In [None]:
import torch
import torch.nn as nn
from egnn_pytorch import EGNN
import numpy as np

import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from deepgd.model import Generator
from deepgd.data import GraphDrawingData
from deepgd.datasets import  RomeDataset
from deepgd.metrics import Stress

from egnn_pytorch.egnn_pytorch_geometric import EGNN_Network, EGNN_Sparse_Network

class EGNNModel(nn.Module):
    def __init__(self, hidden_dim=32, num_layers=3):
        super().__init__()
        # EGNN layers from egnn-pytorch
        self.egnn = EGNN(
            dim=hidden_dim,               
        )
                
        # Multiple EGNN layers
        self.layers = nn.ModuleList([
            EGNN(
                dim=hidden_dim
            ) for _ in range(num_layers)
        ])
                
    def forward(self, x, coords):
        for layer in self.layers:
            x, coords = layer(x, coords)        
        return x, coords

def stress_loss(pred_pos, edge_index, apsp):
    # Handle batching by ensuring the correct indexing and shape
    start = pred_pos[edge_index[0]]  # First node of the edge
    end = pred_pos[edge_index[1]]    # Second node of the edge
    
    dist = (end - start).norm(p=2, dim=1)  # shape: (num_edges,)
    
    # Ensure apsp has the same shape as dist
    if apsp.shape[0] != dist.shape[0]:
        raise ValueError(f"APSP shape mismatch: Expected {dist.shape[0]}, but got {apsp.shape[0]}")
    
    # Compute the stress loss
    loss = ((dist - apsp) / apsp).pow(2).mean()
    return loss

def generate_random_graph(num_nodes, num_features):
    """
    Generate random graph data for testing
    """
    # Generate random node coordinates
    coords = torch.randn(num_nodes, 3)
    
    # Generate random node features
    features = torch.randn(num_nodes, num_features)
    
    return features, coords


In [76]:
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)
datalist = list(dataset)
for i, data in enumerate(datalist):
    if i > 500:
        break
    data.pos = torch.tensor(layouts[i]).float()
train_datalist = datalist[0:450]
test_datalist = datalist[450:500]
val_datalist = datalist[500:550]

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  self.data, self.slices = torch.load(self.data_path)
Transform graphs: 100%|██████████| 11531/11531 [00:05<00:00, 2283.06it/s]


In [77]:
batch_size = 4
lr = 0.001
decay = 0.998
device = 'cpu'

model = EGNNModel(
    hidden_dim=2,
    num_layers=8
)
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=decay)

In [78]:
# Test single layer

init_pos = datalist[0].pos.unsqueeze(0)
print(init_pos.shape)
coors = init_pos
feats = init_pos
net = EGNN(dim=2)
feats, coors = net(feats, coors)



torch.Size([1, 16, 2])


In [79]:

num_epochs = 5
losses = []
train_datalist = datalist[0:10]

for epoch in range(num_epochs):
    for batch in tqdm(train_datalist):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        init_pos = batch.pos.unsqueeze(0)
        print(init_pos)
        coors = init_pos
        feats = init_pos
        # feats, coors = generate_random_graph(20, 2)
        feats, pred = model(
                feats,
                coors
            )
        print(pred)
        pos = pred[0]
        loss = stress_loss(pos, batch.perm_index, batch.apsp_attr)
        print(loss.item())
        loss.backward(retain_graph=True)
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')


 10%|█         | 1/10 [00:00<00:01,  8.70it/s]

tensor([[[ -96.8112,   47.7321],
         [ 194.0765,   99.4092],
         [ -69.2928,   39.6170],
         [ -23.8750,  -83.9988],
         [  11.2142,    6.4823],
         [-151.6656,   75.6432],
         [-160.1674,   87.1233],
         [ -62.8766, -141.4434],
         [ -89.6383,  -61.7515],
         [  47.7557,   56.6002],
         [  71.4003,  -82.4087],
         [  91.2888,   10.9137],
         [ 130.1548,   63.7983],
         [ 176.8823,  -56.4701],
         [ -39.6213,  -75.6672],
         [ -28.8244,   14.4206]]])
