In [1]:
%load_ext autoreload
%autoreload 2

from typing import Dict, Any
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import *
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
include_hydrogen = False
properties=["homo", "lumo"]

transform_list = [
    SelectQM9TargetProperties(properties=properties),
    SelectQM9NodeFeatures(features=["atom_type"]),
]
if not include_hydrogen:
    transform_list.append(DropQM9Hydrogen())

max_num_nodes = 29 if include_hydrogen else 9
transform_list += [
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    AddEdgeAttributeMatrix(max_num_nodes=max_num_nodes),
]

pre_transform = T.Compose(transform_list)
transform = T.Compose([
    #RandomPermutation(max_num_nodes=max_num_nodes),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", pre_transform=pre_transform, pre_filter=qm9_pre_filter, transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

print(f"Training dataset size = {len(train_dataset)}")
print(f"Validation dataset size = {len(val_dataset)}")
print(f"Test dataset size = {len(test_dataset)}")

Training dataset size = 102445
Validation dataset size = 12806
Test dataset size = 12805


## Create Dataloaders

In [2]:
from typing import List
from data_utils import create_validation_subset_loaders

batch_size = 128

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

## Baseline model (mean prediction)

### Training

In [5]:
import torch.nn as nn

train_loader = dataloaders["train"]
homo_mean_list = []
lumo_mean_list = []
for batch in tqdm(train_loader):
    batch_mean = torch.mean(batch.y, dim=0)
    homo_mean_list.append(batch_mean[0])
    lumo_mean_list.append(batch_mean[1])

homo_mean_pred = torch.tensor(homo_mean_list).mean()
lumo_mean_pred = torch.tensor(lumo_mean_list).mean()
print(f"HOMO Mean = {homo_mean_pred}")
print(f"LUMO Mean = {lumo_mean_pred}")

class MeanPredictor(nn.Module):
    def __init__(self, property_mean_values: List[float]):
        super().__init__()
        # create mean prediction and add batch dimension
        self.register_buffer('mean_prediction', torch.tensor(property_mean_values).unsqueeze(0))

    def forward(self, x):
        batch_size = x.y.shape[0]
        return self.mean_prediction.expand(batch_size, -1)

mean_baseline_model = MeanPredictor(property_mean_values=[homo_mean_pred, lumo_mean_pred]).to(device)

  0%|          | 0/801 [00:00<?, ?it/s]

100%|██████████| 801/801 [06:30<00:00,  2.05it/s]


HOMO Mean = -6.546781063079834
LUMO Mean = 0.3270353078842163


### Validation

In [3]:
def evaluate_model_performance(validation_loader, model):
    mae_sum = 0
    for batch in tqdm(validation_loader):
        prediction = model(batch)
        mae_sum += torch.mean(torch.abs(prediction - batch.y), dim=0)

    mean_absolute_error = mae_sum / len(validation_loader)

    print(f"HOMO MAE = {mean_absolute_error[0]}")
    print(f"LUMO MAE = {mean_absolute_error[1]}")

val_loader = dataloaders["val"]

print("Mean Baseline:")
#evaluate_model_performance(validation_loader=val_loader, model=mean_baseline_model)

Mean Baseline:


## Graph Property Predictor

In [6]:
from data_utils import create_tensorboard_writer
from tqdm import tqdm
import itertools
from graph_vae.vae import GraphVAE

class PropertyPredictorVAE(nn.Module):

    def __init__(self, graph_vae: GraphVAE, hparams: Dict[str, Any]):
        super().__init__()
        self.graph_vae = graph_vae

    def forward(self, x: Data):
        z = self.graph_vae.encode(x=x)
        return self.graph_vae.predict_properties(z)

class PropertyPredictorVAEEncoder(nn.Module):

    def __init__(self, graph_vae: GraphVAE, hparams: Dict[str, Any]):
        super().__init__()
        self.graph_vae = graph_vae
        property_count = len(hparams["properties"])
        self.property_predictor = nn.Sequential(
            nn.Linear(self.graph_vae.latent_dim, 256),
            nn.BatchNorm1d(256),
            nn.PReLU(),
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.PReLU(),
            nn.Linear(256, property_count)
        )

    def forward(self, x: Data):
        z = self.graph_vae.encode(x=x)
        return self.property_predictor(z)


def train_property_predictor(
        model: PropertyPredictorVAEEncoder,
        train_loader: DataLoader,
        val_subset_loaders: List[DataLoader],
        epochs: int,
        tb_writer: SummaryWriter,
    ):
    
    learning_rate = 1e-3
    optimizer = torch.optim.Adam(model.property_predictor.parameters(), lr=learning_rate)
    loss_function = nn.MSELoss()

    # After how many iterations to validate
    validation_interval = 10

    val_subset_loader_iterator = itertools.cycle(val_subset_loaders)

    for epoch in range(epochs):
        # Training
        model.train()
        for batch_index, train_batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} Training")):
            optimizer.zero_grad()
            train_prediction = model(train_batch)
            train_loss = loss_function(train_prediction, train_batch.y)
            train_loss.backward()
            optimizer.step()

            iteration = len(train_loader) * epoch + batch_index
            tb_writer.add_scalars("Loss", {"Training": train_loss.item()}, iteration)

            # Validation
            if iteration % validation_interval == 0:
                model.eval()
                val_loss_sum = 0
                mae_sum = 0

                # Get the next subset of the validation set
                val_loader = next(val_subset_loader_iterator)
                with torch.no_grad():
                    for val_batch in val_loader:
                        val_prediction = model(val_batch)
                        val_loss_sum += loss_function(val_prediction, val_batch.y)
                        mae_sum += torch.mean(torch.abs(val_prediction - val_batch.y), dim=0)
                
                val_loss = val_loss_sum / len(val_loader)
                tb_writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)

                mean_absolute_error = mae_sum / len(val_loader)
                tb_writer.add_scalar("MAE (HOMO)", mean_absolute_error[0], iteration)
                tb_writer.add_scalar("MAE (LUMO)", mean_absolute_error[1], iteration)
                
                model.train()

In [7]:
writer = create_tensorboard_writer(experiment_name="property-predictor")

train_loader = dataloaders["train"]
val_subset_loaders = dataloaders["val_subsets"]

hparams = {
    "properties": properties
}

graph_vae_no_prop = GraphVAE.from_pretrained("./checkpoints/graph_vae_20240225_182115.pt")
model_vae_no_prop = PropertyPredictorVAEEncoder(graph_vae=graph_vae_no_prop, hparams=hparams).to(device)

# train_property_predictor(
#     model=model_vae_no_prop,
#     train_loader=train_loader, 
#     val_subset_loaders=val_subset_loaders,
#     epochs=1,
#     tb_writer=writer,
# )

In [11]:
model_vae_no_prop.eval()
evaluate_model_performance(validation_loader=dataloaders["val"], model=model_vae_no_prop)

  0%|          | 0/101 [00:00<?, ?it/s]

100%|██████████| 101/101 [00:24<00:00,  4.07it/s]

HOMO MAE = 0.32732459902763367
LUMO MAE = 0.47747018933296204





In [8]:
graph_vae_prop = GraphVAE.from_pretrained("./checkpoints/graph_vae_20240225_185057.pt")
model_vae_prop = PropertyPredictorVAE(graph_vae=graph_vae_prop, hparams=hparams).to(device)
writer = create_tensorboard_writer(experiment_name="property-predictor")

# train_property_predictor(
#     model=model_vae_no_prop,
#     train_loader=train_loader, 
#     val_subset_loaders=val_subset_loaders,
#     epochs=1,
#     tb_writer=writer,
# )

In [9]:
model_vae_prop.eval()
evaluate_model_performance(validation_loader=dataloaders["val"], model=model_vae_prop)

  0%|          | 0/101 [00:00<?, ?it/s]

100%|██████████| 101/101 [01:21<00:00,  1.24it/s]


HOMO MAE = 0.19954803586006165
LUMO MAE = 0.1990249752998352


In [10]:
import torch_geometric.utils as pyg_utils
import networkx as nx

num_samples = 1000
max_decode_attempts = 1
total_decode_attempts = 0
num_connected_graphs = 0
num_valid_mols = 0
generated_mol_smiles = set()

z, x = graph_vae_prop.sample(num_samples=num_samples, device=device)
for i in tqdm(range(num_samples), "Generating Molecules"):
    sample_matrices = (x[0][i:i+1], x[1][i:i+1], x[2][i:i+1])

    # attempt to decode multiply time until we have both a connected graph and a valid molecule
    for _ in range(max_decode_attempts):
        sample_graph = graph_vae_prop.output_to_graph(x=sample_matrices, stochastic=False)
        total_decode_attempts += 1

        # check if the generated graph is connected
        if nx.is_connected(pyg_utils.to_networkx(sample_graph, to_undirected=True)):
            num_connected_graphs += 1
        else:
            # graph is not connected; try to decode again
            continue
    
        try:
            mol = graph_to_mol(data=sample_graph, includes_h=include_hydrogen, validate=True)
        except Exception as e:
            # Molecule is invalid; try to decode again
            continue

        # Molecule is valid
        num_valid_mols += 1
        smiles = Chem.MolToSmiles(mol)
        if smiles not in generated_mol_smiles:
            writer.add_image('Generated', mol_to_image_tensor(mol=mol), global_step=i, dataformats="NCHW")
            generated_mol_smiles.add(Chem.MolToSmiles(mol))
        break

