In [1]:
import sys
sys.path.append('/Users/xuefengli/24fall/DeepAesthetic/SmartGD')

import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from smartgd.model import Generator, Discriminator
from smartgd.data import GraphDrawingData
from smartgd.datasets import  RomeDataset
from smartgd.metrics import Stress, Crossings
from smartgd.transformations import Compose, Center, NormalizeRotation, RescaleByStress
from smartgd.criteria import RGANCriterion

from egnn_clean import EGNN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cpu"
# for backend, device_name in {
#     torch.backends.mps: "mps",
#     torch.cuda: "cuda",
# }.items():
#     if backend.is_available():
#         device = device_name

In [3]:
model_name = "stress"
batch_size = 8
start_epoch = 0
max_epoch = 2000
max_lr = 0.01
min_lr = 0.0001
wr_period = 200



In [32]:
GraphDrawingData.set_optional_fields([
    "edge_pair_metaindex",
    # "face",
    # "rng"
])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
init_layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)
datalist = list(dataset)
for i, data in enumerate(datalist):
    if i > 550:
        break
    data.pos = torch.tensor(init_layouts[i]).float()
train_datalist = datalist[100:108]
test_datalist = datalist[450:500]
val_datalist = datalist[500:550]

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  self.data, self.slices = torch.load(self.data_path)
Transform graphs: 100%|██████████| 11531/11531 [00:04<00:00, 2792.34it/s]


In [45]:
train_datalist = datalist[100:101]
test_datalist = datalist[450:500]
val_datalist = datalist[500:550]
m=0
for data in train_datalist:
    m = max(len(data.G.nodes), m)
for data in test_datalist:
    m = max(len(data.G.nodes), m)
print(m)

97


# Model

In [46]:
model = EGNN(in_node_nf=2, hidden_nf=64, out_node_nf=2, in_edge_nf=2, n_layers=16, act_fn=torch.nn.LeakyReLU(), tanh=True)
canonicalizer = Compose(
    
    RescaleByStress(),
)
optim = torch.optim.AdamW(model.parameters(), lr=1e-9, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=wr_period)

In [47]:
train_loader = pyg.loader.DataLoader(train_datalist, batch_size=batch_size, shuffle=False)
val_loader = pyg.loader.DataLoader(val_datalist, batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(test_datalist, batch_size=batch_size, shuffle=False)

# preprocess

In [48]:
def generate_init_pos(batch):
    pos = canonicalizer(
        pos=batch.pos,
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    return pos

def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

def rescale_by_stress(pos, apsp, edge_index, batch_index):
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    u_over_d = dist / apsp
    scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index[edge_index[0]])
    scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index[edge_index[0]])
    scale = scatterd_u_over_d_2 / scatterd_u_over_d
    return pos / scale[batch_index][:, None]

criteria = {
    Stress(): 1,
}

In [49]:
import torch

def stress_loss(positions, edge_index, edge_attr, weights=None):
    src, dst = edge_index[0], edge_index[1]
    edge_distances = torch.norm(positions[src] - positions[dst], dim=1, keepdim=True)
    diff = edge_distances - edge_attr
    if weights is not None:
        diff = weights * diff
    loss = (diff ** 2).sum()
    return loss


# Test single pic

In [50]:
for epoch in range(100):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            coors = generate_init_pos(batch)
            feats = generate_init_pos(batch)
            print(batch.perm_index.shape)
            print(get_edge_features(batch.apsp_attr).shape)
            print(get_edge_features(batch.apsp_attr))
            feats, pred = model(
                feats,
                coors,
                batch.perm_index,
                edge_attr = get_edge_features(batch.apsp_attr)
            )
            pos = canonicalizer(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            stress = Stress()
            loss = stress(pos, batch.perm_index, batch.apsp_attr, batch.batch, batch.edge_pair_index)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00,  7.36it/s]


[Epoch 0] Train Loss: 387.6455993652344


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 13.04it/s]


[Epoch 1] Train Loss: 383.2989196777344


100%|██████████| 1/1 [00:00<00:00, 12.51it/s]


torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])
[Epoch 2] Train Loss: 386.708740234375


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 15.19it/s]


[Epoch 3] Train Loss: 547.4810791015625


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 15.91it/s]


[Epoch 4] Train Loss: 395.20489501953125


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 13.26it/s]


[Epoch 5] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00,  7.25it/s]


[Epoch 6] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00,  5.32it/s]


[Epoch 7] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00,  9.42it/s]


[Epoch 8] Train Loss: nan


100%|██████████| 1/1 [00:00<00:00, 14.18it/s]


torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])
[Epoch 9] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00,  7.19it/s]


[Epoch 10] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 15.37it/s]


[Epoch 11] Train Loss: nan


100%|██████████| 1/1 [00:00<00:00, 15.41it/s]


torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])
[Epoch 12] Train Loss: nan


100%|██████████| 1/1 [00:00<00:00, 16.16it/s]


torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])
[Epoch 13] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 11.85it/s]


[Epoch 14] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 16.45it/s]


[Epoch 15] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 15.81it/s]


[Epoch 16] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 13.38it/s]


[Epoch 17] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 14.78it/s]


[Epoch 18] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 14.92it/s]


[Epoch 19] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


100%|██████████| 1/1 [00:00<00:00, 14.47it/s]


[Epoch 20] Train Loss: nan


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 4160])
torch.Size([4160, 2])
tensor([[1.0000, 1.0000],
        [7.0000, 0.0204],
        [2.0000, 0.2500],
        ...,
        [6.0000, 0.0278],
        [1.0000, 1.0000],
        [4.0000, 0.0625]])


  0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: value cannot be converted to type float without overflow