In [1]:
# import torch
# print(torch.__version__)


In [2]:
!pip install torch-scatter torch-sparse torch-scatter torch-geometric ogb  -f https://data.pyg.org/whl/torch-2.5.1+cu121.html

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_scatter-2.1.2%2Bpt25cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m83.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_sparse-0.6.18%2Bpt25cu121-cp310-cp310-linux_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m91.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hCollecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)
Collecting outdated>=0.2.0 (from ogb)
  Downloadin

In [3]:
import torch
import torch_sparse
import torch_scatter
from ogb.graphproppred.mol_encoder import BondEncoder, AtomEncoder
from torch import nn as nn
from torch.nn import functional as F
from torch_geometric import nn as nng
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.graphproppred import Evaluator
from copy import copy
import os
import numpy as np
import pandas as pd
from datetime import datetime

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    # MPS is currently slower than CPU due to missing int64 min/max ops
    device = torch.device("cpu")
else:
    device = torch.device("cpu")

In [None]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


set_seed()

In [11]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        network = [
            nn.Linear(in_dim, 2 * in_dim),
            nn.BatchNorm1d(2 * in_dim),
            nn.ReLU(),
            nn.Linear(2 * in_dim, out_dim),
        ]
        self.network = nn.Sequential(*network)

    def forward(self, x):
        return self.network(x)


class OGBMolEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.atom_encoder = AtomEncoder(emb_dim=dim)
        self.bond_encoder = BondEncoder(emb_dim=dim)

    def forward(self, data):
        data = copy(data)
        data.x = self.atom_encoder(data.x)
        data.edge_attr = self.bond_encoder(data.edge_attr)
        return data


class VNAgg(nn.Module):
    def __init__(self, dim, train_eps=False, eps=0.0):
        super().__init__()
        self.mlp = nn.Sequential(MLP(dim, dim), nn.BatchNorm1d(dim), nn.ReLU())
        self.train_eps = train_eps
        self.eps = (
            nn.Parameter(torch.Tensor([eps])) if train_eps else torch.Tensor([eps])
        )

    def forward(self, virtual_node, embeddings, batch_idx):
        if batch_idx.size(0) > 0:
            sum_embeddings = nng.global_add_pool(embeddings, batch_idx)
        else:
            sum_embeddings = torch.zeros_like(virtual_node, device=device)
        virtual_node = (1 + self.eps.to(device)) * virtual_node.to(device) + sum_embeddings.to(device)
        virtual_node = self.mlp(virtual_node)
        return virtual_node


class GlobalPool(nn.Module):
    def __init__(self, fun):
        super().__init__()
        self.fun = getattr(nng, "global_{}_pool".format(fun.lower()))

    def forward(self, data):
        h, batch_idx = data.x, data.batch
        pooled = self.fun(h, batch_idx, size=data.num_graphs)
        return pooled

In [14]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        dim,
        dropout=0.5,
        activation=F.relu,
        virtual_node=False,
        virtual_node_agg=True,
        last_layer=False,
        train_vn_eps=False,
        vn_eps=0.0,
    ):
        super().__init__()
        self.conv = nng.GINEConv(MLP(dim, dim), train_eps=True)
        self.bn = nn.BatchNorm1d(dim)
        self.activation = activation or nn.Identity()
        self.dropout_ratio = dropout
        self.last_layer = last_layer
        self.virtual_node = virtual_node
        self.virtual_node_agg = virtual_node_agg

        if self.virtual_node and self.virtual_node_agg:
            self.virtual_node_agg = VNAgg(dim, train_eps=train_vn_eps, eps=vn_eps)

    def forward(self, data):
        data = copy(data)
        h, edge_index, edge_attr, batch_idx = (
            data.x,
            data.edge_index,
            data.edge_attr,
            data.batch,
        )
        if self.virtual_node:
            h = h + data.virtual_node[batch_idx]
        h = self.conv(h, edge_index, edge_attr)
        h = self.bn(h)
        if not self.last_layer:
            h = self.activation(h)
        h = F.dropout(h, self.dropout_ratio, training=self.training)
        if self.virtual_node and self.virtual_node_agg:
            v = self.virtual_node_agg(data.virtual_node, h, batch_idx)
            v = F.dropout(v, self.dropout_ratio, training=self.training)
            data.virtual_node = v
        data.x = h
        return data

