In [1]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from data_utils import *

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features
num_edge_features = dataset.num_edge_features

In [29]:
class Encoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        # TODO: two graph convolutional layers (32 and 64 channeles) with identity connection (edge conditioned graph convolution)
        self.conv1 = GCNConv(in_channels=hparams["num_node_features"], out_channels=32)
        self.conv2 = GCNConv(in_channels=32, out_channels=64)
        self.fc = nn.Linear(in_features=64, out_features=128)
        # TODO: use nn.Sequential

    def forward(self, data):
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x
    
class Decoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.fcls = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            nn.BatchNorm1d(num_features=512),
            nn.ReLU(),
        )

        n = hparams["max_node_count"]
        # the atom graph is symmetric so we only predict the upper triangular part
        upper_triangular_size = int(n * (n + 1) / 2)
        self.fc_adjacency = nn.Linear(in_features=512, out_features=upper_triangular_size)

        self.fc_node_features = nn.Linear(in_features=512, out_features=n * hparams["num_node_features"])
        self.fc_edge_features = nn.Linear(in_features=512, out_features=n * hparams["num_edge_features"])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.fcls(x)

        adjacency_matrix = self.fc_adjacency(x)
        node_features = self.fc_node_features(x)
        edge_features = self.fc_edge_features(x)

        return adjacency_matrix, node_features, edge_features


class GraphVAE(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.encoder = Encoder(hparams=hparams)
        self.decoder = Decoder(hparams=hparams)

    def forward(self, data):
        z = self.encoder(data)
        x = self.decoder(z)
        return x

In [13]:
hparams = {
    "batch_size": 32,
    "max_node_count": 29,
    "learning_rate": 1e-3,
    "epochs": 25,
    "num_node_features": num_node_features,
    "num_edge_features": num_edge_features
}
print(hparams)

{'batch_size': 32, 'max_node_count': 29, 'learning_rate': 0.001, 'epochs': 25, 'num_node_features': 5, 'num_edge_features': 4}


In [11]:
batch_size = hparams["batch_size"]

# TODO: implement reconstruction loss
# TODO: add kl-loss
# TODO: implement encoder from the paper
# TODO: graph matching

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [30]:
model = GraphVAE(hparams=hparams).to(device=device)
# TODO: add beta from paper
optimizer = torch.optim.Adam(model.parameters(), lr=hparams["learning_rate"])
epochs = hparams["epochs"]
train_loader = dataloaders["train_tiny"]

for epoch in range(epochs):
    model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        out = model(train_batch)
        a, f, e = out

        print(a)
        print(a.shape)
        break
    break


Epoch 1 Training:   0%|          | 0/1 [00:00<?, ?it/s]

tensor([[-0.1027, -0.1559, -0.3459,  ..., -0.0055, -0.2288,  0.0569],
        [-0.1306, -0.2203,  0.0454,  ..., -0.4390,  0.1434,  0.3531],
        [-0.2077, -0.5292, -0.2885,  ..., -0.2427, -0.4502,  0.1299],
        ...,
        [-0.1218, -0.2589,  0.0697,  ...,  0.5250, -0.0473,  0.4270],
        [-0.0907, -0.1220, -0.1019,  ...,  0.3809,  0.0100,  0.2199],
        [-0.2599, -0.4633, -0.3692,  ..., -0.0621, -0.4965,  0.3781]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
torch.Size([16, 435])



