In [1]:
%load_ext autoreload
%autoreload 2

In [30]:
from deepgd.model import *
from deepgd.data import ScalableGraphData
from deepgd.datasets import SuiteSparseDataset
from deepgd.model_old import Stress

In [24]:
import numpy as np
import networkx as nx
import torch
from torch import nn, jit
import torch_geometric as pyg
import torch_scatter
from tqdm.auto import *
from attrs import define, NOTHING

In [26]:
device = "cpu"
for backend, device_name in {
    torch.backends.mps: "mps",
    torch.cuda: "cuda",
}.items():
    if backend.is_available():
        device = device_name

In [34]:
generator = Generator(
    params=NOTHING,
    block_config=NOTHING,
    edge_net_config=NOTHING,
    gnn_config=NOTHING,
    edge_feat_expansion=NOTHING,
    eps=NOTHING
)
generator = jit.script(generator)
generator = torch.compile(generator)



In [6]:
jit.save(generator, 'generator.pt')

In [29]:
model = jit.load('generator.pt').to(device)

In [28]:
device = 'cpu'
lr = 0.001
landmarks = 20
rand_edges = 20
batch_size = 1

In [20]:
torch.random.manual_seed(12345)
dataset = SuiteSparseDataset(
    min_nodes=0,
    max_nodes=7500,
    limit=1000,
    datatype=ScalableGraphData,
    datatype_args=dict(
        device=device,
        landmarks=20,
        random_edges=20
    )
)
shuffled_dataset, perm_idx = dataset.shuffle(return_perm=True)
len(shuffled_dataset), perm_idx

(574,
 tensor([162, 457,  79, 338, 478, 307, 156, 342, 167,  64, 127, 305, 138, 411,
         521, 382, 467,  82, 181, 118,  80, 182, 423, 104, 293, 398, 119,   8,
         140,  73, 361, 134, 391, 332, 566, 545, 531, 456, 431, 102, 269, 573,
          90, 466, 369, 266, 330, 113, 267, 196, 229,   2, 201,  45, 537, 496,
         328,  74, 325, 244, 387,  78, 170, 482, 281, 193, 542, 418, 285, 333,
         321,  10, 406, 329,  14, 558,  88, 314, 356, 529, 344,   5, 427, 376,
          31, 126, 383, 554, 108, 145, 535, 268, 351, 141, 505, 132, 180, 303,
         327, 211, 200, 133, 380, 389,  12,  69,   7, 497, 438, 123, 473, 179,
          34, 569, 386, 517, 448, 221, 403, 175, 352, 203, 253, 336, 164,  75,
         508, 433, 347, 130,  38,  68, 514, 273, 270, 420, 262, 213,  43, 360,
         107, 568, 250, 353, 552, 161, 384, 417,  52, 275, 414,  89, 358, 135,
         248, 543, 177, 442, 254, 192, 154, 546, 272,  28, 477, 373, 488, 235,
         219, 409, 189, 304, 337,  26, 238, 35

In [21]:
train_loader = pyg.loader.DataLoader(shuffled_dataset[:550], batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(shuffled_dataset[550:], batch_size=batch_size, shuffle=False)

In [31]:
criteria = {
    Stress(): 1,
    # EdgeVar(): 0,
    # Occlusion(): 0,
    # IncidentAngle(): 0,
    # TSNEScore(): 0,
}

In [36]:
optim = torch.optim.AdamW(model.parameters(), lr=lr)

In [37]:
for epoch in trange(100):
    model.train()
    losses = []
    edge_ratios = []
    for batch in tqdm(train_loader, disable=False):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            pred = model(init_pos=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr, batch_index=batch.batch)
            loss += w * c(pred, batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
        edge_ratios.append((batch.edge_attr.shape[0] / (batch.n * (batch.n - 1)).sum()).item())
    print(f'[Epoch {epoch}] Train Loss:\t{np.mean(losses):.2f}')
    print(f'[Epoch {epoch}] Edge Ratio:\t{np.mean(edge_ratios):.2f}')
    with torch.no_grad():
        model.train()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = 0
            for c, w in criteria.items():
                pred = model(init_pos=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr, batch_index=batch.batch)
                loss += w * c(pred, batch)
            losses.append(loss.item())
        print(f'[Epoch {epoch}] Val Loss:\t{np.mean(losses):.2f}')
    torch.save(model.state_dict(), "model.ckpt")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/550 [00:00<?, ?it/s]

[Epoch 0] Train Loss:	86045.79
[Epoch 0] Edge Ratio:	0.15
[Epoch 0] Val Loss:	65711.33


  0%|          | 0/550 [00:00<?, ?it/s]

KeyboardInterrupt: 