In [6]:
class GINENetwork(nn.Module):
    def __init__(
        self,
        hidden_dim=100,
        out_dim=128,
        num_layers=3,
        dropout=0.5,
        virtual_node=False,
        train_vn_eps=False,
        vn_eps=0.0,
    ):
        super().__init__()
        convs = [
            ConvBlock(
                hidden_dim,
                dropout=dropout,
                virtual_node=virtual_node,
                train_vn_eps=train_vn_eps,
                vn_eps=vn_eps,
            )
            for _ in range(num_layers - 1)
        ]
        convs.append(
            ConvBlock(
                hidden_dim,
                dropout=dropout,
                virtual_node=virtual_node,
                virtual_node_agg=False,
                last_layer=True,
                train_vn_eps=train_vn_eps,
                vn_eps=vn_eps,
            )
        )
        self.network = nn.Sequential(OGBMolEmbedding(hidden_dim), *convs)
        self.aggregate = nn.Sequential(
            GlobalPool("mean"),
            MLP(hidden_dim, out_dim),
        )

        self.virtual_node = virtual_node
        if self.virtual_node:
            self.v0 = nn.Parameter(torch.zeros(1, hidden_dim), requires_grad=True)

    def forward(self, data):
        # Move all batch data to GPU at once instead of piece by piece
        if not hasattr(data, 'device') or data.device != self.device:
            data = data.to(self.device, non_blocking=True)  # Enable async transfer
            
        if self.virtual_node:
            data.virtual_node = self.v0.expand(data.num_graphs, self.v0.shape[-1])
        H = self.network(data)
        return self.aggregate(H)

In [8]:
# load molpcba dataset
dataset = PygGraphPropPredDataset(name="ogbg-molpcba")
# split dataset into train, valid, and test

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/pcba.zip


Downloaded 0.04 GB: 100%|██████████| 39/39 [00:01<00:00, 32.29it/s]
Processing...


Extracting dataset/pcba.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 437929/437929 [00:04<00:00, 98155.40it/s] 


Converting graphs into PyG objects...


100%|██████████| 437929/437929 [00:13<00:00, 33354.76it/s]


Saving...


Done!
  self.data, self.slices = torch.load(self.processed_paths[0])


In [9]:
split_idx = dataset.get_idx_split()
batch_size = 100
train_dataloader = DataLoader(
    dataset[split_idx["train"]], batch_size=batch_size, shuffle=True, num_workers=4
)
valid_dataloader = DataLoader(
    dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False, num_workers=4
)
test_dataloader = DataLoader(
    dataset[split_idx["test"]], batch_size=batch_size, shuffle=False, num_workers=4
)



In [15]:
# def train(
#     dataset,
#     train_dataloader,
#     valid_dataloader,
#     num_layers=5,
#     hidden_dim=400,
#     dropout=0.5,
#     virtual_node=True,
#     train_vn_eps=False,
#     vn_eps=0.0,
#     lr=0.001,
#     epochs=100,
# ):
#     output_dim = dataset.num_tasks
#     model = GINENetwork(
#         hidden_dim=hidden_dim,
#         out_dim=output_dim,
#         num_layers=num_layers,
#         dropout=dropout,
#         virtual_node=virtual_node,
#         train_vn_eps=train_vn_eps,
#         vn_eps=vn_eps,
#     ).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#     criterion = nn.BCEWithLogitsLoss()

#     for epoch in range(epochs):
#         model.train()
#         train_loss = 0
#         for batch in train_dataloader:
#             batch = batch.to(device)
#             optimizer.zero_grad()
#             y_pred = model(batch)
#             y_true = batch.y.float()
#             y_available = ~torch.isnan(y_true)
#             loss = criterion(y_pred[y_available], y_true[y_available])
#             loss.backward()
#             optimizer.step()
#             train_loss += loss.item()
#         train_loss /= len(train_dataloader)
#         if epoch % 10 == 0:
#             model.eval()
#             valid_loss = 0
#             for batch in valid_dataloader:
#                 batch = batch.to(device)
#                 y_pred = model(batch).to(device)
#                 y_true = batch.y.float()
#                 y_available = ~torch.isnan(y_true)
#                 loss = criterion(y_pred[y_available], y_true[y_available])
#                 valid_loss += loss.item()
#             valid_loss /= len(valid_dataloader)
#             print(
#                 f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}"
#             )
#     return model


