In [14]:
import sys
from time import gmtime, strftime

import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from munch import munchify
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

from .data.sbm import SBM_Dataset
from .models.losses import combinatorical_accuracy, ari_score

sys.path.append('./GNN4CD/src')
from load import get_lg_inputs
from losses import compute_loss_multiclass
from models import lGNN_multiclass

In [15]:
class SBM_Dataset_adjacency(SBM_Dataset):
    def __getitem__(self, idx):
        G, labels = super().__getitem__(idx)
        A = np.array(nx.adjacency_matrix(G).todense(), dtype=np.float32)
        return A, labels

n, k, p_in, p_out = 50, 5, 0.8, 0.2
n_epoch, n_samples_train, n_samples_test = 3, 500, 100
train_dataset = SBM_Dataset_adjacency(n, k, p_in, p_out, n_graphs=n_samples_train)
test_dataset = SBM_Dataset_adjacency(n, k, p_in, p_out, n_graphs=n_samples_test)
train_dataloader = DataLoader(train_dataset, batch_size=1)
test_dataloader = DataLoader(test_dataset, batch_size=1)

In [16]:
device = 'cuda:0'
args = munchify({
    'clip_grad_norm': 40.0,
    'num_features': 8,
    'num_layers': 30,
    'n_classes': k,
    'J': 2,
    'lr': 0.004
})

torch.backends.cudnn.enabled = False
model = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2, n_classes=args.n_classes).to(device)
optimizer = torch.optim.Adamax(model.parameters(), lr=args.lr)

In [17]:
name = f'GNN4CD_SBM({n}, {k}, {p_in:.2f}, {p_out:.2f})_bs=2'
writer = SummaryWriter(f'./logs/{strftime("%Y-%m-%d %H:%M:%S", gmtime())} {name}')

for epoch in range(n_epoch):
    model.train()
    for it, (W, labels) in enumerate(tqdm(train_dataloader, desc=str(epoch))):
        WW, x, WW_lg, y, P = get_lg_inputs(W.numpy(), args.J)
        print(W.shape, WW.shape, x.shape, WW_lg.shape, y.shape, P.shape)
        WW, x, WW_lg, y, P, labels = [x.float().to(device) for x in (WW, x, WW_lg, y, P, labels)]
        pred = model(WW, x, WW_lg, y, P)

        loss = compute_loss_multiclass(pred, labels, args.n_classes)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
        optimizer.step()

        writer.add_scalar('train/loss', loss.item(), epoch * len(train_dataloader) + it)
        writer.add_scalar('train/acc', combinatorical_accuracy(pred, labels), epoch * len(train_dataloader) + it)
        writer.add_scalar('train/ari', ari_score(pred, labels), epoch * len(train_dataloader) + it)
        if it % 100 == 0:
            writer.flush()

#     model.eval()
    loss_lst, acc_lst, ari_lst = [], [], []
    with torch.no_grad():
        for _, (W, labels) in enumerate(test_dataloader):
            WW, x, WW_lg, y, P = get_lg_inputs(W.numpy(), args.J)
            WW, x, WW_lg, y, P, labels = [x.float().to(device) for x in (WW, x, WW_lg, y, P, labels)]
            pred = model(WW, x, WW_lg, y, P)

            loss = compute_loss_multiclass(pred, labels, args.n_classes)
            loss_lst.append(loss.item())
            acc_lst.append(combinatorical_accuracy(pred, labels))
            ari_lst.append(ari_score(pred, labels))
    writer.add_scalar('test/loss', np.mean(loss_lst), epoch)
    writer.add_scalar('test/acc', np.mean(acc_lst), epoch)
    writer.add_scalar('test/ari', np.mean(ari_lst), epoch)
    writer.flush()

HBox(children=(IntProgress(value=0, description='0', max=500, style=ProgressStyle(description_width='initial')…

torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 1]) torch.Size([1, 734, 734, 4]) torch.Size([1, 734, 1]) torch.Size([1, 50, 734, 2])
torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 1]) torch.Size([1, 808, 808, 4]) torch.Size([1, 808, 1]) torch.Size([1, 50, 808, 2])
torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 1]) torch.Size([1, 742, 742, 4]) torch.Size([1, 742, 1]) torch.Size([1, 50, 742, 2])
torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 1]) torch.Size([1, 788, 788, 4]) torch.Size([1, 788, 1]) torch.Size([1, 50, 788, 2])
torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 1]) torch.Size([1, 768, 768, 4]) torch.Size([1, 768, 1]) torch.Size([1, 50, 768, 2])
torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 1]) torch.Size([1, 782, 782, 4]) torch.Size([1, 782, 1]) torch.Size([1, 50, 782, 2])
torch.Size([1, 50, 50]) torch.Size([1, 50, 50, 4]) torch.Size([1, 50, 

KeyboardInterrupt: 