In [225]:
import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from deepgd.model import Generator
from deepgd.data import GraphDrawingData
from deepgd.datasets import  RomeDataset
from deepgd.metrics import Stress

from egnn_pytorch.egnn_pytorch_geometric import EGNN_Network, EGNN_Sparse_Network


In [226]:
device = "cpu"
# for backend, device_name in {
#     torch.backends.mps: "mps",
#     torch.cuda: "cuda",
# }.items():
#     if backend.is_available():
#         device = device_name

In [227]:
batch_size = 4
lr = 0.001
decay = 0.998

In [None]:
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)
datalist = list(dataset)
for i, data in enumerate(datalist):
    if i > 500:
        break
    data.pos = torch.tensor(layouts[i]).float()
train_datalist = datalist[0:450]
test_datalist = datalist[450:500]
val_datalist = datalist[500:550]

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  self.data, self.slices = torch.load(self.data_path)
Transform graphs: 100%|██████████| 11531/11531 [00:02<00:00, 5208.38it/s]


In [229]:
m=0
for data in train_datalist:
    m = max(len(data.G.nodes), m)
for data in test_datalist:
    m = max(len(data.G.nodes), m)
print(m)

100


# Model

In [230]:
model = EGNN_Network(
    num_tokens = 21,
    num_positions = 1024,
    dim = 32,
    depth = 20,
    coor_weights_clamp_value = 2. 
)

optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=decay)

In [231]:
train_loader = pyg.loader.DataLoader(train_datalist, batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(val_datalist, batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(test_datalist, batch_size=batch_size, shuffle=False)

# preprocess

In [232]:
def generate_init_pos(batch):
    # pos = torch.rand_like(batch.pos)
    pos = rescale_by_stress(
        pos=batch.pos,
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    return pos

def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

def rescale_by_stress(pos, apsp, edge_index, batch_index):
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    u_over_d = dist / apsp
    scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index[edge_index[0]])
    scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index[edge_index[0]])
    scale = scatterd_u_over_d_2 / scatterd_u_over_d
    return pos / scale[batch_index][:, None]

criteria = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}

In [233]:
# feats = torch.randint(0, 21, (1, 512)) # (1, 256)
# coors = torch.randn(1, 512, 2)         # (1, 256, 3)
# mask = torch.ones_like(feats).bool()    # (1, 256)

# feats_out, coors_out = model(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

In [234]:
# def stress_loss(pred_pos, edge_index, apsp):
#     start, end = pred_pos[edge_index[0]], pred_pos[edge_index[1]]
#     dist = (end - start).norm(p=2, dim=1)
#     loss = ((dist - apsp) / apsp).pow(2).mean()
#     return loss

# def rescale_stress(pos, apsp, edge_index):
#     src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
#     # print(pos.shape)
#     # print(src_pos.shape, dst_pos.shape)
#     dist = (dst_pos - src_pos).norm(dim=1)
#     u_over_d = dist / apsp
#     scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, edge_index[0])
#     scatterd_u_over_d = pyg.utils.scatter(u_over_d, edge_index[0])
#     scale = scatterd_u_over_d_2 / scatterd_u_over_d
#     return pos / scale[:, None]

# Test single pic

In [235]:
# single = datalist[0]
# single.draw()


In [236]:

# print(torch.isnan(single.pos))  # Should print tensor([[False, False], ...]) indicating no NaNs
# coors = single.pos
# coors = coors.unsqueeze(0)
# feats = torch.randint(1, 21, coors.shape[:2])
# print(coors)
# feats, pred = model(
#         feats,
#         coors
#     )
# single.pos = pred[0]
# single.draw()

In [237]:
# for epoch in range(2):
#     model.train()
#     losses = []
#     for batch in tqdm(train_datalist):
#         batch = batch.to(device)
#         model.zero_grad()
#         loss = 0
#         init_pos = batch.pos.unsqueeze(0)
#         feats = torch.randint(1, 21, init_pos.shape[:2])
#         coors = init_pos
#         feats, pred = model(
#                 feats,
#                 coors
#             )
#         pos = pred[0]
#         loss += stress_loss(pos, batch.perm_index, batch.apsp_attr)
#         loss.backward(retain_graph=True)
#         optim.step()
#         losses.append(loss.item())
#     scheduler.step()
#     print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')


        
            


In [238]:

# coors = single.pos
# coors = coors.unsqueeze(0)
# feats = torch.randint(1, 21, coors.shape[:2])
# feats, pred = model(
#         feats,
#         coors
#     )
# print(pred[0])

In [None]:
for epoch in range(2):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            coors = generate_init_pos(batch).unsqueeze(0)
            feats = torch.randint(0, 21, (1, batch.pos.shape[0]))
            feats, pred = model( #EGNN
                feats,
                coors
            )
            pos = rescale_by_stress(pred[0], batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')

100%|██████████| 113/113 [02:02<00:00,  1.08s/it]


[Epoch 0] Train Loss: nan


100%|██████████| 113/113 [02:10<00:00,  1.15s/it]

[Epoch 1] Train Loss: nan



