In [2]:
%load_ext autoreload
%autoreload 2

import itertools
from tqdm import tqdm

import torch
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

from models import GraphVAE
from data_utils import *

In [3]:
drop_hydrogen = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_num_nodes = 9 if drop_hydrogen else 29

# TODO: pre-transform and store matrices to disk
transform_list = [
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
]
if drop_hydrogen:
    transform_list.append(DropQM9Hydrogen())

transform_list += [
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    AddEdgeAttributeMatrix(max_num_nodes=max_num_nodes),
    T.ToDevice(device=device)
]
transform = T.Compose(transform_list)

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

In [6]:
hparams = {
    "batch_size": 32,
    "max_num_nodes": max_num_nodes,
    "learning_rate": 1e-3,
    "adam_beta_1": 0.5,
    "epochs": 20000,
    "num_node_features": dataset.num_node_features,
    "num_edge_features": dataset.num_edge_features,
    "latent_dim": 128,  # c in the paper
    "kl_weight": 1e-2,
    "drop_hydrogen": drop_hydrogen,
}

batch_size = hparams["batch_size"]
dataloaders = {
    "train_single": DataLoader(train_dataset[16:18], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:batch_size], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:8192], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [7]:
graph_vae_model = GraphVAE(hparams=hparams).to(device=device)
optimizer = torch.optim.Adam(
    graph_vae_model.parameters(),
    lr=hparams["learning_rate"],
    betas=(hparams["adam_beta_1"], 0.999)
)
epochs = hparams["epochs"]

train_loader = dataloaders["train_tiny"]
val_subset_loader_iterator = itertools.cycle(dataloaders["val_subsets"])

validation_interval = 100

writer = create_tensorboard_writer(experiment_name="graph-vae-3")

for epoch in range(epochs):
    graph_vae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        
        loss, recon_loss, kl_div = graph_vae_model.negative_elbo(x=train_batch)

        loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": loss.item()}, iteration)
        writer.add_scalars("Reconstruction Loss", {"Training": recon_loss.item()}, iteration)
        writer.add_scalars("KL-Divergence", {"Training": kl_div.item()}, iteration)
        
        if (iteration + 1) % validation_interval == 0 or iteration == 0:
            graph_vae_model.eval()
            val_loss_sum = 0

            # Get the next subset of the validation set
            val_loader = next(val_subset_loader_iterator)
            with torch.no_grad():
                for val_batch in val_loader:
                    val_loss, _, _ = graph_vae_model.negative_elbo(x=val_batch)
                    val_loss_sum += val_loss
            
            val_loss = val_loss_sum / len(val_loader)
            writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)
            
            graph_vae_model.train()

Epoch 1 Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 1/1 [00:00<00:00,  2.04it/s]
Epoch 2 Training: 100%|██████████| 1/1 [00:00<00:00, 23.77it/s]
Epoch 3 Training: 100%|██████████| 1/1 [00:00<00:00, 23.03it/s]
Epoch 4 Training: 100%|██████████| 1/1 [00:00<00:00, 24.23it/s]
Epoch 5 Training: 100%|██████████| 1/1 [00:00<00:00, 23.07it/s]
Epoch 6 Training: 100%|██████████| 1/1 [00:00<00:00, 23.54it/s]
Epoch 7 Training: 100%|██████████| 1/1 [00:00<00:00, 22.97it/s]
Epoch 8 Training: 100%|██████████| 1/1 [00:00<00:00, 22.22it/s]
Epoch 9 Training: 100%|██████████| 1/1 [00:00<00:00, 20.75it/s]
Epoch 10 Training: 100%|██████████| 1/1 [00:00<00:00, 21.26it/s]
Epoch 11 Training: 100%|██████████| 1/1 [00:00<00:00, 21.82it/s]
Epoch 12 Training: 100%|██████████| 1/1 [00:00<00:00, 22.77it/s]
Epoch 13 Training: 100%|██████████| 1/1 [00:00<00:00, 22.46it/s]
Epoch 14 Training: 100%|██████████| 1/1 [00:00<00:00, 23.15it/s]
Epoch 15 Training: 100%|██████████| 1/1 [00:00<00:00, 24.50it/s]
Epoch 16 Training: 100%|██████████

KeyboardInterrupt: 

In [8]:
# visualize molecule reconstruction of the first training batch
graph_vae_model.eval()
train_mol_smiles = set()
for batch in train_loader:
    for sample_index in tqdm(range(batch_size)):
        sample = batch[sample_index]

        mol = graph_to_mol(data=sample, includes_h=not drop_hydrogen, validate=True)
        train_mol_smiles.add(Chem.MolToSmiles(mol))

        writer.add_image('Input', molecule_graph_data_to_image(sample, includes_h=not drop_hydrogen), global_step=sample_index, dataformats="NCHW")
        reconstructed_sample = graph_vae_model.output_to_graph(graph_vae_model(sample))
        writer.add_image('Reconstruction', molecule_graph_data_to_image(reconstructed_sample, includes_h=not drop_hydrogen), global_step=sample_index, dataformats="NCHW")  
    break

100%|██████████| 32/32 [00:01<00:00, 29.89it/s]


In [13]:
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

num_samples = 1000
num_valid_mols = 0

gen_mol_smiles = set()
z, x = graph_vae_model.sample(num_samples=num_samples, device=device)
for i in tqdm(range(num_samples)):
    sample_matrices = (x[0][i:i+1], x[1][i:i+1], x[2][i:i+1])
    sample_graph = graph_vae_model.output_to_graph(x=sample_matrices)
    
    try:
        mol = graph_to_mol(data=sample_graph, includes_h=not drop_hydrogen, validate=True)
        num_valid_mols += 1
        gen_mol_smiles.add(Chem.MolToSmiles(mol))
    except Exception as e:
        # print(f"Invalid molecule: {e}")
        mol = graph_to_mol(data=sample_graph, includes_h=not drop_hydrogen, validate=False)
    
    writer.add_image('Generated', mol_to_image_tensor(mol=mol), global_step=i, dataformats="NCHW")

non_novel_mols = train_mol_smiles.intersection(gen_mol_smiles)
novel_mol_count = len(gen_mol_smiles) - len(non_novel_mols)
print(novel_mol_count)
print(len(gen_mol_smiles))
metrics = {
    "Validity": num_valid_mols / num_samples,
    "Uniqueness": len(gen_mol_smiles) / num_valid_mols,
    "Novelty": novel_mol_count / len(gen_mol_smiles)
}
print(metrics)
writer.add_hparams(hparam_dict=hparams, metric_dict=metrics)

100%|██████████| 1000/1000 [00:14<00:00, 69.93it/s]

186
214
{'Validity': 0.857, 'Uniqueness': 0.24970828471411902, 'Novelty': 0.8691588785046729}





In [61]:
graph_vae_model = GraphVAE(hparams=hparams).encoder
sum(p.numel() for p in graph_vae_model.parameters() if p.requires_grad)

43616