In [None]:
# import torch
# print(torch.__version__)


In [None]:
!pip install torch-scatter torch-sparse torch-scatter torch-geometric ogb  -f https://data.pyg.org/whl/torch-2.5.1+cu121.html

In [None]:
import torch
import torch_sparse
import torch_scatter
from ogb.graphproppred.mol_encoder import BondEncoder, AtomEncoder
from torch import nn as nn
from torch.nn import functional as F
from torch_geometric import nn as nng
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset
from copy import copy

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        network = [
            nn.Linear(in_dim, 2 * in_dim),
            nn.BatchNorm1d(2 * in_dim),
            nn.ReLU(),
            nn.Linear(2 * in_dim, out_dim),
        ]
        self.network = nn.Sequential(*network)

    def forward(self, x):
        return self.network(x)


class OGBMolEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.atom_encoder = AtomEncoder(emb_dim=dim)
        self.bond_encoder = BondEncoder(emb_dim=dim)

    def forward(self, data):
        data = copy(data)
        data.x = self.atom_encoder(data.x)
        data.edge_attr = self.bond_encoder(data.edge_attr)
        return data


class VNAgg(nn.Module):
    def __init__(self, dim, train_eps=False, eps=0.0):
        super().__init__()
        self.mlp = nn.Sequential(MLP(dim, dim), nn.BatchNorm1d(dim), nn.ReLU())
        self.train_eps = train_eps
        self.eps = (
            nn.Parameter(torch.Tensor([eps])) if train_eps else torch.Tensor([eps])
        )

    def forward(self, virtual_node, embeddings, batch_idx):
        if batch_idx.size(0) > 0:
            sum_embeddings = nng.global_add_pool(embeddings, batch_idx)
        else:
            sum_embeddings = torch.zeros_like(virtual_node, device=device)
        virtual_node = (1 + self.eps) * virtual_node + sum_embeddings
        virtual_node = self.mlp(virtual_node)
        return virtual_node


class GlobalPool(nn.Module):
    def __init__(self, fun):
        super().__init__()
        self.fun = getattr(nng, "global_{}_pool".format(fun.lower()))

    def forward(self, data):
        h, batch_idx = data.x, data.batch
        pooled = self.fun(h, batch_idx, size=data.num_graphs)
        return pooled

In [None]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        dim,
        dropout=0.5,
        activation=F.relu,
        virtual_node=False,
        virtual_node_agg=True,
        last_layer=False,
        train_vn_eps=False,
        vn_eps=0.0,
    ):
        super().__init__()
        self.conv = nng.GINEConv(MLP(dim, dim), train_eps=True)
        self.bn = nn.BatchNorm1d(dim)
        self.activation = activation or nn.Identity()
        self.dropout_ratio = dropout
        self.last_layer = last_layer
        self.virtual_node = virtual_node
        self.virtual_node_agg = virtual_node_agg

        if self.virtual_node and self.virtual_node_agg:
            self.virtual_node_agg = VNAgg(dim, train_eps=train_vn_eps, eps=vn_eps)

    def forward(self, data):
        data = copy(data)
        h, edge_index, edge_attr, batch_idx = (
            data.x,
            data.edge_index,
            data.edge_attr,
            data.batch,
        )
        if self.virtual_node:
            h = h + data.virtual_node[batch_idx]
        h = self.conv(h, edge_index, edge_attr)
        h = self.bn(h)
        if not self.last_layer:
            h = self.activation(h)
        h = F.dropout(h, self.dropout_ratio, training=self.training)
        if self.virtual_node and self.virtual_node_agg:
            v = self.virtual_node_agg(data.virtual_node, h, batch_idx)
            v = F.dropout(v, self.dropout_ratio, training=self.training)
            data.virtual_nnn.Functionalode = v
        data.x = h
        return data

In [None]:
class GINENetwork(nn.Module):
    def __init__(
        self,
        hidden_dim=100,
        out_dim=128,
        num_layers=3,
        dropout=0.5,
        virtual_node=False,
        train_vn_eps=False,
        vn_eps=0.0,
    ):
        super().__init__()
        convs = [
            ConvBlock(
                hidden_dim,
                dropout=dropout,
                virtual_node=virtual_node,
                train_vn_eps=train_vn_eps,
                vn_eps=vn_eps,
            )
            for _ in range(num_layers - 1)
        ]
        convs.append(
            ConvBlock(
                hidden_dim,
                dropout=dropout,
                virtual_node=virtual_node,
                virtual_node_agg=False,
                last_layer=True,
                train_vn_eps=train_vn_eps,
                vn_eps=vn_eps,
            )
        )
        self.network = nn.Sequential(OGBMolEmbedding(hidden_dim), *convs)
        self.aggregate = nn.Sequential(
            GlobalPool("mean"),
            MLP(hidden_dim, out_dim),
        )

        self.virtual_node = virtual_node
        if self.virtual_node:
            self.v0 = nn.Parameter(torch.zeros(1, hidden_dim), requires_grad=True)

    def forward(self, data):
        if self.virtual_node:
            data.virtual_node = self.v0.expand(data.num_graphs, self.v0.shape[-1])
        H = self.network(data)
        return self.aggregate(H)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    # MPS is currently slower than CPU due to missing int64 min/max ops
    device = torch.device("cpu")
else:
    device = torch.device("cpu")

In [None]:
# load molpcba dataset
dataset = PygGraphPropPredDataset(name="ogbg-molpcba")
# split dataset into train, valid, and test

In [None]:
split_idx = dataset.get_idx_split()
batch_size = 100
train_dataloader = DataLoader(
    dataset[split_idx["train"]], batch_size=batch_size, shuffle=True, num_workers=4
)
valid_dataloader = DataLoader(
    dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False, num_workers=4
)
test_dataloader = DataLoader(
    dataset[split_idx["test"]], batch_size=batch_size, shuffle=False, num_workers=4
)

In [None]:
def train(
    dataset,
    train_dataloader,
    valid_dataloader,
    num_layers=5,
    hidden_dim=400,
    dropout=0.5,
    virtual_node=True,
    train_vn_eps=False,
    vn_eps=0.0,
    lr=0.001,
    epochs=100,
):
    output_dim = dataset.num_tasks
    model = GINENetwork(
        hidden_dim=hidden_dim,
        out_dim=output_dim,
        num_layers=num_layers,
        dropout=dropout,
        virtual_node=virtual_node,
        train_vn_eps=train_vn_eps,
        vn_eps=vn_eps,
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            y_pred = model(batch)
            y_true = batch.y.float()
            y_available = ~torch.isnan(y_true)
            loss = criterion(y_pred[y_available], y_true[y_available])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_dataloader)
        if epoch % 10 == 0:
            model.eval()
            valid_loss = 0
            for batch in valid_dataloader:
                batch = batch.to(device)
                y_pred = model(batch).to(device)
                y_true = batch.y.float()
                y_available = ~torch.isnan(y_true)
                loss = criterion(y_pred[y_available], y_true[y_available])
                valid_loss += loss.item()
            valid_loss /= len(valid_dataloader)
            print(
                f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}"
            )
    return model


model = train(
    dataset,
    train_dataloader,
    valid_dataloader,
    num_layers=5,
    hidden_dim=400,
    dropout=0.5,
    virtual_node=True,
    train_vn_eps=False,
    vn_eps=0.0,
    lr=0.001,
    epochs=100,
)