In [19]:
%load_ext autoreload
%autoreload 2

import itertools
import datetime
import os
import shutil

from tqdm import tqdm
import torch
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

from graph_vae.vae import GraphVAE
from data_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
# args
include_hydrogen = False
refresh_data_cache = True
use_pre_transform = True

In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform_list = [
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
]
if not include_hydrogen:
    transform_list.append(DropQM9Hydrogen())

max_num_nodes = 29 if include_hydrogen else 9
transform_list += [
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    AddEdgeAttributeMatrix(max_num_nodes=max_num_nodes),
    # DropAttributes(attributes=["z", "pos", "idx", "name"]),
]

if use_pre_transform:
    pre_transform = T.Compose(transform_list)
    transform = T.ToDevice(device=device)
else:
    pre_transform = None
    transform = T.Compose(transform_list + [T.ToDevice(device=device)])

# note: when the pre_filter or pre_transform is changed, delete the data/processed folder to update the dataset
dataset = QM9(root="./data", pre_transform=pre_transform, pre_filter=qm9_pre_filter, transform=transform)

if refresh_data_cache:
    # remove the processed files and recreate them
    # this might be necessary when the pre_transform or the pre_filter has been updated
    shutil.rmtree(dataset.processed_dir)
    dataset = QM9(root="./data", pre_transform=pre_transform, pre_filter=qm9_pre_filter, transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

Processing...
100%|██████████| 133885/133885 [04:36<00:00, 484.34it/s]
Done!


In [30]:
hparams = {
    "batch_size": 32,
    "max_num_nodes": max_num_nodes,
    "learning_rate": 1e-3,
    "adam_beta_1": 0.5,
    "epochs": 1,
    "num_node_features": dataset.num_node_features,
    "num_edge_features": dataset.num_edge_features,
    "latent_dim": 128,  # c in the paper
    "kl_weight": 1e-2,
    "include_hydrogen": include_hydrogen,
}
in_checkpoint = None
train_sample_limit = None

In [31]:
batch_size = hparams["batch_size"]

dataloaders = {
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

if train_sample_limit is not None:
    dataloaders["train"] = DataLoader(train_dataset[:train_sample_limit], batch_size=batch_size, shuffle=True)
else:
    dataloaders["train"] = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [32]:
# create checkpoint dir and unique filename
os.makedirs("./checkpoints/", exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
out_checkpoint = f"./checkpoints/graph_vae_{timestamp}.pt"

# setup model and optimizer
graph_vae_model = GraphVAE(hparams=hparams).to(device=device)
optimizer = torch.optim.Adam(
    graph_vae_model.parameters(),
    lr=hparams["learning_rate"],
    betas=(hparams["adam_beta_1"], 0.999)
)

# load checkpoint
if in_checkpoint is not None:
    checkpoint = checkpoint = torch.load(in_checkpoint)
    graph_vae_model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
else:
    start_epoch = 0

# get dataloaders
train_loader = dataloaders["train"]
val_subset_loader_iterator = itertools.cycle(dataloaders["val_subsets"])

# create tensorboard summary writer
writer = create_tensorboard_writer(experiment_name="graph_vae")

validation_interval = 100
epochs = hparams["epochs"]

for epoch in range(start_epoch, epochs):
    graph_vae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        
        train_elbo, train_recon_loss = graph_vae_model.elbo(x=train_batch)

        train_loss = -train_elbo
        train_loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": train_loss.item()}, iteration)
        writer.add_scalars("ELBO", {"Training": train_elbo.item()}, iteration)
        writer.add_scalars("Reconstruction Loss", {"Training": train_recon_loss.item()}, iteration)
        
        if (iteration + 1) % validation_interval == 0 or iteration == 0:
            graph_vae_model.eval()
            val_loss_sum = 0
            val_elbo_sum = 0

            # Get the next subset of the validation set
            val_loader = next(val_subset_loader_iterator)
            with torch.no_grad():
                for val_batch in val_loader:
                    val_elbo, val_recon_loss = graph_vae_model.elbo(x=val_batch)
                    val_elbo_sum += val_elbo
                    val_loss = -val_elbo
                    val_loss_sum += val_loss
            
            val_loss = val_loss_sum / len(val_loader)
            val_elbo = val_elbo_sum / len(val_loader)
            writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)
            writer.add_scalars("ELBO", {"Validation": val_elbo.item()}, iteration)
            writer.add_scalars("Reconstruction Loss", {"Validation": val_recon_loss.item()}, iteration)
            
            graph_vae_model.train()

    torch.save({
            'epoch': epoch,
            'model_state_dict': graph_vae_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        },
        out_checkpoint
    )

Epoch 1 Training:   0%|          | 0/3202 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 3202/3202 [01:30<00:00, 35.37it/s]


## Evaluation

In [13]:
graph_vae_model.eval()

log_hparams = hparams
log_hparams.update({
    "Encoder Parameter Count": sum(p.numel() for p in graph_vae_model.encoder.parameters() if p.requires_grad),
    "Decoder Parameter Count": sum(p.numel() for p in graph_vae_model.decoder.parameters() if p.requires_grad),
})

# evaluate average reconstruction log-likelihood on validation set
val_loader = dataloaders["val"]
val_elbo_sum = 0
val_log_likelihood_sum = 0
for val_batch in tqdm(val_loader, desc="Evaluating Reconstruction Performance..."):
    val_elbo, val_recon_loss = graph_vae_model.elbo(x=val_batch)
    val_elbo_sum += val_elbo
    val_log_likelihood_sum -= val_recon_loss

metrics = dict()
metrics.update({
    "ELBO": val_elbo_sum / len(val_loader),
    "Log-likelihood": val_log_likelihood_sum / len(val_loader)
})

Evaluating Reconstruction Performance...:   2%|▏         | 8/401 [00:00<00:04, 79.22it/s]

Evaluating Reconstruction Performance...: 100%|██████████| 401/401 [00:04<00:00, 85.58it/s]


In [14]:
# decoding quality metrics
train_mol_smiles = set()
for batch in tqdm(train_loader, desc="Converting training graphs to SMILES..."):
    for sample_index in range(len(batch)):
        sample = batch[sample_index]
        mol = graph_to_mol(data=sample, includes_h=include_hydrogen, validate=False)
        train_mol_smiles.add(Chem.MolToSmiles(mol))

num_samples = 1000
num_valid_mols = 0

generated_mol_smiles = set()
z, x = graph_vae_model.sample(num_samples=num_samples, device=device)
for i in tqdm(range(num_samples), "Generating Molecules..."):
    sample_matrices = (x[0][i:i+1], x[1][i:i+1], x[2][i:i+1])
    sample_graph = graph_vae_model.output_to_graph(x=sample_matrices)
    
    try:
        mol = graph_to_mol(data=sample_graph, includes_h=include_hydrogen, validate=True)
        num_valid_mols += 1
        smiles = Chem.MolToSmiles(mol)
        if smiles in generated_mol_smiles:
            continue
        generated_mol_smiles.add(Chem.MolToSmiles(mol))
        writer.add_image('Generated', mol_to_image_tensor(mol=mol), global_step=i, dataformats="NCHW")
    except Exception as e:
        # print(f"Invalid molecule: {e}")
        # mol = graph_to_mol(data=sample_graph, includes_h=include_hydrogen, validate=False)
        pass
    
non_novel_mols = train_mol_smiles.intersection(generated_mol_smiles)
novel_mol_count = len(generated_mol_smiles) - len(non_novel_mols)

metrics.update({
    "Validity": num_valid_mols / num_samples,
    "Uniqueness": len(generated_mol_smiles) / num_valid_mols,
    "Novelty": novel_mol_count / len(generated_mol_smiles),  
})
log_hparams["checkpoint"] = out_checkpoint
writer.add_hparams(hparam_dict=log_hparams, metric_dict=metrics)

Converting training graphs to SMILES...:   0%|          | 1/3202 [00:00<08:41,  6.14it/s]

Converting training graphs to SMILES...: 100%|██████████| 3202/3202 [02:43<00:00, 19.53it/s]
Generating Molecules...: 100%|██████████| 1000/1000 [00:11<00:00, 89.79it/s]
