In [237]:
%load_ext autoreload
%autoreload 2

from typing import List, Dict, Any
import itertools
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import dense_to_sparse, remove_self_loops
from data_utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [258]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
drop_hydrogen = True

max_num_nodes = 9 if drop_hydrogen else 9

# TODO: pre-transform and store matrices to disk
transform_list = [
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
]
if drop_hydrogen:
    transform_list.append(DropQM9Hydrogen())

transform_list += [
    AddAdjacencyMatrix(max_num_nodes=max_num_nodes),
    AddNodeAttributeMatrix(max_num_nodes=max_num_nodes),
    AddEdgeAttributeMatrix(max_num_nodes=max_num_nodes),
    T.ToDevice(device=device)
]
transform = T.Compose(transform_list)

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features
num_edge_features = dataset.num_edge_features

In [332]:
class Encoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.latent_dim = hparams["latent_dim"]

        # TODO: two graph convolutional layers (32 and 64 channels) with identity connection (edge conditioned graph convolution)
        self.conv1 = GCNConv(in_channels=hparams["num_node_features"], out_channels=32)
        self.conv2 = GCNConv(in_channels=32, out_channels=64)

        self.fc1 = nn.Linear(in_features=64, out_features=128)
        # output 2 time the size of the latent vector
        # one half contains mu and the other half log(sigma)
        self.fc2 = nn.Linear(in_features=128, out_features=self.latent_dim * 2)


    def forward(self, data: Data) -> Tuple[torch.Tensor, torch.Tensor]:
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        mu = x[:, :self.latent_dim]
        log_sigma = x[:, self.latent_dim:]
        return mu, log_sigma
    
