In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import random

import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from smartgd.model import Generator
from smartgd.data import GraphDrawingData
from smartgd.datasets import  RomeDataset
from smartgd.metrics import Stress

In [14]:
device = "cpu"
for backend, device_name in {
    torch.backends.mps: "mps",
    torch.cuda: "cuda",
}.items():
    if backend.is_available():
        device = device_name

In [15]:
batch_size = 128
lr = 0.01

In [16]:
GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
params = Generator.Params(
    num_blocks=9,
    block_depth=3,
    block_width=8,
    block_output_dim=8,
    edge_net_depth=2,
    edge_net_width=16,
    edge_attr_dim=2,
    node_attr_dim=2,
)
model = Generator(
    params=params,
).to(device)
criteria = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}
optim = torch.optim.AdamW(model.parameters(), lr=lr)

Transform graphs:   0%|          | 0/11531 [00:00<?, ?it/s]

In [17]:
datalist = list(dataset)
random.seed(12345)
random.shuffle(datalist)

In [18]:
train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=batch_size, shuffle=False)

In [19]:
def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

def rescale_by_stress(pos, apsp, edge_index, batch_index):
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    u_over_d = dist / apsp
    scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index)
    scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index)
    scale = scatterd_u_over_d_2 / scatterd_u_over_d
    return pos / scale[batch_index][:, None]

In [20]:
for epoch in range(start, 1000):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            loss += w * c(model(
                init_pos=batch.pos,
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            ), batch.perm_index, batch.apsp_attr)
        # print((loss / max(batch.batch)).item())
        loss.backward()
        optim.step()
        losses.append((loss / max(batch.batch)).item())
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')
    with torch.no_grad():
        model.eval()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = 0
            for c, w in criteria.items():
                loss += w * c(model(
                    init_pos=batch.pos,
                    edge_index=batch.perm_index,
                    edge_attr=get_edge_features(batch.apsp_attr),
                    batch_index=batch.batch,
                ), batch.perm_index, batch.apsp_attr)
            losses.append((loss / max(batch.batch)).item())
        print(f'[Epoch {epoch}] Val Loss: {np.mean(losses)}')

  0%|          | 0/79 [00:00<?, ?it/s]

3795.23828125
3721.304443359375
2914.576904296875
2178.9404296875
1765.6253662109375
1903.3004150390625
1453.254150390625
1519.72021484375
1607.1717529296875
1319.3966064453125
1266.454345703125
1238.0052490234375
1089.3865966796875
1231.103759765625
1343.09423828125


KeyboardInterrupt: 