In [None]:
import torch
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.utils.data import Subset

In [None]:
# !pip install rdkit

In [None]:
from torch_geometric.datasets import QM9

qm9 = QM9(root='./', transform=None, pre_transform=None)

In [None]:
from torch import nn
import torch_geometric.nn as tgnn
from graphormer.model import Graphormer


model = Graphormer(
    num_layers=3,
    input_dim=5,
    emb_dim=128,
    input_edge_attr_dim=4,
    edge_attr_dim=16,
    output_dim=1,
    num_radial=10,
    radial_min=0,
    radial_max=10,
    num_heads=4,
)

In [None]:
from sklearn.model_selection import train_test_split

test_ids, train_ids = train_test_split([i for i in range(len(qm9))], test_size=0.8, random_state=42)
train_loader = DataLoader(Subset(qm9, train_ids), batch_size=64)
test_loader = DataLoader(Subset(qm9, test_ids), batch_size=64)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_functin = nn.MSELoss(reduction="sum")

In [None]:
from tqdm import tqdm
from torch_geometric.nn.pool import global_mean_pool

DEVICE = "cuda"

model.to(DEVICE)
for epoch in range(5):
    model.train()
    batch_loss = 0.0
    for i, batch in enumerate(train_loader):
        batch.to(DEVICE)
        y = batch.y[:, 7]
        optimizer.zero_grad()
        batch.x = batch.x[:, :5]

        edge_attr = torch.zeros(batch.x.shape[0], batch.x.shape[0], 4).to(DEVICE)
        edge_attr[batch.edge_index[0], batch.edge_index[1]] = batch.edge_attr

        batch.edge_attr = edge_attr

        output = global_mean_pool(model(batch), batch.batch)

        loss = loss_functin(output.squeeze(), y)
        batch_loss += loss.item()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if i % 10 == 0:
            print("LOSS", loss.item() / len(batch.y))
    print("TRAIN_LOSS", batch_loss / len(train_ids))

    model.eval()
    with torch.no_grad():
        batch_loss = 0.0



        for batch in tqdm(test_loader):
            batch.to(DEVICE)
            y = batch.y[:, 7]
            batch.x = batch.x[:, :5]
            edge_attr = torch.zeros(batch.x.shape[0], batch.x.shape[0], 4).to(DEVICE)
            edge_attr[batch.edge_index[0], batch.edge_index[1]] = batch.edge_attr

            batch.edge_attr = edge_attr
            with torch.no_grad():
                output = global_mean_pool(model(batch).squeeze(), batch.batch)
                loss = loss_functin(output, y)
                
            batch_loss += loss.item()

    print("EVAL LOSS", batch_loss / len(test_ids))

    

Testing Modules

In [None]:
from graphormer.layers import RadialBasisEmbedding, GraphormerEncoderLayer
from graphormer.model import Graphormer

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from torch_geometric.data import Data


# Node features (embeddings) for both graphs combined
node_features = torch.rand((12, 64))  # 5 nodes from graph 1 + 7 nodes from graph 2, each with 64-dim features

# Edge attributes for both graphs combined
edge_attributes = torch.rand((12, 12, 16))  # 4 edges from graph 1 + 6 edges from graph 2, each with 16-dim attributes

# Radial basis embedding for both graphs
rb_embedding = torch.rand((12, 12, 5))  # 5 radial basis functions

# Batch pointers
ptr = torch.tensor([0, 5, 12])  # Pointers showing where each graph starts in the node features

pos = torch.rand((12, 3))  # 5 nodes from graph 1 + 7 nodes from graph 2, each with 2-dim positions

# Initialize GraphormerEncoderLayer
# encoder_layer = GraphormerEncoderLayer(emb_dim=64, num_heads=8, num_radial=5, edge_attr_dim=16)
# output = encoder_layer(node_features, rb_embedding, edge_attributes, ptr)
# output.shape

graphormer = Graphormer(
    num_layers=3,
    input_dim=64,
    emb_dim=128,
    input_edge_attr_dim=16,
    edge_attr_dim=64,
    output_dim=128,
    num_radial=5,
    radial_min=0,
    radial_max=10,
    num_heads=4,
)


data = Data(
    x = node_features,
    edge_attr = edge_attributes,
    ptr = ptr,
    batch=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]),
    pos = pos,
)

graphormer(data).shape