In [50]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import to_dense_adj
from data_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    AddAdjacencyMatrix(max_num_nodes=29),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features
num_edge_features = dataset.num_edge_features

In [52]:
class Encoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        # TODO: two graph convolutional layers (32 and 64 channeles) with identity connection (edge conditioned graph convolution)
        self.conv1 = GCNConv(in_channels=hparams["num_node_features"], out_channels=32)
        self.conv2 = GCNConv(in_channels=32, out_channels=64)
        self.fc = nn.Linear(in_features=64, out_features=128)
        # TODO: use nn.Sequential

    def forward(self, data: Data):
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x
    
class Decoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.fcls = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            nn.BatchNorm1d(num_features=512),
            nn.ReLU(),
        )

        n = hparams["max_node_count"]
        # the atom graph is symmetric so we only predict the upper triangular part
        upper_triangular_size = int(n * (n + 1) / 2)
        self.fc_adjacency = nn.Linear(in_features=512, out_features=upper_triangular_size)

        self.fc_node_features = nn.Linear(in_features=512, out_features=n * hparams["num_node_features"])
        self.fc_edge_features = nn.Linear(in_features=512, out_features=n * hparams["num_edge_features"])

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.fcls(z)
        adjacency_matrix = self.fc_adjacency(x)
        node_features = self.fc_node_features(x)
        edge_features = self.fc_edge_features(x)

        # upper triangular
        return adjacency_matrix, node_features, edge_features


class GraphVAE(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.encoder = Encoder(hparams=hparams)
        self.decoder = Decoder(hparams=hparams)

    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z = self.encoder(data)
        x = self.decoder(z)
        return x

In [53]:
hparams = {
    "batch_size": 32,
    "max_node_count": 29,
    "learning_rate": 1e-3,
    "epochs": 25,
    "num_node_features": num_node_features,
    "num_edge_features": num_edge_features
}
print(hparams)

{'batch_size': 32, 'max_node_count': 29, 'learning_rate': 0.001, 'epochs': 25, 'num_node_features': 5, 'num_edge_features': 4}


In [54]:
batch_size = hparams["batch_size"]

# TODO: implement reconstruction loss adjacency
# TODO: implement reconstruction loss node features
# TODO: implement reconstruction loss edge features
# TODO: add kl-loss
# TODO: implement encoder from the paper
# TODO: graph matching

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [58]:
model = GraphVAE(hparams=hparams).to(device=device)
# TODO: add beta_1 = 0.5 as in the paper
optimizer = torch.optim.Adam(model.parameters(), lr=hparams["learning_rate"])
epochs = hparams["epochs"]
train_loader = dataloaders["train_tiny"]

writer = create_tensorboard_writer(experiment_name="graph-vae")

for epoch in range(epochs):
    model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        a_pred, f_pred, e_pred = model(train_batch)
        adj_triu_mat, batch = train_batch.adj_triu_mat, train_batch.batch

        loss = F.binary_cross_entropy_with_logits(input=adj_triu_mat, target=a_pred)
        # TODO: average over diagonal and off-diagonal separately

        loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": loss.item()}, iteration)


Epoch 1 Training: 100%|██████████| 1/1 [00:00<00:00,  4.91it/s]


0.7487122416496277


Epoch 2 Training: 100%|██████████| 1/1 [00:00<00:00, 15.32it/s]


0.7230315208435059


Epoch 3 Training: 100%|██████████| 1/1 [00:00<00:00, 16.07it/s]


0.6988722085952759


Epoch 4 Training: 100%|██████████| 1/1 [00:00<00:00, 16.55it/s]


0.6766119003295898


Epoch 5 Training: 100%|██████████| 1/1 [00:00<00:00, 13.15it/s]


0.6547343730926514


Epoch 6 Training: 100%|██████████| 1/1 [00:00<00:00, 14.25it/s]


0.6333499550819397


Epoch 7 Training: 100%|██████████| 1/1 [00:00<00:00, 12.99it/s]


0.6120489239692688


Epoch 8 Training: 100%|██████████| 1/1 [00:00<00:00, 13.96it/s]


0.5906163454055786


Epoch 9 Training: 100%|██████████| 1/1 [00:00<00:00, 13.53it/s]


0.5689801573753357


Epoch 10 Training: 100%|██████████| 1/1 [00:00<00:00, 14.16it/s]


0.5469658374786377


Epoch 11 Training: 100%|██████████| 1/1 [00:00<00:00, 13.72it/s]


0.5247527956962585


Epoch 12 Training: 100%|██████████| 1/1 [00:00<00:00, 14.02it/s]


0.5022435784339905


Epoch 13 Training: 100%|██████████| 1/1 [00:00<00:00, 13.84it/s]


0.47950586676597595


Epoch 14 Training: 100%|██████████| 1/1 [00:00<00:00, 14.10it/s]


0.45649707317352295


Epoch 15 Training: 100%|██████████| 1/1 [00:00<00:00, 13.59it/s]


0.4331735074520111


Epoch 16 Training: 100%|██████████| 1/1 [00:00<00:00,  9.31it/s]


0.40960362553596497


Epoch 17 Training: 100%|██████████| 1/1 [00:00<00:00, 15.13it/s]


0.3858155608177185


Epoch 18 Training: 100%|██████████| 1/1 [00:00<00:00, 15.17it/s]


0.3617958426475525


Epoch 19 Training: 100%|██████████| 1/1 [00:00<00:00, 16.18it/s]


0.33762118220329285


Epoch 20 Training: 100%|██████████| 1/1 [00:00<00:00, 17.29it/s]

0.31322863698005676



Epoch 21 Training: 100%|██████████| 1/1 [00:00<00:00, 15.97it/s]


0.28870704770088196


Epoch 22 Training: 100%|██████████| 1/1 [00:00<00:00, 11.87it/s]


0.2640085518360138


Epoch 23 Training: 100%|██████████| 1/1 [00:00<00:00, 14.96it/s]


0.23912642896175385


Epoch 24 Training: 100%|██████████| 1/1 [00:00<00:00, 15.92it/s]


0.21409767866134644


Epoch 25 Training: 100%|██████████| 1/1 [00:00<00:00, 13.66it/s]

0.18888774514198303



