In [50]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import to_dense_adj
from data_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    AddAdjacencyMatrix(max_num_nodes=29),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features
num_edge_features = dataset.num_edge_features

In [52]:
class Encoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        # TODO: two graph convolutional layers (32 and 64 channeles) with identity connection (edge conditioned graph convolution)
        self.conv1 = GCNConv(in_channels=hparams["num_node_features"], out_channels=32)
        self.conv2 = GCNConv(in_channels=32, out_channels=64)
        self.fc = nn.Linear(in_features=64, out_features=128)
        # TODO: use nn.Sequential

    def forward(self, data: Data):
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x
    
class Decoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.fcls = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            nn.BatchNorm1d(num_features=512),
            nn.ReLU(),
        )

        n = hparams["max_node_count"]
        # the atom graph is symmetric so we only predict the upper triangular part
        upper_triangular_size = int(n * (n + 1) / 2)
        self.fc_adjacency = nn.Linear(in_features=512, out_features=upper_triangular_size)

        self.fc_node_features = nn.Linear(in_features=512, out_features=n * hparams["num_node_features"])
        self.fc_edge_features = nn.Linear(in_features=512, out_features=n * hparams["num_edge_features"])

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.fcls(z)
        adjacency_matrix = self.fc_adjacency(x)
        node_features = self.fc_node_features(x)
        edge_features = self.fc_edge_features(x)

        # upper triangular
        return adjacency_matrix, node_features, edge_features


class GraphVAE(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.encoder = Encoder(hparams=hparams)
        self.decoder = Decoder(hparams=hparams)

    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z = self.encoder(data)
        x = self.decoder(z)
        return x

In [53]:
hparams = {
    "batch_size": 32,
    "max_node_count": 29,
    "learning_rate": 1e-3,
    "epochs": 25,
    "num_node_features": num_node_features,
    "num_edge_features": num_edge_features
}
print(hparams)

{'batch_size': 32, 'max_node_count': 29, 'learning_rate': 0.001, 'epochs': 25, 'num_node_features': 5, 'num_edge_features': 4}


In [54]:
batch_size = hparams["batch_size"]

# TODO: implement reconstruction loss adjacency
# TODO: implement reconstruction loss node features
# TODO: implement reconstruction loss edge features
# TODO: add kl-loss
# TODO: implement encoder from the paper
# TODO: graph matching

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [55]:
model = GraphVAE(hparams=hparams).to(device=device)
# TODO: add beta_1 = 0.5 as in the paper
optimizer = torch.optim.Adam(model.parameters(), lr=hparams["learning_rate"])
epochs = hparams["epochs"]
train_loader = dataloaders["train_tiny"]

for epoch in range(epochs):
    model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        out = model(train_batch)
        a, f, e = out

        a = a[0]

        adj_triu_mat, batch = train_batch.adj_triu_mat, train_batch.batch

        adj_triu_mat = adj_triu_mat[0]

        print("pred")
        print(a)

        print("target")
        print(adj_triu_mat)

        
        break
    break


Epoch 1 Training:   0%|          | 0/1 [00:00<?, ?it/s]

pred
tensor([ 3.2702e-02, -3.6735e-02,  5.0752e-02, -4.9446e-01,  2.0766e-02,
         1.9365e-01,  1.3127e+00, -6.7586e-01,  6.4589e-01,  3.2640e-01,
        -7.4867e-01,  6.7585e-01,  9.9469e-01, -5.3812e-01, -9.2381e-01,
        -4.2445e-01, -2.6701e-01, -5.7261e-01,  3.9967e-01,  5.0608e-01,
        -1.6453e-02,  3.1527e-01, -1.2763e+00,  9.1606e-01, -7.6159e-02,
        -5.0747e-01, -3.6732e-01,  1.8161e-01, -3.5841e-01, -5.4424e-01,
         2.6100e-01, -5.4432e-01,  9.9129e-01,  5.6860e-02,  7.5466e-01,
        -6.7616e-01,  5.3304e-01, -1.0713e+00,  1.8274e+00,  1.0923e+00,
         1.5836e+00, -9.4519e-01,  1.6245e-01, -6.4333e-02, -1.8237e+00,
         6.1454e-01,  7.7781e-02,  6.0585e-01,  2.1424e-01,  8.6889e-02,
         1.1978e+00,  1.1860e+00,  5.7142e-04,  5.7663e-02, -4.4496e-01,
        -8.9290e-01,  1.6065e-01, -1.1018e+00, -5.2224e-01, -3.4467e-01,
         8.0591e-01,  1.3219e+00, -1.5670e+00, -1.3879e-01,  7.1336e-01,
         9.2058e-01, -6.3079e-01,  9.5702e-01,


