In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch_geometric as pyg
import networkx as nx
from MINAR.MinAggGNN import MinAggGNN

seed = 0
torch.manual_seed(seed)
rng = np.random.default_rng(seed)

device = torch.device('cuda')
K = 2
m = 2

model = MinAggGNN(1, 8, K, 1, edge_dim = 1)
model.to(device)

MinAggGNN(1, 1, num_layers=2)

In [2]:
train = torch.load('data/training_data.pt')
test = torch.load('data/test_data.pt')

train_loader = pyg.loader.DataLoader(train, batch_size = len(train))
num_reachable_nodes = sum([data.reachable.sum() for data in train])
test_loader = pyg.loader.DataLoader(test, batch_size = len(test))
num_reachable_test_nodes = sum([data.reachable.sum() for data in test])

In [None]:
from tqdm import tqdm
from MINAR.utils import MultiplicativeLoss

criterion = torch.nn.MSELoss()
test_criterion = MultiplicativeLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)
epochs = 20000
eta = 0.0003

def training():
    model.train()
    for data in train_loader:
        data.to(device)
        out = model(data.x, data.edge_index, edge_attr = data.edge_attr, batch = data.batch)
        mse_loss = criterion(out.flatten()[data.reachable], data.y[data.reachable])
        mse_loss /= num_reachable_nodes
        l1_regularization = torch.tensor(0., device=device)
        for param in model.parameters():
            l1_regularization += param.abs().sum()
        total_loss = mse_loss + eta * l1_regularization
        total_loss.backward()
        optimizer.step()
    return mse_loss, l1_regularization, total_loss

def testing():
    torch.no_grad()
    model.eval()
    total_loss = torch.tensor(0.)
    for data in test_loader:
        data.to(device)
        out = model(data.x, data.edge_index, edge_attr = data.edge_attr, batch = data.batch).flatten()
        total_loss += test_criterion(out[data.reachable], data.y[data.reachable]).cpu()
    total_loss /= num_reachable_test_nodes
    return total_loss.detach().cpu()

mse_losses = torch.zeros(epochs)
l1_regs = torch.zeros(epochs)
test_losses = torch.zeros(epochs)
model_checkpoints = []

pbar = tqdm(range(epochs))
for epoch in pbar:
    optimizer.zero_grad()
    mse_loss, l1_reg, total_loss = training()
    mse_losses[epoch] = mse_loss
    l1_regs[epoch] = l1_reg
    test_loss = testing()
    test_losses[epoch] = test_loss
    if epoch % 100 == 0:
        model_checkpoints.append({k: v.cpu() for k, v in model.state_dict().items()})
    pbar.set_description(f'Train MSE Loss: {float(mse_loss):.4f}, L1 Reg: {float(l1_reg):.4f}, Test Loss: {float(test_loss):.4f}')
model.eval()

Train MSE Loss: 0.0002, L1 Reg: 32.0317, Test Loss: 0.1218:  11%|â–ˆ         | 2108/20000 [03:35<38:34,  7.73it/s]  

In [None]:
torch.save(model.state_dict(), 'model_progress/bellman-ford/model_final.pt')
torch.save(model_checkpoints, 'model_progress/bellman-ford/model_checkpoints.pt')
torch.save(mse_losses, 'model_progress/bellman-ford/mse_losses.pt')
torch.save(l1_regs, 'model_progress/bellman-ford/l1_regs.pt')
torch.save(test_losses, 'model_progress/bellman-ford/test_losses.pt')