In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import torch
import torch_geometric as pyg
import numpy as np
from tqdm.auto import *

In [None]:
from deepgd_demo import *

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = RomeDataset()
model = DeepGD().to(device)
criteria = {
    Stress(): 1,
    EdgeVar(): 0,
    Occlusion(): 0,
    IncidentAngle(): 0,
    TSNEScore(): 0,
}
optim = torch.optim.AdamW(model.parameters())

In [None]:
datalist = list(dataset)
random.seed(12345)
random.shuffle(datalist)

In [None]:
train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=128, shuffle=True)
val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=128, shuffle=False)
test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=128, shuffle=False)

In [None]:
for epoch in range(1000):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            loss += w * c(model(batch), batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')
    with torch.no_grad():
        model.eval()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = criterion(model(batch), batch)
            losses.append(loss.item())
        print(f'[Epoch {epoch}] Val Loss: {np.mean(losses)}')
        