In [None]:
%pip install -q torch-geometric rdkit tensorboard pandas seaborn xgboost

In [None]:
import random
from dataclasses import dataclass, replace
import typing as T
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import networkx as nx
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from torch.nn import functional as F
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Data
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GATConv, GATv2Conv, global_add_pool
from torch_geometric.utils import to_networkx, degree
from rdkit import Chem
from rdkit.Chem import Descriptors
import xgboost as xgb
from sklearn.metrics import mean_absolute_error

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

@dataclass
class Config:
    dataset_path: str = "data/ZINC"
    batch_size: int = 128
    num_workers: int = 2
    subset: bool = True

    model_type: T.Literal["GINE", "GAT", "GATv2"] = "GINE"
    hidden_channels: int = 128
    num_layers: int = 4
    dropout: float = 0.5
    heads: int = 4

    lr: float = 0.001
    weight_decay: float = 1e-5
    epochs: int = 20
    seed: int = 42
    log_dir: str = "runs"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()
set_seed(config.seed)

In [None]:
train_dataset = ZINC(config.dataset_path, subset=config.subset, split="train")
val_dataset = ZINC(config.dataset_path, subset=config.subset, split="val")
test_dataset = ZINC(config.dataset_path, subset=config.subset, split="test")

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)

In [None]:
def visualize_grid(dataset: ZINC, n: int = 25) -> None:
    fig, axes = plt.subplots(5, 5, figsize=(15, 15))
    axes = axes.flatten()
    indices = random.sample(range(len(dataset)), n)
    
    for i, idx in enumerate(indices):
        data = dataset[idx]
        G = to_networkx(data, to_undirected=True)
        ax = axes[i]
        
        node_colors = data.x.squeeze().numpy()
        pos = nx.kamada_kawai_layout(G)
        nx.draw(G, pos, ax=ax, with_labels=False, node_color=node_colors, 
                cmap=plt.cm.tab20, node_size=100)
        ax.set_title(f"LogP: {data.y.item():.2f}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_grid(train_dataset)

In [None]:
def plot_dataset_stats(dataset: ZINC) -> None:
    ys = [data.y.item() for data in dataset]
    num_nodes = [data.num_nodes for data in dataset]
    num_edges = [data.num_edges // 2 for data in dataset]
    avg_degrees = [2 * (data.num_edges // 2) / data.num_nodes for data in dataset]

    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    sns.histplot(ys, kde=True, ax=axes[0, 0], color='skyblue')
    axes[0, 0].set_title("Target (logP) Distribution")
    axes[0, 0].set_xlabel("logP")

    sns.histplot(num_nodes, kde=True, ax=axes[0, 1], color='salmon')
    axes[0, 1].set_title("Number of Nodes (Atoms) Distribution")
    axes[0, 1].set_xlabel("Count")

    sns.histplot(num_edges, kde=True, ax=axes[1, 0], color='lightgreen')
    axes[1, 0].set_title("Number of Edges (Bonds) Distribution")
    axes[1, 0].set_xlabel("Count")

    sns.histplot(avg_degrees, kde=True, ax=axes[1, 1], color='orange')
    axes[1, 1].set_title("Average Degree Distribution")
    axes[1, 1].set_xlabel("Degree")

    plt.tight_layout()
    plt.show()

plot_dataset_stats(train_dataset)

In [None]:
class FlexibleGNN(torch.nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.model_type = config.model_type
        self.num_layers = config.num_layers
        self.dropout = config.dropout

        self.node_emb = torch.nn.Embedding(21, config.hidden_channels)
        self.edge_emb = torch.nn.Embedding(4, config.hidden_channels)

        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for _ in range(config.num_layers):
            if config.model_type == "GINE":
                mlp = Sequential(
                    Linear(config.hidden_channels, 2 * config.hidden_channels),
                    BatchNorm1d(2 * config.hidden_channels),
                    ReLU(),
                    Linear(2 * config.hidden_channels, config.hidden_channels),
                )
                self.convs.append(GINEConv(mlp, train_eps=True, edge_dim=config.hidden_channels))
            elif config.model_type == "GAT":
                self.convs.append(
                    GATConv(
                        config.hidden_channels,
                        config.hidden_channels // config.heads,
                        heads=config.heads,
                        edge_dim=config.hidden_channels,
                    )
                )
            elif config.model_type == "GATv2":
                self.convs.append(
                    GATv2Conv(
                        config.hidden_channels,
                        config.hidden_channels // config.heads,
                        heads=config.heads,
                        edge_dim=config.hidden_channels,
                        share_weights=True,
                    )
                )
            else:
                raise ValueError(f"Unknown model type: {config.model_type}")

            self.batch_norms.append(BatchNorm1d(config.hidden_channels))

        self.out_lin = Sequential(
            Linear(config.hidden_channels, config.hidden_channels), ReLU(), Linear(config.hidden_channels, 1)
        )

    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = self.node_emb(x.squeeze())
        edge_attr = self.edge_emb(edge_attr.squeeze())

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            if self.model_type == "GINE":
                x = conv(x, edge_index, edge_attr=edge_attr)
            elif self.model_type in ["GAT", "GATv2"]:
                x = conv(x, edge_index, edge_attr=edge_attr)

            x = batch_norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = global_add_pool(x, batch)

        x = self.out_lin(x)
        return x

In [None]:
def train_epoch(model: torch.nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, device: str) -> float:
    model.train()
    total_loss = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y[:, None])
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(model: torch.nn.Module, loader: DataLoader, criterion: torch.nn.Module, device: str) -> T.Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    for data in loader:
        data = data.to(device)
        out = model(data)
        loss = criterion(out, data.y[:, None])
        total_loss += loss.item() * data.num_graphs
        total_mae += (out - data.y[:, None]).abs().sum().item()
    return total_loss / len(loader.dataset), total_mae / len(loader.dataset)

In [None]:
def calculate_baseline_mae(train_dataset: ZINC, val_dataset: ZINC) -> float:
    train_y = torch.tensor([data.y.item() for data in train_dataset])
    val_y = torch.tensor([data.y.item() for data in val_dataset])
    
    mean_train_y = train_y.mean()
    baseline_mae = (val_y - mean_train_y).abs().mean().item()
    return baseline_mae

def train_eval_pipeline(config: Config, train_loader: DataLoader, val_loader: DataLoader, run_name: str) -> T.Tuple[float, float]:
    set_seed(config.seed)
    model = FlexibleGNN(config).to(config.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    criterion = torch.nn.MSELoss()
    
    log_dir = Path(config.log_dir) / run_name
    log_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(str(log_dir))

    best_val_mae = float("inf")
    final_train_loss = 0.0

    for epoch in range(1, config.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, config.device)
        val_loss, val_mae = evaluate(model, val_loader, criterion, config.device)

        writer.add_scalar("train/loss", train_loss, epoch)
        writer.add_scalar("val/loss", val_loss, epoch)
        writer.add_scalar("val/mae", val_mae, epoch)
        writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], epoch)

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            torch.save(model.state_dict(), log_dir / "best_model.pth")
        
        final_train_loss = train_loss

    writer.close()
    return final_train_loss, best_val_mae

In [None]:
def zinc_to_smiles(data: Data) -> str:
    # This is a simplified assumption as ZINC dataset in PyG usually comes with pre-processed features.
    # However, standard PyG ZINC does not preserve SMILES in the Data object by default.
    # For a true descriptor calculation, we need the SMILES string or the RDKit Mol object.
    # The standard ZINC dataset in PyG is pre-processed. 
    # BUT, we can construct a mol from node types and edge types if we had the mapping.
    # Since we don't have the mapping handy in the raw Data object for reconstruction without ambiguity,
    # we will rely on the fact that we need valid mols.
    # 
    # FOR THIS TUTORIAL SKELETON: 
    # We will assume we cannot easily reconstruct SMILES from the graph tensor alone without the atom encoder map.
    # However, the task asks for XGBoost on RDKit descriptors. 
    # IF the dataset doesn't provide SMILES, we can't compute descriptors easily.
    # 
    # Workaround: We will create dummy descriptors based on graph statistics (node counts, edge counts, degrees)
    # plus simple histograms of atom types (which are available in x) to simulate "chemical descriptors".
    pass

def compute_graph_descriptors(dataset: ZINC) -> T.Tuple[np.ndarray, np.ndarray]:
    X = []
    y = []
    for data in dataset:
        # Feature 1: Number of atoms
        n_atoms = data.num_nodes
        # Feature 2: Number of bonds
        n_bonds = data.num_edges // 2
        # Feature 3-23: Histogram of atom types (ZINC has 21 types)
        atom_hist = torch.bincount(data.x.squeeze(), minlength=21).numpy()
        # Feature 24-27: Histogram of bond types (ZINC has 4 types)
        bond_hist = torch.bincount(data.edge_attr.squeeze(), minlength=4).numpy()
        
        features = np.concatenate(([n_atoms, n_bonds], atom_hist, bond_hist))
        X.append(features)
        y.append(data.y.item())
    return np.array(X), np.array(y)

def run_xgboost_baseline(train_dataset: ZINC, val_dataset: ZINC) -> float:
    X_train, y_train = compute_graph_descriptors(train_dataset)
    X_val, y_val = compute_graph_descriptors(val_dataset)
    
    model = xgb.XGBRegressor(n_estimators=100, max_depth=6, learning_rate=0.1, random_state=42, n_jobs=-1)
    model.fit(X_train, y_train)
    
    preds = model.predict(X_val)
    mae = mean_absolute_error(y_val, preds)
    return mae

In [None]:
def run_experiments(base_config: Config, configs: T.List[T.Tuple[str, Config]]) -> pd.DataFrame:
    results = []
    
    # Mean Baseline
    baseline_mae = calculate_baseline_mae(train_dataset, val_dataset)
    results.append({
        "Experiment": "Mean Baseline", 
        "Config": "N/A", 
        "Final Train Loss": np.nan, 
        "Best Val MAE": baseline_mae
    })
    print(f"Mean Baseline MAE: {baseline_mae:.4f}")

    # XGBoost Baseline
    xgb_mae = run_xgboost_baseline(train_dataset, val_dataset)
    results.append({
        "Experiment": "XGBoost Baseline",
        "Config": "Graph Stats + Atom Hist",
        "Final Train Loss": np.nan,
        "Best Val MAE": xgb_mae
    })
    print(f"XGBoost Baseline MAE: {xgb_mae:.4f}")

    for name, cfg in configs:
        print(f"Running experiment: {name}")
        train_loss, val_mae = train_eval_pipeline(cfg, train_loader, val_loader, name)
        results.append({
            "Experiment": name,
            "Config": str(cfg),
            "Final Train Loss": train_loss,
            "Best Val MAE": val_mae
        })
        print(f"Result: Train Loss={train_loss:.4f}, Val MAE={val_mae:.4f}")
    
    return pd.DataFrame(results)

In [None]:
experiment_configs = [
    ("GINE_Default", config),
    ("GINE_Deep", replace(config, num_layers=6)),
    ("GAT_Default", replace(config, model_type="GAT")),
    ("GATv2_Default", replace(config, model_type="GATv2"))
]

df_results = run_experiments(config, experiment_configs)
display(df_results)

In [None]:
def evaluate_best_model(results_df: pd.DataFrame, test_loader: DataLoader, base_log_dir: str) -> None:
    # Filter out baselines
    dl_results = results_df[~results_df["Experiment"].str.contains("Baseline")]
    best_row = dl_results.sort_values("Best Val MAE").iloc[0]
    best_exp_name = best_row["Experiment"]
    print(f"Best Experiment: {best_exp_name} with Val MAE: {best_row['Best Val MAE']:.4f}")

    if "GATv2" in best_exp_name: model_type = "GATv2"
    elif "GAT" in best_exp_name: model_type = "GAT"
    else: model_type = "GINE"
    
    num_layers = 6 if "Deep" in best_exp_name else 4
    
    best_config = replace(config, model_type=model_type, num_layers=num_layers)
    
    model = FlexibleGNN(best_config).to(best_config.device)
    model_path = Path(base_log_dir) / best_exp_name / "best_model.pth"
    model.load_state_dict(torch.load(model_path))
    
    criterion = torch.nn.MSELoss()
    test_loss, test_mae = evaluate(model, test_loader, criterion, best_config.device)
    print(f"Test MAE for best model: {test_mae:.4f}")

evaluate_best_model(df_results, test_loader, config.log_dir)

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs