In [1]:
%load_ext autoreload
%autoreload 2

from typing import Dict, Any
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import *
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
include_hydrogen = False
properties=["homo", "lumo"]

dataset = create_qm9_dataset(
    device=device,
    include_hydrogen=include_hydrogen,
    refresh_data_cache=False,
    properties=properties
)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

print(f"Training dataset size = {len(train_dataset)}")
print(f"Validation dataset size = {len(val_dataset)}")
print(f"Test dataset size = {len(test_dataset)}")

Training dataset size = 102445
Validation dataset size = 12806
Test dataset size = 12805


## Graph Property Predictor

In [40]:
from property_predictor import load_property_predictor_model
from graph_vae.vae import GraphVAE

graph_vae_prop = GraphVAE.from_pretrained("./checkpoints/graph_vae_20240303_204747.pt").to(device)
property_predictor = load_property_predictor_model().to(device)

graph_vae_prop.eval()
property_predictor.eval()

writer = create_tensorboard_writer(experiment_name="molopt")

In [41]:
import torch_geometric.utils as pyg_utils
import networkx as nx

num_samples = 1000
max_decode_attempts = 1
total_decode_attempts = 0
num_connected_graphs = 0
num_valid_mols = 0
generated_mol_smiles = set()

properties = {}

z, x = graph_vae_prop.sample(num_samples=num_samples, device=device)
for i in tqdm(range(num_samples), "Generating Molecules"):
    sample_matrices = (x[0][i:i+1], x[1][i:i+1], x[2][i:i+1])

    # attempt to decode multiply time until we have both a connected graph and a valid molecule
    for _ in range(max_decode_attempts):
        sample_graph = graph_vae_prop.output_to_graph(x=sample_matrices, stochastic=False)
        total_decode_attempts += 1

        properties[i] = property_predictor(sample_graph)

        # check if the generated graph is connected
        if nx.is_connected(pyg_utils.to_networkx(sample_graph, to_undirected=True)):
            num_connected_graphs += 1
        else:
            # graph is not connected; try to decode again
            continue
    
        try:
            mol = graph_to_mol(data=sample_graph, includes_h=include_hydrogen, validate=True)
        except Exception as e:
            # Molecule is invalid; try to decode again
            continue

        # Molecule is valid
        num_valid_mols += 1
        smiles = Chem.MolToSmiles(mol)
        if smiles not in generated_mol_smiles:
            writer.add_image('Generated', mol_to_image_tensor(mol=mol), global_step=i, dataformats="NCHW")
            generated_mol_smiles.add(Chem.MolToSmiles(mol))
        break

Generating Molecules:   0%|          | 0/1000 [00:00<?, ?it/s]

[11:02:32] Explicit valence for atom # 6 O, 3, is greater than permitted
[11:02:32] Explicit valence for atom # 1 C, 5, is greater than permitted
[11:02:32] Explicit valence for atom # 7 O, 3, is greater than permitted
Generating Molecules:   1%|          | 9/1000 [00:00<00:11, 85.61it/s][11:02:32] Explicit valence for atom # 5 O, 3, is greater than permitted
[11:02:32] Explicit valence for atom # 7 C, 5, is greater than permitted
Generating Molecules:   2%|▏         | 18/1000 [00:00<00:12, 80.69it/s][11:02:32] Explicit valence for atom # 1 O, 3, is greater than permitted
Generating Molecules:   3%|▎         | 28/1000 [00:00<00:11, 85.14it/s][11:02:33] Explicit valence for atom # 4 O, 3, is greater than permitted
[11:02:33] Explicit valence for atom # 6 O, 3, is greater than permitted
Generating Molecules:   4%|▎         | 37/1000 [00:00<00:11, 83.99it/s][11:02:33] Explicit valence for atom # 7 O, 3, is greater than permitted
[11:02:33] Explicit valence for atom # 8 O, 3, is greater th

In [34]:
properties

