In [2]:
%load_ext autoreload
%autoreload 2

from typing import Dict, Any
import torch.nn as nn
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import *
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
include_hydrogen = False
properties=["homo", "lumo"]

transform_list = [
    SelectQM9TargetProperties(properties=properties),
    SelectQM9NodeFeatures(features=["atom_type"]),
]
if not include_hydrogen:
    transform_list.append(DropQM9Hydrogen())

max_num_nodes = 29 if include_hydrogen else 9
transform_list += [
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    AddEdgeAttributeMatrix(max_num_nodes=max_num_nodes),
]

pre_transform = T.Compose(transform_list)
transform = T.Compose([
    #RandomPermutation(max_num_nodes=max_num_nodes),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", pre_transform=pre_transform, pre_filter=qm9_pre_filter, transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

print(f"Training dataset size = {len(train_dataset)}")
print(f"Validation dataset size = {len(val_dataset)}")
print(f"Test dataset size = {len(test_dataset)}")

Training dataset size = 102445
Validation dataset size = 12806
Test dataset size = 12805


In [3]:
from typing import List

batch_size = 128

dataloaders = {
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

In [4]:
def evaluate_model_performance(validation_loader, model):
    mae_sum = 0
    for batch in tqdm(validation_loader):
        prediction = model(batch)
        mae_sum += torch.mean(torch.abs(prediction - batch.y), dim=0)

    mean_absolute_error = mae_sum / len(validation_loader)

    print(f"HOMO MAE = {mean_absolute_error[0]}")
    print(f"LUMO MAE = {mean_absolute_error[1]}")

val_loader = dataloaders["val"]

In [4]:
train_loader = dataloaders["train"]
homo_mean_list = []
lumo_mean_list = []
for batch in tqdm(train_loader):
    batch_mean = torch.mean(batch.y, dim=0)
    homo_mean_list.append(batch_mean[0])
    lumo_mean_list.append(batch_mean[1])

homo_mean_pred = torch.tensor(homo_mean_list).mean()
lumo_mean_pred = torch.tensor(lumo_mean_list).mean()
print(f"HOMO Mean = {homo_mean_pred}")
print(f"LUMO Mean = {lumo_mean_pred}")

class MeanPredictor(nn.Module):
    def __init__(self, property_mean_values: List[float]):
        super().__init__()
        # create mean prediction and add batch dimension
        self.register_buffer('mean_prediction', torch.tensor(property_mean_values).unsqueeze(0))

    def forward(self, x):
        batch_size = x.y.shape[0]
        return self.mean_prediction.expand(batch_size, -1)

mean_baseline_model = MeanPredictor(property_mean_values=[homo_mean_pred, lumo_mean_pred]).to(device)

evaluate_model_performance(val_loader, mean_baseline_model)

  0%|          | 0/801 [00:00<?, ?it/s]

100%|██████████| 801/801 [00:47<00:00, 16.86it/s]


HOMO Mean = -6.546605110168457
LUMO Mean = 0.32694491744041443


100%|██████████| 101/101 [00:05<00:00, 19.25it/s]

HOMO MAE = 0.4337182939052582
LUMO MAE = 1.042443871498108





In [5]:
from graph_vae.vae import GraphVAE

class PropertyPredictorVAE(nn.Module):

    def __init__(self, graph_vae: GraphVAE):
        super().__init__()
        self.graph_vae = graph_vae

    def forward(self, x: Data):
        z = self.graph_vae.encode(x=x)
        return self.graph_vae.predict_properties(z)

graph_vae_model = GraphVAE.from_pretrained("./checkpoints/graph_vae_20240303_204747.pt")

vae_prop_pred = PropertyPredictorVAE(graph_vae=graph_vae_model).to(device)
vae_prop_pred.eval()

with torch.no_grad():
    evaluate_model_performance(val_loader, vae_prop_pred)

100%|██████████| 101/101 [00:06<00:00, 15.50it/s]

HOMO MAE = 0.17140606045722961
LUMO MAE = 0.1887381374835968





In [6]:
from graph_vae.encoder import Encoder

hparams = {
    "max_num_nodes": 29 if include_hydrogen else 9,
    "adam_beta_1": 0.5,
    "num_node_features": dataset.num_node_features,
    "num_edge_features": dataset.num_edge_features,
    "latent_dim": 64,
    "include_hydrogen": include_hydrogen,
    "properties": properties,
}

class PropertyPredictor(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()
        self.graph_encoder = Encoder(hparams=hparams)
        self.property_count = len(hparams["properties"])
        dim = hparams["latent_dim"] * 2
        self.fc = nn.Sequential(
            nn.BatchNorm1d(dim),
            nn.PReLU(),
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
            nn.PReLU(),
            nn.Linear(dim, self.property_count * 2)
        )

    def forward(self, x: Data):
        z = self.graph_encoder(x)
        # combine tuple (mu, log_sigma) into single latent
        z = torch.cat(list(z), dim=1)

        z = self.fc(z)

        # split output into mean and variance prediction
        mu = z[:, :self.property_count]
        log_var = z[:, self.property_count:]
        return mu, torch.exp(log_var)


In [7]:
from data_utils import create_tensorboard_writer
from tqdm import tqdm
import itertools
from graph_vae.vae import GraphVAE


def train_property_predictor(
        model: PropertyPredictor,
        train_loader: DataLoader,
        val_loader: DataLoader,
        epochs: int,
        tb_writer: SummaryWriter,
    ):
    
    learning_rate = 4e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_function = nn.GaussianNLLLoss(full=True)

    # After how many iterations to validate
    validation_interval = 10
    batches_per_validation = 2

    val_loader_iterator = itertools.cycle(iter(val_loader))

    for epoch in range(epochs):
        # Training
        model.train()
        for batch_index, train_batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} Training")):
            optimizer.zero_grad()
            train_pred_mean, train_pred_var = model(train_batch)
            train_loss = loss_function(train_pred_mean, train_batch.y, train_pred_var)
            train_loss.backward()
            optimizer.step()

            iteration = len(train_loader) * epoch + batch_index
            tb_writer.add_scalars("Loss", {"Training": train_loss.item()}, iteration)

            # Validation
            if iteration % validation_interval == 0:
                model.eval()
                val_loss_sum = 0
                mae_sum = 0

                # Get the next subset of the validation set
                for _ in range(batches_per_validation):
                    with torch.no_grad():
                        val_batch = next(val_loader_iterator)
                        val_pred_mean, val_pred_var = model(val_batch)
                        val_loss_sum += loss_function(val_pred_mean, val_batch.y, val_pred_var)
                        mae_sum += torch.mean(torch.abs(val_pred_mean - val_batch.y), dim=0)
                
                val_loss = val_loss_sum / batches_per_validation
                tb_writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)

                mean_absolute_error = mae_sum / batches_per_validation
                tb_writer.add_scalar("MAE (HOMO)", mean_absolute_error[0], iteration)
                tb_writer.add_scalar("MAE (LUMO)", mean_absolute_error[1], iteration)
                
                model.train()

In [8]:
writer = create_tensorboard_writer(experiment_name="property-predictor-2")

train_loader = dataloaders["train"]

model = PropertyPredictor(hparams=hparams).to(device)

train_property_predictor(
    model=model,
    train_loader=train_loader, 
    val_loader=val_loader,
    epochs=100,
    tb_writer=writer,
)

Epoch 1 Training:   0%|          | 0/801 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 801/801 [00:57<00:00, 14.00it/s]
Epoch 2 Training: 100%|██████████| 801/801 [00:34<00:00, 23.08it/s]
Epoch 3 Training: 100%|██████████| 801/801 [00:34<00:00, 23.40it/s]
Epoch 4 Training: 100%|██████████| 801/801 [00:34<00:00, 22.96it/s]
Epoch 5 Training: 100%|██████████| 801/801 [00:34<00:00, 23.55it/s]
Epoch 6 Training: 100%|██████████| 801/801 [00:34<00:00, 23.04it/s]
Epoch 7 Training: 100%|██████████| 801/801 [00:34<00:00, 22.95it/s]
Epoch 8 Training: 100%|██████████| 801/801 [00:34<00:00, 23.48it/s]
Epoch 9 Training: 100%|██████████| 801/801 [00:35<00:00, 22.86it/s]
Epoch 10 Training: 100%|██████████| 801/801 [00:34<00:00, 23.50it/s]
Epoch 11 Training: 100%|██████████| 801/801 [00:34<00:00, 23.02it/s]
Epoch 12 Training: 100%|██████████| 801/801 [00:34<00:00, 22.97it/s]
Epoch 13 Training: 100%|██████████| 801/801 [00:34<00:00, 23.52it/s]
Epoch 14 Training: 100%|██████████| 801/801 [00:34<00:00, 22.95it/s]
Epoch 15 Training: 100%|██████████| 801/801

In [7]:
ckpt_file = "./checkpoints/property_predictor.pt"

In [8]:
torch.save({
        "model_state_dict": model.state_dict(),
    },
    ckpt_file
)

In [9]:

model = PropertyPredictor(hparams=hparams).to(device)
model.eval()

model.load_state_dict(torch.load(ckpt_file)["model_state_dict"])

with torch.no_grad():
    evaluate_model_performance(val_loader, model)

  0%|          | 0/101 [00:00<?, ?it/s]

100%|██████████| 101/101 [00:03<00:00, 31.30it/s]

HOMO MAE = 0.11913200467824936
LUMO MAE = 0.1312396079301834



