In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from smartgd.model import Generator, Discriminator
from smartgd.data import GraphDrawingData
from smartgd.datasets import  RomeDataset
from smartgd.metrics import Stress
from smartgd.transformations import Compose, Center, NormalizeRotation, RescaleByStress
from smartgd.criteria import RGANCriterion

In [13]:
device = "cpu"
# for backend, device_name in {
#     torch.backends.mps: "mps",
#     torch.cuda: "cuda",
# }.items():
#     if backend.is_available():
#         device = device_name

In [14]:
batch_size = 8
lr = 0.001
decay = 0.998

In [15]:
GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
init_layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)
target_layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)
generator = Generator(
    params=Generator.Params(
        num_blocks=11,
        block_depth=3,
        block_width=8,
        block_output_dim=8,
        edge_net_depth=2,
        edge_net_width=16,
        edge_attr_dim=2,
        node_attr_dim=2,
    ),
).to(device)
discriminator = Discriminator(
    params=Discriminator.Params(
        num_layers=9,
        hidden_width=16,
        edge_net_shared_depth=8,
        edge_net_embedded_depth=8,
        edge_net_width=64,
        edge_attr_dim=2
    )
)
canonicalizer = Compose(
    Center(),
    NormalizeRotation(),
    RescaleByStress(),
)
metrics = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}
criterion = RGANCriterion()

Transform graphs:   0%|          | 0/11531 [00:00<?, ?it/s]

In [16]:
gen_optim = torch.optim.AdamW(generator.parameters(), lr=lr)
dis_optim = torch.optim.AdamW(discriminator.parameters(), lr=lr)
gen_scheduler = torch.optim.lr_scheduler.ExponentialLR(gen_optim, gamma=decay)
dis_scheduler = torch.optim.lr_scheduler.ExponentialLR(dis_optim, gamma=decay)

In [17]:
def create_dataloaders():
    datalist = list(dataset)
    for i, data in enumerate(datalist):
        data.pos = torch.tensor(init_layouts[i]).float()
        data.target_pos = torch.tensor(target_layouts[i]).float()
        data.fake_pos = torch.zeros_like(data.target_pos)
        data.index = i
    train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=batch_size, shuffle=True)
    val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=batch_size, shuffle=False)
    test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

In [18]:
train_loader, val_loader, test_loader = create_dataloaders()