{0: tensor([[-9.5113,  2.5911]], device='cuda:0', grad_fn=<AddmmBackward0>),
 1: tensor([[-8.8771,  2.5325]], device='cuda:0', grad_fn=<AddmmBackward0>),
 2: tensor([[-9.5150,  2.4138]], device='cuda:0', grad_fn=<AddmmBackward0>),
 3: tensor([[-9.3530,  2.5613]], device='cuda:0', grad_fn=<AddmmBackward0>),
 4: tensor([[-9.3158,  2.5257]], device='cuda:0', grad_fn=<AddmmBackward0>),
 5: tensor([[-9.8418,  2.3043]], device='cuda:0', grad_fn=<AddmmBackward0>),
 6: tensor([[-8.8689,  2.3996]], device='cuda:0', grad_fn=<AddmmBackward0>),
 7: tensor([[-9.0986,  2.4993]], device='cuda:0', grad_fn=<AddmmBackward0>),
 8: tensor([[-9.2769,  2.5510]], device='cuda:0', grad_fn=<AddmmBackward0>),
 9: tensor([[-9.5985,  2.2946]], device='cuda:0', grad_fn=<AddmmBackward0>),
 10: tensor([[-9.3059,  2.5132]], device='cuda:0', grad_fn=<AddmmBackward0>),
 11: tensor([[-9.4012,  2.4234]], device='cuda:0', grad_fn=<AddmmBackward0>),
 12: tensor([[-9.5070,  2.5791]], device='cuda:0', grad_fn=<AddmmBackward0

In [42]:
z.requires_grad_(True)
optimizer = torch.optim.Adam([z], lr=1e-2)

In [43]:
for i in tqdm(range(500)):
    properties_predicted = graph_vae_prop.predict_properties(z)
    # reduce homo-lumo gap
    loss = ((properties_predicted[:, 0] - properties_predicted[:, 1]) ** 2).mean()
    loss.backward()
    optimizer.step()

print(loss)

  7%|▋         | 34/500 [00:00<00:01, 339.07it/s]

100%|██████████| 500/500 [00:01<00:00, 272.50it/s]

tensor(7.6533, device='cuda:0', grad_fn=<MeanBackward0>)





In [44]:
num_samples = 1000
max_decode_attempts = 1
total_decode_attempts = 0
num_connected_graphs = 0
num_valid_mols = 0
generated_mol_smiles = set()

properties_500 = {}

z = z.detach()
x = graph_vae_prop.decode(z)
for i in tqdm(range(num_samples), "Generating Molecules"):
    sample_matrices = (x[0][i:i+1], x[1][i:i+1], x[2][i:i+1])

    # attempt to decode multiply time until we have both a connected graph and a valid molecule
    for _ in range(max_decode_attempts):
        sample_graph = graph_vae_prop.output_to_graph(x=sample_matrices, stochastic=False)
        total_decode_attempts += 1

        properties_500[i] = property_predictor(sample_graph)


        # check if the generated graph is connected
        if nx.is_connected(pyg_utils.to_networkx(sample_graph, to_undirected=True)):
            num_connected_graphs += 1
        else:
            # graph is not connected; try to decode again
            continue
    
        try:
            mol = graph_to_mol(data=sample_graph, includes_h=include_hydrogen, validate=True)
        except Exception as e:
            # Molecule is invalid; try to decode again
            continue

        # Molecule is valid
        num_valid_mols += 1
        smiles = Chem.MolToSmiles(mol)
        if smiles not in generated_mol_smiles:
            writer.add_image('Generated Optimized', mol_to_image_tensor(mol=mol), global_step=i, dataformats="NCHW")
            generated_mol_smiles.add(Chem.MolToSmiles(mol))
        break

Generating Molecules:   0%|          | 0/1000 [00:00<?, ?it/s][11:02:52] Explicit valence for atom # 5 O, 4, is greater than permitted
[11:02:52] Explicit valence for atom # 5 N, 4, is greater than permitted
[11:02:52] Explicit valence for atom # 5 N, 4, is greater than permitted
Generating Molecules:   2%|▏         | 20/1000 [00:00<00:05, 179.43it/s][11:02:52] Explicit valence for atom # 4 N, 4, is greater than permitted
[11:02:52] Explicit valence for atom # 2 O, 3, is greater than permitted
Generating Molecules:   4%|▍         | 38/1000 [00:00<00:08, 115.22it/s][11:02:53] Explicit valence for atom # 7 N, 4, is greater than permitted
Generating Molecules:   5%|▌         | 53/1000 [00:00<00:07, 126.75it/s][11:02:53] Explicit valence for atom # 5 N, 4, is greater than permitted
[11:02:53] Explicit valence for atom # 5 N, 4, is greater than permitted
[11:02:53] Explicit valence for atom # 5 N, 4, is greater than permitted
Generating Molecules:   8%|▊         | 75/1000 [00:00<00:05, 156.

In [48]:
index = 984

print(properties[index])
print(properties_500[index])
print(abs(properties[index][0][0] - properties[index][0][1]))
print(abs(properties_500[index][0][0] - properties_500[index][0][1]))

tensor([[-9.4245,  2.5634]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([[-9.6047,  2.0674]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor(11.9879, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(11.6721, device='cuda:0', grad_fn=<AbsBackward0>)


In [51]:
index = 856

print(properties[index])
print(properties_500[index])
print(abs(properties[index][0][0] - properties[index][0][1]))
print(abs(properties_500[index][0][0] - properties_500[index][0][1]))

tensor([[-9.3583,  2.4327]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([[-9.6059,  1.9501]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor(11.7911, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(11.5559, device='cuda:0', grad_fn=<AbsBackward0>)


In [54]:
index = 710

print(properties[index])
print(properties_500[index])
print(abs(properties[index][0][0] - properties[index][0][1]))
print(abs(properties_500[index][0][0] - properties_500[index][0][1]))

tensor([[-9.3656,  2.3912]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([[-9.2482,  2.5082]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor(11.7568, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(11.7564, device='cuda:0', grad_fn=<AbsBackward0>)