# model = train(
#     dataset,
#     train_dataloader,
#     valid_dataloader,
#     num_layers=5,
#     hidden_dim=400,
#     dropout=0.5,
#     virtual_node=True,
#     train_vn_eps=False,
#     vn_eps=0.0,
#     lr=0.001,
#     epochs=100,
# )

Epoch 1, Train Loss: 0.0518, Valid Loss: 0.0581
Epoch 11, Train Loss: 0.0395, Valid Loss: 0.0493


KeyboardInterrupt: 

In [None]:
class GINETrainer:
    def __init__(
        self,
        dataset_name="ogbg-molpcba",
        num_layers=5,
        hidden_dim=400,
        dropout=0.5,
        virtual_node=True,
        train_vn_eps=False,
        vn_eps=0.0,
        lr=0.001,
        batch_size=100,
        num_workers=4,
    ):
        # Initialize dataset
        self.dataset_name = dataset_name

        self.dataset = PygGraphPropPredDataset(name=dataset_name)
        self.split_idx = self.dataset.get_idx_split()

        # Initialize dataloaders
        self.train_loader = DataLoader(
            self.dataset[self.split_idx["train"]],
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,  # Enables faster data transfer to GPU
            persistent_workers=True  # Keeps workers alive between epochs
        )
        self.valid_loader = DataLoader(
            self.dataset[self.split_idx["valid"]],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
        )
        self.test_loader = DataLoader(
            self.dataset[self.split_idx["test"]],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
        )

        # Initialize model
        self.model = GINENetwork(
            hidden_dim=hidden_dim,
            out_dim=self.dataset.num_tasks,
            num_layers=num_layers,
            dropout=dropout,
            virtual_node=virtual_node,
            train_vn_eps=train_vn_eps,
            vn_eps=vn_eps,
        ).to(device)

        # Initialize optimizer and criterion
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.BCEWithLogitsLoss()

        # Initialize evaluator
        self.evaluator = Evaluator(name=self.dataset_name)

    def train(self, epochs=100):
        best_valid_ap = 0
        for epoch in range(epochs):
            self.model.train()
            train_loss = 0
            y_true_list = []
            y_pred_list = []

            for batch in self.train_loader:
                batch = batch.to(device)
                self.optimizer.zero_grad()

                y_pred = self.model(batch)
                y_true = batch.y.float()
                y_available = ~torch.isnan(y_true)

                loss = self.criterion(y_pred[y_available], y_true[y_available])
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()

                y_true_list.append(y_true.detach().cpu())
                y_pred_list.append(y_pred.detach().cpu())

            train_loss /= len(self.train_loader)

            if epoch % max(1, self.epochs // 10) == 0:
                # Evaluate on validation set
                valid_ap = self.eval(split="valid")
                print(
                    f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Valid AP: {valid_ap:.4f}"
                )

                # Save best model
                if valid_ap > best_valid_ap:
                    best_valid_ap = valid_ap
                    torch.save(self.model.state_dict(), "best_model.pt")

    @torch.no_grad()
    def eval(self, split="valid"):
        self.model.eval()
        loader = self.valid_loader if split == "valid" else self.test_loader
        y_true_list = []
        y_pred_list = []

        for batch in loader:
            batch = batch.to(device, non_blocking=True)
            y_pred = self.model(batch)
            y_true = batch.y

            y_true_list.append(y_true.detach().cpu())
            y_pred_list.append(y_pred.detach().cpu())

        y_true = torch.cat(y_true_list, dim=0).numpy()
        y_pred = torch.cat(y_pred_list, dim=0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}
        return self.evaluator.eval(input_dict)["ap"]


# Usage example:
trainer = GINETrainer(
    num_layers=5,
    hidden_dim=300,
    dropout=0.5,
    virtual_node=True,
    train_vn_eps=False,
    vn_eps=0.0,
    lr=0.001,
    batch_size=32,
)

# Train the model
trainer.train(epochs=10)

# Evaluate on test set
test_ap = trainer.eval(split="test")
print(f"Test AP: {test_ap:.4f}")

In [None]:
def run_multiple_experiments(
    n_runs=10,
    num_layers=5,
    hidden_dim=400,
    dropout=0.5,
    virtual_node=True,
    train_vn_eps=False,
    vn_eps=0.0,
    lr=0.001,
    batch_size=32,
    epochs=100,
    output_file="experiment_results.csv",
):
    """
    Run multiple training experiments and save results to CSV

    Args:
        n_runs: Number of experimental runs
        num_layers: Number of GNN layers
        hidden_dim: Hidden dimension size
        dropout: Dropout rate
        virtual_node: Whether to use virtual node
        train_vn_eps: Whether to train virtual node epsilon
        vn_eps: Virtual node epsilon value
        lr: Learning rate
        batch_size: Batch size
        epochs: Number of epochs
        output_file: Path to save results CSV
    """

    # Initialize results storage
    results = []

    # Create unique experiment ID using timestamp
    experiment_id = datetime.now().strftime("%Y%m%d_%H%M%S")

    for run in range(n_runs):
        print(f"\nStarting Run {run + 1}/{n_runs}")

        # Set different seed for each run
        set_seed(42 + run)

        # Initialize trainer
        trainer = GINETrainer(
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            dropout=dropout,
            virtual_node=virtual_node,
            train_vn_eps=train_vn_eps,
            vn_eps=vn_eps,
            lr=lr,
            batch_size=batch_size,
        )

        # Train model
        trainer.train(epochs=epochs)

        # Get final validation and test AP
        valid_ap = trainer.eval(split="valid")
        test_ap = trainer.eval(split="test")

        # Store results
        run_results = {
            "experiment_id": experiment_id,
            "run": run + 1,
            "num_layers": num_layers,
            "hidden_dim": hidden_dim,
            "dropout": dropout,
            "virtual_node": virtual_node,
            "train_vn_eps": train_vn_eps,
            "vn_eps": vn_eps,
            "lr": lr,
            "batch_size": batch_size,
            "epochs": epochs,
            "valid_ap": valid_ap,
            "test_ap": test_ap,
        }
        results.append(run_results)

        # Save intermediate results after each run
        df = pd.DataFrame(results)
        df.to_csv(output_file, index=False)

        print(f"Run {run + 1} Results:")
        print(f"Validation AP: {valid_ap:.4f}")
        print(f"Test AP: {test_ap:.4f}")

    # Calculate and print summary statistics
    df = pd.DataFrame(results)
    summary = df[["valid_ap", "test_ap"]].agg(["mean", "std"])

    print("\nSummary Statistics:")
    print(
        f"Validation AP: {summary['valid_ap']['mean']:.4f} ± {summary['valid_ap']['std']:.4f}"
    )
    print(
        f"Test AP: {summary['test_ap']['mean']:.4f} ± {summary['test_ap']['std']:.4f}"
    )

    return df

In [None]:
# Experiments with Virtual Nodes
results_df_vns = run_multiple_experiments(
    n_runs=10,
    num_layers=5,
    hidden_dim=400,
    dropout=0.5,
    virtual_node=True,
    train_vn_eps=False,
    vn_eps=0.0,
    lr=0.001,
    batch_size=32,
    epochs=100,
    output_file="gine_molpcba_results_virtual_nodes.csv",
)

In [None]:
# Experiments without Virtual Nodes
results_df_no_vns = run_multiple_experiments(
    n_runs=10,
    num_layers=5,
    hidden_dim=400,
    dropout=0.5,
    virtual_node=False,
    train_vn_eps=False,
    vn_eps=0.0,
    lr=0.001,
    batch_size=32,
    epochs=100,
    output_file="gine_molpcba_results_no_virtual_nodes.csv",
)