In [None]:
%pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric rdkit-pypi tensorboard

In [None]:
import os
import random
from dataclasses import dataclass
import typing as T

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Data
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GATConv, GATv2Conv, global_add_pool
from torch_geometric.utils import to_networkx


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


@dataclass
class Config:
    dataset_path: str = "data/ZINC"
    batch_size: int = 128
    num_workers: int = 2
    subset: bool = True

    model_type: T.Literal["GINE", "GAT", "GATv2"] = "GINE"
    hidden_channels: int = 128
    num_layers: int = 4
    dropout: float = 0.5
    heads: int = 4

    lr: float = 0.001
    weight_decay: float = 1e-5
    epochs: int = 100
    seed: int = 42
    log_dir: str = "runs/zinc_experiment"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


config = Config()
set_seed(config.seed)

In [None]:
train_dataset = ZINC(config.dataset_path, subset=config.subset, split="train")
val_dataset = ZINC(config.dataset_path, subset=config.subset, split="val")
test_dataset = ZINC(config.dataset_path, subset=config.subset, split="test")

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

In [None]:
def visualize_graph(data: Data) -> None:
    G = to_networkx(data, to_undirected=True)
    plt.figure(figsize=(8, 6))
    node_colors = data.x.squeeze().numpy()
    pos = nx.kamada_kawai_layout(G)
    nx.draw(G, pos, with_labels=True, node_color=node_colors, cmap=plt.cm.tab20, node_size=300)
    plt.title(f"Target (logP): {data.y.item():.4f}")
    plt.show()


visualize_graph(train_dataset[0])

ys = [data.y.item() for data in train_dataset]
plt.figure(figsize=(10, 5))
plt.hist(ys, bins=50, color="skyblue", edgecolor="black")
plt.show()

In [None]:
class FlexibleGNN(torch.nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.model_type = config.model_type
        self.num_layers = config.num_layers
        self.dropout = config.dropout

        self.node_emb = torch.nn.Embedding(21, config.hidden_channels)
        self.edge_emb = torch.nn.Embedding(4, config.hidden_channels)

        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for _ in range(config.num_layers):
            if config.model_type == "GINE":
                mlp = Sequential(
                    Linear(config.hidden_channels, 2 * config.hidden_channels),
                    BatchNorm1d(2 * config.hidden_channels),
                    ReLU(),
                    Linear(2 * config.hidden_channels, config.hidden_channels),
                )
                self.convs.append(GINEConv(mlp, train_eps=True, edge_dim=config.hidden_channels))
            elif config.model_type == "GAT":
                self.convs.append(
                    GATConv(
                        config.hidden_channels,
                        config.hidden_channels // config.heads,
                        heads=config.heads,
                        edge_dim=config.hidden_channels,
                    )
                )
            elif config.model_type == "GATv2":
                self.convs.append(
                    GATv2Conv(
                        config.hidden_channels,
                        config.hidden_channels // config.heads,
                        heads=config.heads,
                        edge_dim=config.hidden_channels,
                        share_weights=True,
                    )
                )
            else:
                raise ValueError(f"Unknown model type: {config.model_type}")

            self.batch_norms.append(BatchNorm1d(config.hidden_channels))

        self.out_lin = Sequential(
            Linear(config.hidden_channels, config.hidden_channels), ReLU(), Linear(config.hidden_channels, 1)
        )

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = self.node_emb(x.squeeze())
        edge_attr = self.edge_emb(edge_attr.squeeze())

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            if self.model_type == "GINE":
                x = conv(x, edge_index, edge_attr=edge_attr)
            elif self.model_type in ["GAT", "GATv2"]:
                x = conv(x, edge_index, edge_attr=edge_attr)

            x = batch_norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = global_add_pool(x, batch)

        x = self.out_lin(x)
        return x.squeeze()

In [None]:
def train_epoch(model: torch.nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, device: str) -> float:
    model.train()
    total_loss = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)


@torch.no_grad()
def evaluate(model: torch.nn.Module, loader: DataLoader, criterion: torch.nn.Module, device: str) -> T.Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    for data in loader:
        data = data.to(device)
        out = model(data)
        loss = criterion(out, data.y)
        total_loss += loss.item() * data.num_graphs
        total_mae += (out - data.y).abs().sum().item()
    return total_loss / len(loader.dataset), total_mae / len(loader.dataset)

In [None]:
def train(config: Config, model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader) -> None:
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    criterion = torch.nn.MSELoss()
    writer = SummaryWriter(config.log_dir)

    best_val_mae = float("inf")

    for epoch in range(1, config.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, config.device)
        val_loss, val_mae = evaluate(model, val_loader, criterion, config.device)
        test_loss, test_mae = evaluate(model, test_loader, criterion, config.device)

        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Loss/Val", val_loss, epoch)
        writer.add_scalar("MAE/Val", val_mae, epoch)
        writer.add_scalar("MAE/Test", test_mae, epoch)
        writer.add_scalar("LR", optimizer.param_groups[0]["lr"], epoch)

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            torch.save(model.state_dict(), "best_model.pth")

        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Val MAE: {val_mae:.4f}, Test MAE: {test_mae:.4f}")

    writer.close()
    print(f"Training finished. Best Val MAE: {best_val_mae:.4f}")

In [None]:
model = FlexibleGNN(config).to(config.device)
train(config, model, train_loader, val_loader, test_loader)

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs