In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
import numpy as np
import torch
import torch_geometric as pyg
import networkx as nx
from model.MinAggGNN import MinAggGNN

seed = 4
torch.manual_seed(seed)
rng = np.random.default_rng(seed)

device = torch.device('cuda')
K = 2
m = 2

model = MinAggGNN(2, 8, K, 2, edge_dim = 1)
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


MinAggGNN(2, 2, num_layers=2)

In [None]:
train = torch.load('../data/training_data.pt', weights_only=False)
test = torch.load('../data/test_data.pt', weights_only=False)

for data in train:
    data.x = torch.cat([data.x, data.x_bfs], 1)

for data in test:
    data.x = torch.cat([data.x, data.x_bfs], 1)

train_loader = pyg.loader.DataLoader(train, batch_size = len(train))
num_reachable_nodes = sum([data.reachable.sum() for data in train])
total_nodes = sum([data.num_nodes for data in train])
test_loader = pyg.loader.DataLoader(test, batch_size = len(test))
num_reachable_test_nodes = sum([data.reachable.sum() for data in test])

print(f'Percent of reachable nodes (Train): {float(num_reachable_nodes / total_nodes)}')
print(f'Percent of reachable nodes (Test): {float(num_reachable_test_nodes / sum([data.num_nodes for data in test]))}')
weight = ((total_nodes - num_reachable_nodes) / num_reachable_nodes).detach().clone().to(device)

Percent of reachable nodes (Train): 0.9124087691307068
Percent of reachable nodes (Test): 0.8161452412605286


In [3]:
from tqdm import tqdm
from model.CustomLosses import MultiplicativeLoss

dst_criterion = torch.nn.MSELoss()
bfs_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weight)
test_criterion = MultiplicativeLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)
epochs = 20000
eta = 0.001
alpha = 1.0
beta = 25.0
eps = 0.001

def training():
    model.train()
    for data in train_loader:
        data.to(device)
        out = model(data.x, data.edge_index, edge_attr = data.edge_attr, batch = data.batch)
        mse_loss = dst_criterion(out[:,0].flatten()[data.reachable], data.y[data.reachable])
        mse_loss /= num_reachable_nodes
        bce_loss = bfs_criterion(out[:,1].flatten(), data.reachable.float())
        bce_loss /= data.num_nodes
        l1_regularization = torch.tensor(0., device=device)
        l0_norm = torch.tensor(0., device=device)
        for param in model.parameters():
            l1_regularization += param.abs().sum()
            l0_norm += (param.abs() > eps).sum()
        total_loss = alpha * mse_loss + beta * bce_loss + eta * l1_regularization
        total_loss.backward()
        optimizer.step()

        bfs_acc = ((out[:,1].flatten() > 0) == data.reachable).sum().item() / data.num_nodes
    return mse_loss, bce_loss, bfs_acc, l1_regularization, l0_norm, total_loss

def testing():
    torch.no_grad()
    model.eval()
    test_loss = torch.tensor(0.)
    bfs_acc = torch.tensor(0.)
    for data in test_loader:
        data.to(device)
        out = model(data.x, data.edge_index, edge_attr = data.edge_attr, batch = data.batch)
        test_loss += test_criterion(out[:,0].flatten()[data.reachable], data.y[data.reachable]).cpu()
        bfs_acc += ((out[:,1].flatten() > 0) == data.reachable).sum().item()
    test_loss /= num_reachable_test_nodes
    bfs_acc /= data.num_nodes
    return test_loss.detach().cpu(), bfs_acc.detach().cpu()

mse_losses = torch.zeros(epochs)
train_accs = torch.zeros(epochs)
l1_regs = torch.zeros(epochs)
test_losses = torch.zeros(epochs)
test_accs = torch.zeros(epochs)
model_checkpoints = []

pbar = tqdm(range(epochs))
for epoch in pbar:
    optimizer.zero_grad()
    mse_loss, bce_loss, train_acc, l1_reg, l0_norm, total_loss = training()
    mse_losses[epoch] = mse_loss
    train_accs[epoch] = train_acc
    l1_regs[epoch] = l1_reg
    test_loss, test_acc = testing()
    test_losses[epoch] = test_loss
    test_accs[epoch] = test_acc
    if epoch % 100 == 0:
        model_checkpoints.append({k: v.cpu() for k, v in model.state_dict().items()})
    pbar.set_description(f'Train MSE: {float(mse_loss):.4f}, Train BCE: {float(beta * bce_loss):.4f}, Train Acc: {float(train_acc):.4f}, L1 Norm: {float(l1_reg):.4f}, L0 Norm: {float(l0_norm):.4f}, Test Loss: {float(test_loss):.4f}, Test Acc: {float(test_acc):.4f}')
model.eval()

Train MSE: 0.0001, Train BCE: 0.0020, Train Acc: 0.9927, L1 Norm: 15.9885, L0 Norm: 14.0000, Test Loss: 0.0602, Test Acc: 1.0000: 100%|██████████| 20000/20000 [29:53<00:00, 11.15it/s]   


MinAggGNN(2, 2, num_layers=2)

In [None]:
torch.save(model.state_dict(), f'../model_progress/parallel/seed_{seed}/model_final.pt')
torch.save(model_checkpoints, f'../model_progress/parallel/seed_{seed}/model_checkpoints.pt')
torch.save(mse_losses, f'../model_progress/parallel/seed_{seed}/mse_losses.pt')
torch.save(train_accs, f'../model_progress/parallel/seed_{seed}/train_accs.pt')
torch.save(l1_regs, f'../model_progress/parallel/seed_{seed}/l1_regs.pt')
torch.save(test_losses, f'../model_progress/parallel/seed_{seed}/test_losses.pt')
torch.save(test_accs, f'../model_progress/parallel/seed_{seed}/test_accs.pt')