Generating Molecules:   0%|          | 1/1000 [00:00<15:28,  1.08it/s][15:20:40] Explicit valence for atom # 1 C, 5, is greater than permitted
Generating Molecules:   2%|▎         | 25/1000 [00:02<00:49, 19.53it/s][15:20:42] Explicit valence for atom # 6 C, 5, is greater than permitted
Generating Molecules:   3%|▎         | 28/1000 [00:02<00:51, 18.82it/s][15:20:42] Explicit valence for atom # 6 N, 4, is greater than permitted
Generating Molecules:   4%|▎         | 36/1000 [00:02<00:51, 18.88it/s][15:20:42] Explicit valence for atom # 7 C, 6, is greater than permitted
Generating Molecules:   4%|▍         | 44/1000 [00:03<00:51, 18.39it/s][15:20:42] Explicit valence for atom # 8 O, 3, is greater than permitted
[15:20:42] Explicit valence for atom # 6 N, 4, is greater than permitted
Generating Molecules:   7%|▋         | 69/1000 [00:04<00:51, 17.95it/s][15:20:44] Explicit valence for atom # 1 O, 3, is greater than permitted
Generating Molecules:   8%|▊         | 80/1000 [00:05<00:43, 21.

In [20]:
z.requires_grad_(True)
optimizer = torch.optim.Adam([z], lr=1e-2)

In [21]:
for i in tqdm(range(10000)):
    properties_predicted = graph_vae_prop.predict_properties(z)
    # reduce homo-lumo gap
    loss = ((properties_predicted[:, 0] - properties_predicted[:, 1]) ** 2).mean()
    loss.backward()
    optimizer.step()

print(loss)

  0%|          | 22/10000 [00:00<00:46, 215.10it/s]

100%|██████████| 10000/10000 [00:41<00:00, 241.22it/s]

tensor(11.5145, device='cuda:0', grad_fn=<MeanBackward0>)





In [19]:
num_samples = 1000
max_decode_attempts = 1
total_decode_attempts = 0
num_connected_graphs = 0
num_valid_mols = 0
generated_mol_smiles = set()

z = z.detach()
x = graph_vae_prop.decode(z)
for i in tqdm(range(num_samples), "Generating Molecules"):
    sample_matrices = (x[0][i:i+1], x[1][i:i+1], x[2][i:i+1])

    # attempt to decode multiply time until we have both a connected graph and a valid molecule
    for _ in range(max_decode_attempts):
        sample_graph = graph_vae_prop.output_to_graph(x=sample_matrices, stochastic=False)
        total_decode_attempts += 1

        # check if the generated graph is connected
        if nx.is_connected(pyg_utils.to_networkx(sample_graph, to_undirected=True)):
            num_connected_graphs += 1
        else:
            # graph is not connected; try to decode again
            continue
    
        try:
            mol = graph_to_mol(data=sample_graph, includes_h=include_hydrogen, validate=True)
        except Exception as e:
            # Molecule is invalid; try to decode again
            continue

        # Molecule is valid
        num_valid_mols += 1
        smiles = Chem.MolToSmiles(mol)
        if smiles not in generated_mol_smiles:
            writer.add_image('Generated Optimized 20000', mol_to_image_tensor(mol=mol), global_step=i, dataformats="NCHW")
            generated_mol_smiles.add(Chem.MolToSmiles(mol))
        break

Generating Molecules:   0%|          | 0/1000 [00:00<?, ?it/s]

Generating Molecules:  16%|█▋        | 165/1000 [00:01<00:07, 110.70it/s][15:31:43] Explicit valence for atom # 2 O, 3, is greater than permitted
Generating Molecules:  18%|█▊        | 182/1000 [00:01<00:06, 125.44it/s][15:31:44] Explicit valence for atom # 4 O, 3, is greater than permitted
Generating Molecules:  24%|██▍       | 239/1000 [00:02<00:07, 100.85it/s][15:31:44] Explicit valence for atom # 4 O, 3, is greater than permitted
Generating Molecules:  29%|██▉       | 293/1000 [00:03<00:07, 100.79it/s][15:31:45] Explicit valence for atom # 4 O, 3, is greater than permitted
Generating Molecules:  38%|███▊      | 375/1000 [00:03<00:05, 107.09it/s][15:31:45] Explicit valence for atom # 4 O, 3, is greater than permitted
Generating Molecules:  48%|████▊     | 475/1000 [00:04<00:05, 89.88it/s] [15:31:46] Explicit valence for atom # 4 C, 5, is greater than permitted
Generating Molecules:  61%|██████    | 608/1000 [00:05<00:03, 114.99it/s][15:31:48] Explicit valence for atom # 7 N, 4, is g