In [34]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any
import itertools
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import dense_to_sparse, remove_self_loops

from models import GraphVAE
from data_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
drop_hydrogen = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_num_nodes = 9 if drop_hydrogen else 29

# TODO: pre-transform and store matrices to disk
transform_list = [
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
]
if drop_hydrogen:
    transform_list.append(DropQM9Hydrogen())

transform_list += [
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    AddEdgeAttributeMatrix(max_num_nodes=max_num_nodes),
    T.ToDevice(device=device)
]
transform = T.Compose(transform_list)

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

In [32]:
hparams = {
    "batch_size": 32,
    "max_num_nodes": max_num_nodes,
    "learning_rate": 1e-3,
    "adam_beta_1": 0.5,
    "epochs": 800,
    "num_node_features": dataset.num_node_features,
    "num_edge_features": dataset.num_edge_features,
    "latent_dim": 128,  # c in the paper
    "kl_weight": 1.0,
    "drop_hydrogen": drop_hydrogen,
}

batch_size = hparams["batch_size"]
dataloaders = {
    "train_single": DataLoader(train_dataset[16:18], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:batch_size], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [33]:
graph_vae_model = GraphVAE(hparams=hparams).to(device=device)
optimizer = torch.optim.Adam(
    graph_vae_model.parameters(),
    lr=hparams["learning_rate"],
    betas=(hparams["adam_beta_1"], 0.999)
)
epochs = hparams["epochs"]

train_loader = dataloaders["train_tiny"]
val_subset_loader_iterator = itertools.cycle(dataloaders["val_subsets"])

validation_interval = 100

writer = create_tensorboard_writer(experiment_name="graph-vae-2")

for epoch in range(epochs):
    graph_vae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        loss = graph_vae_model.negative_elbo(x=train_batch)

        loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": loss.item()}, iteration)

        if (iteration + 1) % validation_interval == 0 or iteration == 0:
            graph_vae_model.eval()
            val_loss_sum = 0

            # Get the next subset of the validation set
            val_loader = next(val_subset_loader_iterator)
            with torch.no_grad():
                for val_batch in val_loader:
                    val_loss_sum += graph_vae_model.negative_elbo(x=val_batch)
            
            val_loss = val_loss_sum / len(val_loader)
            writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)
            
            graph_vae_model.train()

# visualize molecule reconstruction of the first training batch
graph_vae_model.eval()
for batch in train_loader:
    for sample_index in tqdm(range(batch_size)):
        sample = batch[sample_index]
        writer.add_image('Input', molecule_graph_data_to_image(sample, includes_h=not drop_hydrogen), global_step=sample_index, dataformats="NCHW")
        reconstructed_sample = graph_vae_model.output_to_graph(graph_vae_model(sample))
        writer.add_image('Reconstruction', molecule_graph_data_to_image(reconstructed_sample, includes_h=not drop_hydrogen), global_step=sample_index, dataformats="NCHW")  
    break

Epoch 1 Training: 100%|██████████| 1/1 [00:02<00:00,  2.83s/it]
Epoch 2 Training: 100%|██████████| 1/1 [00:00<00:00,  4.68it/s]
Epoch 3 Training: 100%|██████████| 1/1 [00:00<00:00,  4.51it/s]
Epoch 4 Training: 100%|██████████| 1/1 [00:00<00:00,  4.31it/s]
Epoch 5 Training: 100%|██████████| 1/1 [00:00<00:00,  5.80it/s]
Epoch 6 Training: 100%|██████████| 1/1 [00:00<00:00,  4.79it/s]
Epoch 7 Training: 100%|██████████| 1/1 [00:00<00:00,  4.73it/s]
Epoch 8 Training: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s]
Epoch 9 Training: 100%|██████████| 1/1 [00:00<00:00,  4.63it/s]
Epoch 10 Training: 100%|██████████| 1/1 [00:00<00:00,  4.68it/s]
Epoch 11 Training: 100%|██████████| 1/1 [00:00<00:00,  4.56it/s]
Epoch 12 Training: 100%|██████████| 1/1 [00:00<00:00,  4.63it/s]
Epoch 13 Training: 100%|██████████| 1/1 [00:00<00:00,  3.97it/s]
Epoch 14 Training: 100%|██████████| 1/1 [00:00<00:00,  4.72it/s]
Epoch 15 Training: 100%|██████████| 1/1 [00:00<00:00,  3.73it/s]
Epoch 16 Training: 100%|██████████

In [None]:
# TODO: evaluate validity
# TODO: generate 100 mols