class Decoder(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.fcls = nn.Sequential(
            nn.Linear(in_features=hparams["latent_dim"], out_features=128),
            #nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
            #nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            #nn.BatchNorm1d(num_features=512),
            nn.ReLU(),
        )

        self.max_num_nodes = hparams["max_num_nodes"]
        self.num_node_features = hparams["num_node_features"]

        # the atom graph is symmetric so we only predict the upper triangular part
        # and the diagonal that indicates the presence of nodes
        upper_triangular_diag_size = int(self.max_num_nodes * (self.max_num_nodes + 1) / 2)
        self.fc_adjacency = nn.Linear(in_features=512, out_features=upper_triangular_diag_size)

        self.fc_node_features = nn.Linear(in_features=512, out_features=self.max_num_nodes * num_node_features)

        self.max_num_edges = int(self.max_num_nodes * (self.max_num_nodes - 1) / 2)
        self.num_edge_features = hparams["num_edge_features"]
        self.fc_edge_features = nn.Linear(in_features=512, out_features=self.max_num_edges * self.num_edge_features)
        

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.fcls(z)
        # predict upper triangular matrix including the diagonal
        adj_triu_mat = self.fc_adjacency(x)
        node_features = self.fc_node_features(x)
        edge_features = self.fc_edge_features(x)

        # reshape matrices
        node_mat = node_features.view(-1, self.max_num_nodes, self.num_node_features)
        edge_triu_mat = edge_features.view(-1, self.max_num_edges, self.num_edge_features)

        return adj_triu_mat, node_mat, edge_triu_mat


class GraphVAE(nn.Module):

    def __init__(self, hparams: Dict[str, Any]) -> None:
        super().__init__()

        self.encoder = Encoder(hparams=hparams)
        self.decoder = Decoder(hparams=hparams)
        self.latent_dim = hparams["latent_dim"]

        rows, cols = torch.triu_indices(hparams["max_num_nodes"], hparams["max_num_nodes"])
        self.diag_triu_mask = rows == cols

        self.edge_triu_rows, self.edge_triu_cols = torch.triu_indices(hparams["max_num_nodes"], hparams["max_num_nodes"], offset=1)

    def _sample_with_reparameterization(self, mu: torch.Tensor, log_sigma: torch.Tensor) -> torch.Tensor:
        sigma = torch.exp(log_sigma)
        std_norm = torch.randn_like(mu)
        return std_norm * sigma + mu

    def forward(self, data: Data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_sigma = self.encoder(data)
        z = self._sample_with_reparameterization(mu=mu, log_sigma=log_sigma)
        x = self.decoder(z)
        return x
    
    def _kl_divergence(self, mu: torch.Tensor, log_sigma: torch.Tensor) -> torch.Tensor:
        log_sigma_squared = log_sigma + log_sigma
        sigma_squared = torch.exp(log_sigma_squared)
        mu_squared = mu * mu
        kl_div_sample = 0.5 * torch.sum(sigma_squared + mu_squared - log_sigma_squared - 1, dim=1)
        # average over the batch
        return torch.mean(kl_div_sample)
    
    def _reconstruction_loss(
        self, 
        input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
        target: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ):
        input_adj_triu_mat, input_node_mat, input_edge_mat = input
        target_adj_triu_mat, target_node_mat, target_edge_mat = target

        # average loss over nodes and edges separately
        input_adj_triu_mat_diag = input_adj_triu_mat[:, self.diag_triu_mask]
        input_adj_triu_mat_off_diag = input_adj_triu_mat[:, ~self.diag_triu_mask]
        target_adj_triu_mat_diag = target_adj_triu_mat[:, self.diag_triu_mask]
        target_adj_triu_mat_off_diag = target_adj_triu_mat[:, ~self.diag_triu_mask]
        adjacency_loss = (
            F.binary_cross_entropy_with_logits(input=input_adj_triu_mat_diag, target=target_adj_triu_mat_diag)
            + F.binary_cross_entropy_with_logits(input=input_adj_triu_mat_off_diag, target=target_adj_triu_mat_off_diag)
        )

        # compute the node feature loss only for nodes that exist in the input graph
        node_mat_logits = input_node_mat.view(-1, input_node_mat.size(-1))
        node_targets = target_node_mat.argmax(dim=2).view(-1)
        per_node_feature_loss = F.cross_entropy(input=node_mat_logits, target=node_targets, reduction="none")
        node_mask = target_adj_triu_mat_diag.view(-1)
        node_feature_loss = (per_node_feature_loss * node_mask).sum() / node_mask.sum()

        # Compute edge feature loss only for edges that are connected to nodes existing in the input graph
        edge_mat_logits = input_edge_mat.view(-1, input_edge_mat.size(-1))
        edge_targets = target_edge_mat.argmax(dim=2).view(-1)
        per_edge_feature_loss = F.cross_entropy(input=edge_mat_logits, target=edge_targets, reduction="none")
        edge_mask = (
            target_adj_triu_mat_diag[:, self.edge_triu_rows].int() & target_adj_triu_mat_diag[:, self.edge_triu_cols].int()
        ).view(-1)
        edge_feature_loss = (per_edge_feature_loss * edge_mask).sum() / edge_mask.sum()
        
        return adjacency_loss + node_feature_loss + edge_feature_loss
    
    def negative_elbo(self, x: Data):
        mu, log_sigma = self.encoder(x)
        z = self._sample_with_reparameterization(mu=mu, log_sigma=log_sigma)

        x_recon = self.decoder(z)
        x_target = (x.adj_triu_mat, x.node_mat, x.edge_triu_mat)
        return self._reconstruction_loss(input=x_recon, target=x_target) + self._kl_divergence(mu=mu, log_sigma=log_sigma)
    
    def sample(self, num_samples: int, device: str):
        z = torch.randn((num_samples, self.latent_dim), device=device)
        x = self.decoder(z)
        return z, x

In [335]:
hparams = {
    "batch_size": 256,
    "max_num_nodes": max_num_nodes,
    "learning_rate": 1e-3,
    "beta_1": 0.5,
    "epochs": 500,
    "num_node_features": num_node_features,
    "num_edge_features": num_edge_features,
    "latent_dim": 128  # c in the paper
}

batch_size = hparams["batch_size"]
dataloaders = {
    "train_single": DataLoader(train_dataset[16:24], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:batch_size], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

val_subset_count = 32
dataloaders["val_subsets"] = create_validation_subset_loaders(validation_dataset=val_dataset, subset_count=32, batch_size=batch_size)

In [336]:
graph_vae_model = GraphVAE(hparams=hparams).to(device=device)
optimizer = torch.optim.Adam(
    graph_vae_model.parameters(),
    lr=hparams["learning_rate"],
    betas=(hparams["beta_1"], 0.999)
)
epochs = hparams["epochs"]

train_loader = dataloaders["train_small"]
val_subset_loader_iterator = itertools.cycle(dataloaders["val_subsets"])

validation_interval = 50

writer = create_tensorboard_writer(experiment_name="graph-vae")

for epoch in range(epochs):
    graph_vae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader,  desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        train_prediction = graph_vae_model(train_batch)
        train_target = (train_batch.adj_triu_mat, train_batch.node_mat, train_batch.edge_triu_mat)

        loss = graph_vae_model.negative_elbo(x=train_batch)

        loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": loss.item()}, iteration)

        if iteration % validation_interval == 0:
            graph_vae_model.eval()
            val_loss_sum = 0

            # Get the next subset of the validation set
            val_loader = next(val_subset_loader_iterator)
            with torch.no_grad():
                for val_batch in val_loader:
                    val_loss_sum += graph_vae_model.negative_elbo(x=val_batch)
            
            val_loss = val_loss_sum / len(val_loader)
            writer.add_scalars("Loss", {"Validation": val_loss.item()}, iteration)
            
            graph_vae_model.train()

    # Visualization
    # get random mol from train dataset and visualize
    sample_index = random.randint(0, len(train_loader) * batch_size)
    sample = train_dataset[sample_index]
    writer.add_image('Input', molecule_graph_data_to_image(sample, includes_h=not drop_hydrogen), global_step=epoch, dataformats="NCHW")

    # encode and decode the molecule
    graph_vae_model.eval()
    z, x = graph_vae_model.sample(num_samples=1, device=device)
    pred_adj_triu_mat, pred_node_mat, pred_edge_triu_mat = graph_vae_model(sample)

    #####################################
    # Convert predicted matrices to graph
    #####################################

    # drop the batch dimension
    pred_adj_triu_mat = pred_adj_triu_mat[0]
    pred_node_mat = pred_node_mat[0]
    pred_edge_triu_mat = pred_edge_triu_mat[0]

    n = hparams["max_num_nodes"]

    edge_triu_mat = pred_edge_triu_mat.argmax(dim=1).float()
    edge_mat = torch.zeros(n, n, device=device)
    edge_mat[torch.ones(n, n).triu(diagonal=1) == 1] = edge_triu_mat
    edge_mat = edge_mat + 1  # add one so we can so that 0 indicates no node instead of hydrogen

    # convert predicted upper triagular matrix into symmetric edge index and
    adj_triu_mat = torch.where(F.sigmoid(pred_adj_triu_mat) > 0.5, 1.0, 0.0)
    adj_mat = torch.zeros(n, n, device=device)
    adj_mat[torch.ones(n, n).triu() == 1] = adj_triu_mat
    diagonal = adj_mat.diagonal()

    # combine the adjacency matrix with edge features
    edge_mat *= adj_mat

    node_mask = diagonal == 1
    edge_index, edge_attr = dense_to_sparse(adj=edge_mat.unsqueeze(0), mask=node_mask.unsqueeze(0))

    edge_attr = F.one_hot((edge_attr - 1).long(), num_classes=num_edge_features)

    # Adjust indices in edge_index to account for removed nodes
    # Create a mapping from old indices to new indices
    old_to_new_indices = torch.cumsum(node_mask, 0) - 1
    old_to_new_indices[~node_mask] = -1
    new_edge_index = old_to_new_indices[edge_index]

    # Remove edges that contain removed nodes
    valid_edges = (new_edge_index >= 0).all(dim=0)
    new_edge_index = new_edge_index[:, valid_edges]
    new_edge_attr = edge_attr[valid_edges, :]

    edge_index, edge_attr = remove_self_loops(edge_index=new_edge_index, edge_attr=new_edge_attr)

    # convert node feature logits into one-hot vector
    x = F.one_hot(pred_node_mat[node_mask].argmax(dim=1), num_classes=num_node_features)
    reconstructed_sample = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    writer.add_image('Reconstruction', molecule_graph_data_to_image(reconstructed_sample, includes_h=not drop_hydrogen), global_step=epoch, dataformats="NCHW")  


Epoch 1 Training:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 16/16 [00:04<00:00,  3.61it/s]
Epoch 2 Training: 100%|██████████| 16/16 [00:03<00:00,  4.02it/s]
Epoch 3 Training: 100%|██████████| 16/16 [00:04<00:00,  3.95it/s]
Epoch 4 Training: 100%|██████████| 16/16 [00:04<00:00,  3.71it/s]
Epoch 5 Training: 100%|██████████| 16/16 [00:04<00:00,  3.95it/s]
Epoch 6 Training: 100%|██████████| 16/16 [00:04<00:00,  3.97it/s]
Epoch 7 Training: 100%|██████████| 16/16 [00:04<00:00,  3.65it/s]
Epoch 8 Training: 100%|██████████| 16/16 [00:04<00:00,  3.92it/s]
Epoch 9 Training: 100%|██████████| 16/16 [00:04<00:00,  3.96it/s]
Epoch 10 Training: 100%|██████████| 16/16 [00:04<00:00,  3.60it/s]
Epoch 11 Training: 100%|██████████| 16/16 [00:03<00:00,  4.01it/s]
Epoch 12 Training: 100%|██████████| 16/16 [00:04<00:00,  3.92it/s]
Epoch 13 Training: 100%|██████████| 16/16 [00:04<00:00,  3.63it/s]
Epoch 14 Training: 100%|██████████| 16/16 [00:04<00:00,  3.98it/s]
Epoch 15 Training: 100%|██████████| 16/16 [00:04<00:00,  3.91it/s]
Epoc

KeyboardInterrupt: 