In [1]:
import torch
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.utils.data import Subset

In [2]:
# !pip install rdkit

In [3]:
from torch_geometric.datasets import QM9

qm9 = QM9(root='./', transform=None, pre_transform=None)

In [4]:
from torch import nn
import torch_geometric.nn as tgnn
from graphormer.model import Graphormer


model = Graphormer(
    num_layers=3,
    input_dim=5,
    emb_dim=128,
    input_edge_attr_dim=4,
    edge_attr_dim=16,
    output_dim=1,
    num_radial=10,
    radial_min=0,
    radial_max=10,
    num_heads=4,
)

In [5]:
from sklearn.model_selection import train_test_split

test_ids, train_ids = train_test_split([i for i in range(len(qm9))], test_size=0.8, random_state=42)
train_loader = DataLoader(Subset(qm9, train_ids), batch_size=64)
test_loader = DataLoader(Subset(qm9, test_ids), batch_size=64)

In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_functin = nn.MSELoss(reduction="sum")

In [8]:
from tqdm import tqdm
from torch_geometric.nn.pool import global_mean_pool

DEVICE = "cuda"

model.to(DEVICE)
for epoch in range(5):
    model.train()
    batch_loss = 0.0
    for i, batch in enumerate(train_loader):
        batch.to(DEVICE)
        y = batch.y[:, 7]
        optimizer.zero_grad()
        batch.x = batch.x[:, :5]

        edge_attr = torch.zeros(batch.x.shape[0], batch.x.shape[0], 4).to(DEVICE)
        edge_attr[batch.edge_index[0], batch.edge_index[1]] = batch.edge_attr

        batch.edge_attr = edge_attr

        output = global_mean_pool(model(batch), batch.batch)

        loss = loss_functin(output.squeeze(), y)
        batch_loss += loss.item()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if i % 10 == 0:
            print("LOSS", loss.item() / len(batch.y))
    print("TRAIN_LOSS", batch_loss / len(train_ids))

    model.eval()
    with torch.no_grad():
        batch_loss = 0.0



        for batch in tqdm(test_loader):
            batch.to(DEVICE)
            y = batch.y[:, 7]
            batch.x = batch.x[:, :5]
            edge_attr = torch.zeros(batch.x.shape[0], batch.x.shape[0], 4).to(DEVICE)
            edge_attr[batch.edge_index[0], batch.edge_index[1]] = batch.edge_attr

            batch.edge_attr = edge_attr
            with torch.no_grad():
                output = global_mean_pool(model(batch).squeeze(), batch.batch)
                loss = loss_functin(output, y)
                
            batch_loss += loss.item()

    print("EVAL LOSS", batch_loss / len(test_ids))

    

LOSS 125989504.0
LOSS 115867776.0
LOSS 106052224.0
LOSS 99492432.0
LOSS 81012992.0
LOSS 74148688.0
LOSS 67983800.0
LOSS 53378548.0
LOSS 41076088.0
LOSS 35717424.0
LOSS 19474852.0
LOSS 8957322.0
LOSS 5815762.0
LOSS 2303852.5
LOSS 1375826.0
LOSS 1486422.375
LOSS 1143528.25
LOSS 852999.375
LOSS 1231683.5
LOSS 667563.875
LOSS 1279514.75
LOSS 1078682.625
LOSS 1166886.5
LOSS 1148435.25
LOSS 1313434.25
LOSS 1337940.25
LOSS 1442692.5
LOSS 1378165.5
LOSS 997401.875
LOSS 1123959.25
LOSS 914281.3125
LOSS 957916.375
LOSS 1664437.5
LOSS 1012348.75
LOSS 1165281.625
LOSS 1237530.0
LOSS 1058093.0
LOSS 1903017.25
LOSS 1210704.75
LOSS 1595341.875
LOSS 1754575.75
LOSS 1284520.625
LOSS 1022583.8125
LOSS 1411748.5
LOSS 1537039.0
LOSS 961344.375
LOSS 1521828.25
LOSS 1446379.25
LOSS 950908.375
LOSS 1124355.875
LOSS 786653.8125
LOSS 1413205.5
LOSS 1297839.25
LOSS 1167053.5
LOSS 1434335.375
LOSS 1193116.0
LOSS 1504906.75
LOSS 1057144.25
LOSS 938360.375
LOSS 1147558.0
LOSS 1212051.25
LOSS 1833844.0
LOSS 1921363

100%|██████████| 409/409 [00:18<00:00, 21.68it/s]


EVAL LOSS 1286762.841397233
LOSS 1053508.5
LOSS 771471.8125
LOSS 809538.625
LOSS 910274.375
LOSS 1975319.125
LOSS 1540618.25
LOSS 1178334.75
LOSS 1200442.0
LOSS 2360203.25
LOSS 1386791.375
LOSS 1031284.4375
LOSS 1116211.0
LOSS 1664225.5
LOSS 1313573.625
LOSS 1156836.25
LOSS 1554585.5
LOSS 1355772.0
LOSS 859233.625
LOSS 1238049.5
LOSS 682777.25
LOSS 1338279.25
LOSS 1144232.75
LOSS 1147202.25
LOSS 1151116.75
LOSS 1302691.625
LOSS 1324457.5
LOSS 1515842.25
LOSS 1407987.75
LOSS 1065332.0
LOSS 1121113.875
LOSS 974077.625
LOSS 882390.625
LOSS 1659560.125
LOSS 1094194.0
LOSS 1183420.0
LOSS 1258961.75
LOSS 1046128.5625
LOSS 1888862.25
LOSS 1145642.0
LOSS 1578507.75
LOSS 1847499.0
LOSS 1375171.75
LOSS 997508.9375
LOSS 1416063.25
LOSS 1541481.25
LOSS 1024271.9375
LOSS 1523158.125
LOSS 1345392.0
LOSS 948938.625
LOSS 1081888.25
LOSS 753246.1875
LOSS 1388248.75
LOSS 1308362.0
LOSS 1208069.75
LOSS 1436563.25
LOSS 1206327.75
LOSS 1356524.375
LOSS 1023808.0
LOSS 986289.625
LOSS 1111279.25
LOSS 1183772