In [19]:
def generate_init_pos(batch):
    # pos = torch.rand_like(batch.pos)
    pos = canonicalizer(
        pos=batch.pos,
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    return pos

def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

In [12]:
# model.load_state_dict(torch.load("model_359.pt", map_location=device))

<All keys matched successfully>

In [29]:
batch = next(iter(train_loader))
batch

GraphDrawingDataBatch(G=[32], perm_index=[2, 102334], edge_metaindex=[4476], apsp_attr=[102334], perm_weight=[102334], aggr_metaindex=[102334], pos=[1672, 2], name=[32], n=[32], m=[32], edge_pair_metaindex=[2, 94147], num_nodes=1672, target_pos=[1672, 2], index=[32], batch=[1672], ptr=[33])

In [31]:
output = forward(batch, train=True)
output

{'fake_pos': tensor([[-5.1948,  2.0996],
         [-1.4264,  0.2930],
         [-1.9859,  1.4574],
         ...,
         [-4.5070, -0.3866],
         [ 1.2863,  1.6330],
         [ 2.6649, -3.0330]], grad_fn=<DivBackward0>),
 'fake_score': tensor([ 196.0659, 1371.5764,  885.7131,  140.4702,  159.6006, 1187.2657,
         1410.8604, 1199.0756,  181.6657,  487.0019,  303.8329,  187.4160,
           60.4895,  607.2047,  171.2368,   45.6460,  623.2626,  645.0079,
          812.8198,   33.4490,  416.7246,  501.5317,  230.6680,  651.7480,
          200.6717,  902.0440,  365.8091, 1375.2031,  349.8654,  577.6883,
          264.9240,  177.8495], grad_fn=<AddBackward0>),
 'fake_logits': tensor([ 5.8247e-01,  2.3285e+01,  5.4998e+00,  1.7368e+00,  2.1116e+00,
          1.4585e+01,  8.0327e+00,  1.2721e+01,  2.1300e+00,  4.3922e+00,
          1.6836e+00,  2.1429e+00, -2.7185e-01,  3.9511e+00,  7.5187e-01,
          2.1842e-02,  3.4293e+00,  4.7576e+00,  5.8462e+00, -1.8753e-01,
         -2.4070e

In [35]:
batch.target_pos = output['fake_pos']
batch.to_data_list()

[GraphDrawingData(G=DiGraph named 'grafo9969.33' with 33 nodes and 78 edges, perm_index=[2, 1056], edge_metaindex=[78], apsp_attr=[1056], perm_weight=[1056], aggr_metaindex=[1056], pos=[33, 2], name='grafo9969.33', n=[1], m=[1], edge_pair_metaindex=[2, 741], target_pos=[33, 2], index=[1], num_nodes=33),
 GraphDrawingData(G=DiGraph named 'grafo8759.94' with 94 nodes and 248 edges, perm_index=[2, 8742], edge_metaindex=[248], apsp_attr=[8742], perm_weight=[8742], aggr_metaindex=[8742], pos=[94, 2], name='grafo8759.94', n=[1], m=[1], edge_pair_metaindex=[2, 7626], target_pos=[94, 2], index=[1], num_nodes=94),
 GraphDrawingData(G=DiGraph named 'grafo6187.81' with 81 nodes and 224 edges, perm_index=[2, 6480], edge_metaindex=[224], apsp_attr=[6480], perm_weight=[6480], aggr_metaindex=[6480], pos=[81, 2], name='grafo6187.81', n=[1], m=[1], edge_pair_metaindex=[2, 6216], target_pos=[81, 2], index=[1], num_nodes=81),
 GraphDrawingData(G=DiGraph named 'grafo5508.33' with 33 nodes and 82 edges, pe

In [20]:
def evaluate(pos, batch):
    score = 0
    for c, w in metrics.items():
        score += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
    return score

In [25]:
def forward(batch, train=False):
    edge_attr = get_edge_features(batch.apsp_attr)
    pred = generator(
        init_pos=generate_init_pos(batch),
        edge_index=batch.perm_index,
        edge_attr=edge_attr,
        batch_index=batch.batch,
    )
    fake_pos = canonicalizer(pred, batch.apsp_attr, batch.perm_index, batch.batch)
    fake_score = evaluate(fake_pos, batch)
    output = {
        'fake_pos': fake_pos,
        'fake_score': fake_score,
    }
    if train:
        fake_logits = discriminator(
            pos=fake_pos,
            edge_index=batch.perm_index,
            edge_attr=edge_attr,
            batch_index=batch.batch,
        )
        real_pos = canonicalizer(batch.target_pos, batch.apsp_attr, batch.perm_index, batch.batch)
        real_score = evaluate(real_pos, batch)
        real_logits = discriminator(
            pos=real_pos,
            edge_index=batch.perm_index,
            edge_attr=edge_attr,
            batch_index=batch.batch,
        )
        output |= {
            'fake_logits': fake_logits,
            'real_pos': real_pos,
            'real_score': real_score,
            'real_logits': real_logits,
        }
    return output
    

In [28]:
for epoch in range(1000):
    train_loader, val_loader, test_loader = create_dataloaders()

    generator.train()
    discriminator.train()
    gen_losses = []
    dis_losses = []
    scores = []

    for batch in tqdm(train_loader):
        batch = batch.to(device)

        generator.zero_grad()
        discriminator.zero_grad()
        output = forward(batch, train=True)
        dis_loss = criterion(encourage=output['real_logits'], discourage=output['fake_logits'])
        dis_loss.backward()
        dis_optim.step()

        generator.zero_grad()
        discriminator.zero_grad()
        output = forward(batch, train=True)
        gen_loss = criterion(encourage=output['fake_logits'], discourage=output['real_logits'])
        gen_loss.backward()
        gen_optim.step()

        gen_losses.append(gen_loss.item())
        dis_losses.append(dis_loss.item())
        scores += output['fake_score'].tolist()

        batch.fake_pos = output['fake_pos']
        data_list = batch.to_data_list()
        for real_score, fake_score, data in zip(output['real_score'], output['real_score'], data_list):
            if fake_score > real_score:
                target_layouts[data['index']] = data['fake_pos'].numpy()
        # print(np.mean(output['fake_score'].tolist()))

    gen_scheduler.step()
    dis_scheduler.step()

    print(f'[Epoch {epoch}] Train Loss:\tgen={np.mean(gen_losses)}\tdis={np.mean(dis_losses)}')
    print(f'[Epoch {epoch}] Train Score:\t{np.mean(scores)}')

    with torch.no_grad():
        generator.eval()
        discriminator.eval()
        scores = []
        for batch in tqdm(test_loader, disable=True):
            batch = batch.to(device)
            output = forward(batch)
            scores += output['fake_score'].tolist()

        print(f'[Epoch {epoch}] Test Score:\t{np.mean(scores)}')

  0%|          | 0/313 [00:00<?, ?it/s]

1502.2365336418152
1013.4369616508484
1028.484326839447
1112.1174216270447
1195.7657984495163
1392.5567245483398
1307.5690392255783
2090.2365527153015
1836.993859887123
1615.2328161001205
1273.5825086832047
1169.8704580068588
1467.5561834573746
1452.4366165399551
1412.351623058319
1281.3525722026825
1111.1325639486313
1451.2303624153137
1111.1522946357727
1137.2717475891113
1101.7943007946014
1190.426450252533
1093.3279912471771
1160.4399166107178
916.7488149404526
1414.460598230362
1199.7763127088547
1229.6246408224106
1352.5197132229805
993.1665745973587
654.2742004394531
1121.6510025262833
1222.7069644927979
1228.1739521026611
1342.7336966991425
976.5077821016312
989.6945796012878
1083.231187582016
1292.5433490276337
1391.1503765583038
1375.043908238411
1333.7593258619308
893.7991659641266
1313.7932803630829
927.9698474407196
1035.6549826860428
940.9055020809174
747.1860494613647
945.3521990776062
1125.4266443252563
861.2695200443268
836.0386967658997
865.8436059951782
971.351586639

NameError: name 'dis_scheduler' is not defined