In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch
import torch_geometric as pyg

In [27]:
from smartgd.model import Generator, Discriminator
from smartgd.data import GraphDrawingData
from smartgd.datasets import  RomeDataset
from smartgd.transformations import Compose, Center, NormalizeRotation, RescaleByStress
from smartgd.criteria import RGANCriterion

In [4]:
GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)

Transform graphs:   0%|          | 0/11531 [00:00<?, ?it/s]

In [5]:
generator = Generator(
    params=Generator.Params(
        num_blocks=10,
        block_depth=2,
        block_width=8,
        block_output_dim=8,
        edge_net_depth=1,
        edge_net_width=16,
        edge_attr_dim=2,
        node_attr_dim=2,
    ),
)

In [6]:
discriminator = Discriminator(
    params=Discriminator.Params(
        num_layers=9,
        hidden_width=16,
        edge_net_shared_depth=8,
        edge_net_embedded_depth=8,
        edge_net_width=64,
        edge_attr_dim=2
    )
)

In [7]:
loader = pyg.loader.DataLoader(dataset, batch_size=8)

In [8]:
batch = next(iter(loader))
batch

GraphDrawingDataBatch(G=[8], perm_index=[2, 13554], edge_metaindex=[784], apsp_attr=[13554], perm_weight=[13554], aggr_metaindex=[13554], pos=[307, 2], name=[8], n=[8], m=[8], edge_pair_metaindex=[2, 11403], num_nodes=307, batch=[307], ptr=[9])

In [9]:
batch.perm_index.min(), batch.perm_index.max()

(tensor(0), tensor(306))

In [10]:
edge_attr = torch.cat([batch.apsp_attr[:, None], 1 / batch.apsp_attr[:, None].square()], dim=-1)

In [11]:
pos = generator(
    init_pos=batch.pos,
    edge_index=batch.perm_index,
    edge_attr=edge_attr,
    batch_index=batch.batch,
)

In [12]:
pos.shape, batch.perm_index.shape, edge_attr.shape, batch.batch.shape

(torch.Size([307, 2]),
 torch.Size([2, 13554]),
 torch.Size([13554, 2]),
 torch.Size([307]))

In [29]:
scores = discriminator(
    pos=pos,
    edge_index=batch.perm_index,
    edge_attr=edge_attr,
    batch_index=batch.batch,
)
scores

tensor([-99.1154, -20.1891, -35.9233, -22.1679, -20.0742, -24.1629, -18.9748,
        -15.2183], grad_fn=<ViewBackward0>)

In [16]:
canonicalize = Compose(
    Center(),
    NormalizeRotation(),
    RescaleByStress(),
)

In [26]:
canonicalize(pos, batch.apsp_attr, batch.perm_index, batch.batch)

torch.Size([8])


tensor([[ 1.0138e+00,  3.5419e-01],
        [-2.0302e+00, -8.4786e-02],
        [ 7.0059e-02,  3.3077e-01],
        [ 6.2695e-01,  1.5262e-01],
        [-5.2490e-01,  2.7800e-01],
        [-1.1946e+00, -4.9322e-01],
        [-2.4869e-01, -2.4773e-02],
        [-1.0503e+00, -4.8479e-01],
        [ 2.9774e-01,  4.2932e-02],
        [ 9.9237e-01,  4.1722e-01],
        [-4.1725e-02,  1.4459e-01],
        [ 5.0276e-02,  1.9612e-02],
        [-1.8492e-01,  6.3846e-02],
        [-8.7123e-01,  2.5164e-01],
        [ 1.0618e+00, -1.7377e-01],
        [ 2.0337e+00, -7.9410e-01],
        [-2.5165e+00,  4.7397e-01],
        [ 5.9211e-01,  1.0110e+00],
        [ 1.2207e+00, -1.3011e+00],
        [ 6.0913e-01,  6.1451e-01],
        [-1.5295e+00, -2.7533e-01],
        [-9.8618e-02,  8.6112e-01],
        [ 9.0376e-01, -3.0392e-01],
        [ 2.1325e+00, -1.2338e+00],
        [ 7.5647e-01, -3.6920e-01],
        [-6.0134e-01,  3.6468e-01],
        [-4.5951e-01, -8.1281e-02],
        [-6.6933e-01,  9.147

In [28]:
criterion = RGANCriterion()

In [31]:
criterion(scores, scores)

tensor([0.6931, 0.6931, 0.6931, 0.6931, 0.6931, 0.6931, 0.6931, 0.6931],
       grad_fn=<NegBackward0>)