100%|██████████| 409/409 [00:15<00:00, 26.32it/s]


EVAL LOSS 1295548.0493770542
LOSS 1055480.625
LOSS 771214.625
LOSS 830544.0
LOSS 903456.3125
LOSS 1969276.625
LOSS 1544384.625
LOSS 1170258.0
LOSS 1205579.75
LOSS 2364878.0
LOSS 1374504.0
LOSS 1029308.9375
LOSS 1091530.375
LOSS 1655970.5
LOSS 1313538.875
LOSS 1157654.875
LOSS 1578203.75
LOSS 1355252.0
LOSS 813174.125
LOSS 1232024.75
LOSS 697403.5625
LOSS 1323875.875
LOSS 1187870.625
LOSS 1154526.0
LOSS 1146803.125
LOSS 1313956.5
LOSS 1318671.125
LOSS 1543991.5
LOSS 1398674.75
LOSS 1085521.0
LOSS 1130304.125
LOSS 960623.125
LOSS 858993.5625
LOSS 1590680.25
LOSS 1090940.0
LOSS 1193480.5
LOSS 1278830.0
LOSS 1039464.25
LOSS 1857345.625
LOSS 1139411.5
LOSS 1606492.25
LOSS 1869466.5
LOSS 1393805.75
LOSS 1001390.5625
LOSS 1416204.0
LOSS 1534687.25
LOSS 1046036.9375
LOSS 1499592.5
LOSS 1337200.25
LOSS 946318.625
LOSS 1073577.375
LOSS 773575.125
LOSS 1394116.5
LOSS 1297809.25
LOSS 1205472.25
LOSS 1443639.5
LOSS 1203087.75
LOSS 1319587.375
LOSS 1020574.3125
LOSS 996905.8125
LOSS 1120815.25
LOSS 

100%|██████████| 409/409 [00:15<00:00, 25.96it/s]


EVAL LOSS 1285173.490483834
LOSS 1050643.25
LOSS 763190.0
LOSS 822113.0
LOSS 887309.625
LOSS 1975760.75
LOSS 1539209.0
LOSS 1163099.0
LOSS 1203051.5
LOSS 2362766.0
LOSS 1367309.625
LOSS 1026526.875
LOSS 1099550.5
LOSS 1642013.25
LOSS 1310855.5
LOSS 1158448.75
LOSS 1567104.75
LOSS 1347892.5
LOSS 823866.375
LOSS 1229487.0
LOSS 687674.125
LOSS 1321990.25
LOSS 1163828.75
LOSS 1143096.75
LOSS 1138430.875
LOSS 1299242.625
LOSS 1315443.25
LOSS 1528043.875
LOSS 1400036.625
LOSS 1071237.5
LOSS 1126735.75
LOSS 959884.3125
LOSS 863721.1875
LOSS 1619812.25
LOSS 1084755.125
LOSS 1179961.875
LOSS 1258436.75
LOSS 1035990.375
LOSS 1870282.375
LOSS 1137953.375
LOSS 1589119.25
LOSS 1845904.25
LOSS 1377546.0
LOSS 995358.3125
LOSS 1410440.0
LOSS 1533160.125
LOSS 1029499.5
LOSS 1503097.375
LOSS 1339576.75
LOSS 941117.9375
LOSS 1069057.5
LOSS 758130.625
LOSS 1393681.75
LOSS 1302452.5
LOSS 1199475.0
LOSS 1439875.75
LOSS 1202915.0
LOSS 1344602.75
LOSS 1014532.4375
LOSS 975221.3125
LOSS 1103996.25
LOSS 1167746

KeyboardInterrupt: 

Testing Modules

In [1]:
from graphormer.layers import RadialBasisEmbedding, GraphormerEncoderLayer
from graphormer.model import Graphormer

In [7]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from torch_geometric.data import Data


# Node features (embeddings) for both graphs combined
node_features = torch.rand((12, 64))  # 5 nodes from graph 1 + 7 nodes from graph 2, each with 64-dim features

# Edge attributes for both graphs combined
edge_attributes = torch.rand((12, 12, 16))  # 4 edges from graph 1 + 6 edges from graph 2, each with 16-dim attributes

# Radial basis embedding for both graphs
rb_embedding = torch.rand((12, 12, 5))  # 5 radial basis functions

# Batch pointers
ptr = torch.tensor([0, 5, 12])  # Pointers showing where each graph starts in the node features

pos = torch.rand((12, 3))  # 5 nodes from graph 1 + 7 nodes from graph 2, each with 2-dim positions

# Initialize GraphormerEncoderLayer
# encoder_layer = GraphormerEncoderLayer(emb_dim=64, num_heads=8, num_radial=5, edge_attr_dim=16)
# output = encoder_layer(node_features, rb_embedding, edge_attributes, ptr)
# output.shape

graphormer = Graphormer(
    num_layers=3,
    input_dim=64,
    emb_dim=128,
    input_edge_attr_dim=16,
    edge_attr_dim=64,
    output_dim=128,
    num_radial=5,
    radial_min=0,
    radial_max=10,
    num_heads=4,
)


data = Data(
    x = node_features,
    edge_attr = edge_attributes,
    ptr = ptr,
    batch=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]),
    pos = pos,
)

graphormer(data).shape