In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from deepgd.model import Generator
from deepgd.data import GraphDrawingData
from deepgd.datasets import  RomeDataset
from deepgd.metrics import Stress

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = "cpu"
for backend, device_name in {
    torch.backends.mps: "mps",
    torch.cuda: "cuda",
}.items():
    if backend.is_available():
        device = device_name

In [5]:
batch_size = 32
lr = 0.001
decay = 0.998

In [6]:
# GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)
params = Generator.Params(
    num_blocks=11,
    block_depth=3,
    block_width=8,
    block_output_dim=8,
    edge_net_depth=2,
    edge_net_width=16,
    edge_attr_dim=2,
    node_attr_dim=2,
)
model = Generator(
    params=params,
).to(device)
criteria = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
Processing...
Loading graphs: 100%|██████████| 11534/11534 [00:49<00:00, 235.32it/s]
Done!
  self.data, self.slices = torch.load(self.data_path)
Transform graphs: 100%|██████████| 11531/11531 [00:02<00:00, 5444.54it/s]


In [7]:
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=decay)

In [9]:
datalist = list(dataset)
for i, data in enumerate(datalist):
    data.pos = torch.tensor(layouts[i]).float()
print(datalist[0])

GraphDrawingData(G=DiGraph named 'grafo929.16' with 16 nodes and 42 edges, perm_index=[2, 240], edge_metaindex=[42], apsp_attr=[240], perm_weight=[240], aggr_metaindex=[240], pos=[16, 2], name='grafo929.16', n=[1], m=[1], num_nodes=16)


In [13]:
train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=batch_size, shuffle=False)

In [14]:
def generate_init_pos(batch):
    # pos = torch.rand_like(batch.pos)
    pos = rescale_by_stress(
        pos=batch.pos,
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    return pos

def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

def rescale_by_stress(pos, apsp, edge_index, batch_index):
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    u_over_d = dist / apsp
    scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index[edge_index[0]])
    scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index[edge_index[0]])
    scale = scatterd_u_over_d_2 / scatterd_u_over_d
    return pos / scale[batch_index][:, None]

In [16]:
# model.load_state_dict(torch.load("model_359.pt", map_location=device))
model.load_state_dict(torch.load("model_stress_only.pt", map_location=device))

  model.load_state_dict(torch.load("model_stress_only.pt", map_location=device))


<All keys matched successfully>

In [18]:
for epoch in range(1000):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')
    with torch.no_grad():
        model.eval()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = 0
            for c, w in criteria.items():
                pred = model(
                    init_pos=generate_init_pos(batch),
                    edge_index=batch.perm_index,
                    edge_attr=get_edge_features(batch.apsp_attr),
                    batch_index=batch.batch,
                )
                pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
                loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
            losses.append(loss.item())
        print(f'[Epoch {epoch}] Val Loss: {np.mean(losses)}')

  0%|          | 0/313 [00:36<?, ?it/s]


KeyboardInterrupt: 

In [17]:
with torch.no_grad():
    model.eval()
    losses = []
    for batch in tqdm(test_loader, disable=True):
        batch = batch.to(device)
        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        losses.append(loss.item())
    print(f'Test Loss: {np.mean(losses)}')

Test Loss: 238.12293910980225
