In [1]:
from time import gmtime, strftime
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import torch
import torch.nn as nn
from munch import munchify
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

from data.sbm_dgl import SBM_Dataset_LGNN_DGL
from models.lgnn_dgl import LGNN
from models.losses import combinatorical_accuracy, ari_score, combinatorical_cce

Using backend: pytorch


In [6]:
device = torch.device('cuda:0')
args = munchify({
    'clip_grad_norm': 40.0,
    'num_features': 8,
    'num_layers': 30,
    'n_classes': 2,
    'J': 2,
    'lr': 0.004
})

n, k, p_in, p_out = 100, args.n_classes, 0.2, 0.1
n_epoch, n_samples_train, n_samples_test = 3, 500, 100
train_dataset = SBM_Dataset_LGNN_DGL(n, k, p_in, p_out, n_graphs=n_samples_train, verbose=True)
test_dataset = SBM_Dataset_LGNN_DGL(n, k, p_in, p_out, n_graphs=n_samples_test)
train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=train_dataset.collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=train_dataset.collate_fn)

  0%|          | 0/500 [00:00<?, ?it/s]

sizes: [50, 50], p: [[0.2, 0.1], [0.1, 0.2]]


100%|██████████| 500/500 [00:29<00:00, 16.90it/s]


In [7]:
G, G_lg, P, labels = train_dataset[0]

In [8]:
torch.backends.cudnn.enabled = False
model = LGNN(args.num_features, args.num_layers, args.J, n_classes=args.n_classes).to(device)
optimizer = torch.optim.Adamax(model.parameters(), lr=args.lr)

In [9]:
name = f'lgnn_dgl-SBM({n}, {k}, {p_in:.2f}, {p_out:.2f})'
writer = SummaryWriter(f'./logs/{strftime("%Y-%m-%d %H:%M:%S", gmtime())} {name}')
for epoch in range(n_epoch):
    model.train()
    for it, (G, G_lg, P, labels) in enumerate(tqdm(train_dataloader, desc=str(epoch))):
        G, G_lg, P, labels = G.to(device), G_lg.to(device), P.float().to(device), labels.long().to(device)[None]
        pred = model(G, G_lg, P)[None]

        loss = combinatorical_cce(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
        optimizer.step()

        writer.add_scalar('train/loss', loss.item(), epoch * len(train_dataloader) + it)
        writer.add_scalar('train/acc', combinatorical_accuracy(pred, labels), epoch * len(train_dataloader) + it)
        writer.add_scalar('train/ari', ari_score(pred, labels), epoch * len(train_dataloader) + it)
        if it % 100 == 0:
            writer.flush()

    loss_lst, acc_lst, ari_lst = [], [], []
    with torch.no_grad():
        for _, (G, G_lg, P, labels) in enumerate(test_dataloader):
            G, G_lg, P, labels = G.to(device), G_lg.to(device), P.float().to(device), labels.long().to(device)[None]
            pred = model(G, G_lg, P)[None]

            loss = combinatorical_cce(pred, labels)
            loss_lst.append(loss.item())
            acc_lst.append(combinatorical_accuracy(pred, labels))
            ari_lst.append(ari_score(pred, labels))
    writer.add_scalar('test/loss', np.mean(loss_lst), epoch)
    writer.add_scalar('test/acc', np.mean(acc_lst), epoch)
    writer.add_scalar('test/ari', np.mean(ari_lst), epoch)
    writer.flush()

HBox(children=(IntProgress(value=0, description='0', max=500, style=ProgressStyle(description_width='initial')…




HBox(children=(IntProgress(value=0, description='1', max=500, style=ProgressStyle(description_width='initial')…




HBox(children=(IntProgress(value=0, description='2', max=500, style=ProgressStyle(description_width='initial')…