tensor([[[ 42.3846, -20.8929],
         [-84.9964, -43.5305],
         [ 30.3355, -17.3398],
         [ 10.4534,  36.7690],
         [ -4.9075,  -2.8362],
         [ 66.4119, -33.1164],
         [ 70.1382, -38.1444],
         [ 27.5323,  61.9241],
         [ 39.2448,  27.0324],
         [-20.9041, -24.7735],
         [-31.2570,  36.0745],
         [-39.9624,  -4.7761],
         [-56.9843, -27.9284],
         [-77.4550,  24.7264],
         [ 17.3467,  33.1221],
     

100%|██████████| 10/10 [00:00<00:00, 40.57it/s]


tensor([[[ -22.3154,  144.3292],
         [ 273.7987,  111.5100],
         [-103.6608,  169.2280],
         [ 183.2265,   55.5031],
         [ 223.7144,   58.7570],
         [ -49.1775,  120.7634],
         [ 208.2686,   81.8795],
         [-111.9069,   51.2366],
         [  76.1138, -181.0899],
         [-201.9590,   46.1481],
         [ 142.7386,   52.2491],
         [ 144.3938,  -81.2459],
         [ -32.8967, -156.3029],
         [-434.8773,  195.0268],
         [ -20.9228, -125.8012],
         [-289.8093, -132.4203],
         [-244.0366,  -99.3396],
         [  94.0969, -259.1067],
         [  60.6767, -120.4313],
         [ 400.4089,  166.4134],
         [-198.2639,  -66.2589],
         [-369.1114,  162.3752],
         [  76.1138, -181.0899],
         [   2.8178,  -94.1051],
         [-167.6117,   80.1965],
         [-303.3454,  129.7237],
         [ 183.2265,   55.5031],
         [ 223.7144,   58.7570],
         [-237.5794,   97.0721],
         [ 337.1038,  138.9617],
         [

  0%|          | 0/10 [00:00<?, ?it/s]

tensor([[[ -96.8112,   47.7321],
         [ 194.0765,   99.4092],
         [ -69.2928,   39.6170],
         [ -23.8750,  -83.9988],
         [  11.2142,    6.4823],
         [-151.6656,   75.6432],
         [-160.1674,   87.1233],
         [ -62.8766, -141.4434],
         [ -89.6383,  -61.7515],
         [  47.7557,   56.6002],
         [  71.4003,  -82.4087],
         [  91.2888,   10.9137],
         [ 130.1548,   63.7983],
         [ 176.8823,  -56.4701],
         [ -39.6213,  -75.6672],
         [ -28.8244,   14.4206]]])
tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<AddBackward0>)
nan
tensor([[[ 1.2722e+02,  1.6020e+02],
         [ 1.2139e+02, -1.3313e+01],
         [ 1.0611e+01,  5.4986

 50%|█████     | 5/10 [00:00<00:00, 47.10it/s]

tensor([[[-211.5836,  -94.7887],
         [-190.5026,  -64.0636],
         [-134.0330, -157.4993],
         [  -8.7641,  137.5646],
         [-134.5559, -102.9980],
         [-170.2398, -198.3637],
         [  67.4433,  170.3457],
         [ -65.4021,  207.0429],
         [-246.2977,  -88.4186],
         [-246.2977,  -88.4186],
         [ -81.6457,  266.7404],
         [  61.7573,  104.7804],
         [ -61.0842,  142.9463],
         [-302.0927, -112.7737],
         [-302.0927, -112.7737],
         [ -72.9895,  205.3480],
         [ -97.8892,  326.4378],
         [ 132.9782,  105.4171],
         [-114.9786,   64.7867],
         [  39.1507,  133.1943],
         [   1.7458,  127.2528],
         [ 111.7567,  149.2728],
         [  85.7579,  214.8277],
         [-119.2215,  -28.2710],
         [ 292.2055,  -50.6080],
         [ -45.2295,   18.7372],
         [  89.0317,   96.7635],
         [  86.2901,   80.5306],
         [ 109.1670,  131.1308],
         [ 356.6686,  -65.1415],
         [

100%|██████████| 10/10 [00:00<00:00, 43.40it/s]


nan
[Epoch 1] Train Loss: nan


 90%|█████████ | 9/10 [00:00<00:00, 73.67it/s]

tensor([[[ -96.8112,   47.7321],
         [ 194.0765,   99.4092],
         [ -69.2928,   39.6170],
         [ -23.8750,  -83.9988],
         [  11.2142,    6.4823],
         [-151.6656,   75.6432],
         [-160.1674,   87.1233],
         [ -62.8766, -141.4434],
         [ -89.6383,  -61.7515],
         [  47.7557,   56.6002],
         [  71.4003,  -82.4087],
         [  91.2888,   10.9137],
         [ 130.1548,   63.7983],
         [ 176.8823,  -56.4701],
         [ -39.6213,  -75.6672],
         [ -28.8244,   14.4206]]])
tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<AddBackward0>)
nan
tensor([[[ 1.2722e+02,  1.6020e+02],
         [ 1.2139e+02, -1.3313e+01],
         [ 1.0611e+01,  5.4986

100%|██████████| 10/10 [00:00<00:00, 75.03it/s]


tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<AddBackward0>)
nan
[Epoch 2] Train Loss: nan


  0%|          | 0/10 [00:00<?, ?it/s]

tensor([[[ -96.8112,   47.7321],
         [ 194.0765,   99.4092],
         [ -69.2928,   39.6170],
         [ -23.8750,  -83.9988],
         [  11.2142,    6.4823],
         [-151.6656,   75.6432],
         [-160.1674,   87.1233],
         [ -62.8766, -141.4434],
         [ -89.6383,  -61.7515],
         [  47.7557,   56.6002],
         [  71.4003,  -82.4087],
         [  91.2888,   10.9137],
         [ 130.1548,   63.7983],
         [ 176.8823,  -56.4701],
         [ -39.6213,  -75.6672],
         [ -28.8244,   14.4206]]])
tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<AddBackward0>)
nan
tensor([[[ 1.2722e+02,  1.6020e+02],
         [ 1.2139e+02, -1.3313e+01],
         [ 1.0611e+01,  5.4986

100%|██████████| 10/10 [00:00<00:00, 81.55it/s]


tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan

  0%|          | 0/10 [00:00<?, ?it/s]

tensor([[[ -96.8112,   47.7321],
         [ 194.0765,   99.4092],
         [ -69.2928,   39.6170],
         [ -23.8750,  -83.9988],
         [  11.2142,    6.4823],
         [-151.6656,   75.6432],
         [-160.1674,   87.1233],
         [ -62.8766, -141.4434],
         [ -89.6383,  -61.7515],
         [  47.7557,   56.6002],
         [  71.4003,  -82.4087],
         [  91.2888,   10.9137],
         [ 130.1548,   63.7983],
         [ 176.8823,  -56.4701],
         [ -39.6213,  -75.6672],
         [ -28.8244,   14.4206]]])
tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<AddBackward0>)
nan
tensor([[[ 1.2722e+02,  1.6020e+02],
         [ 1.2139e+02, -1.3313e+01],
         [ 1.0611e+01,  5.4986

 80%|████████  | 8/10 [00:00<00:00, 74.44it/s]

tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<AddBackward0>)
nan
tensor([[[-6.2000e+01,  6.6671e+01],
         [-1.3551e+02, -1.9739e+01],
         [-1.2665e-01, -8.6222e+01],
         [ 1.4763e+02, -2.9964e+00],
   

100%|██████████| 10/10 [00:00<00:00, 66.61it/s]

nan
[Epoch 4] Train Loss: nan



