In [19]:
import os
import os.path as osp
import sys
from math import ceil
import numpy as np

import torch
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from dataset import GraphDataset
from models import *
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.model_selection import train_test_split
from torch.nn import Linear
from torch_geometric import utils
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.nn import DMoNPooling, GCNConv, GraphConv, Sequential
from torch_geometric.nn.conv.gcn_conv import gcn_norm

In [20]:
dataset = GraphDataset(torch.load("../data/cycle_line_star_complete_1.pt"))

train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [21]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [22]:
def train(model, optimizer):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        output, loss = model(data)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


@torch.no_grad()
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    for data in loader:
        output, _ = model(data)
        pred = torch.argmax(output, dim=1)
        target = data.node_classes
        batch = data.batch
        node_counter = [0] * len(target)
        for node_idx in range(data.num_nodes):
            graph_index = batch[node_idx].item()
            node_index = node_counter[graph_index]
            node_counter[graph_index] += 1
            true_classes = [c - 1 for c in target[graph_index][node_index]]
            predicted_class = pred[node_idx].item()

            if predicted_class in true_classes:
                correct += 1
        total += data.num_nodes
    accuracy = correct / total

    return accuracy


def run(model, optimizer, early_stopping):
    for epoch in range(1001):
        train_loss = train(model, optimizer)
        accuracy = test(model, test_loader)
        if epoch % 100 == 0:
            print(
                f"Epoch: {epoch:03d}, Loss: {train_loss:.4f}, Accuracy: {accuracy:.4f}"
            )
        early_stopping(train_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break

In [23]:
for m in [DMoN, JustBalance, MinCut]:
    for i in [5, 7, 10]:
        for j in [0, 1]:
            for dim in [32, 64]:
                for lr in [1e-3, 1e-4]:
                    model = m(
                        [64] * i,
                        "ReLU",
                        dataset.num_features,
                        len(np.unique(dataset[0].node_classes)),
                        [16] * j,
                        "ReLU",
                    )
                    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
                    early_stopping = EarlyStopping(patience=50)
                    print(f"{model.name}: {(i, j, dim, lr)}")
                    run(model, optimizer, early_stopping)
                    print("--------------------")

DMoN: (5, 0, 32, 0.001)
Epoch: 000, Loss: 2.0603, Accuracy: 0.1419
Epoch: 100, Loss: 0.9834, Accuracy: 0.1404
Epoch: 200, Loss: 0.9239, Accuracy: 0.1348
Epoch: 300, Loss: 0.8929, Accuracy: 0.1334
Epoch: 400, Loss: 0.8783, Accuracy: 0.1376
Epoch: 500, Loss: 0.8677, Accuracy: 0.1334
Epoch: 600, Loss: 0.8595, Accuracy: 0.1306
Epoch: 700, Loss: 0.8587, Accuracy: 0.1461
Epoch: 800, Loss: 0.8587, Accuracy: 0.1362
Early stopping
--------------------
DMoN: (5, 0, 32, 0.0001)
Epoch: 000, Loss: 1.8442, Accuracy: 0.1320
Epoch: 100, Loss: 1.5755, Accuracy: 0.1433
Epoch: 200, Loss: 1.1763, Accuracy: 0.1306
Early stopping
--------------------
DMoN: (5, 0, 64, 0.001)
Epoch: 000, Loss: 1.8472, Accuracy: 0.1362
Epoch: 100, Loss: 0.9447, Accuracy: 0.1475
Epoch: 200, Loss: 0.9056, Accuracy: 0.1601
Epoch: 300, Loss: 0.8764, Accuracy: 0.1433
Epoch: 400, Loss: 0.8743, Accuracy: 0.1503
Epoch: 500, Loss: 0.8607, Accuracy: 0.1419
Epoch: 600, Loss: 0.8547, Accuracy: 0.1531
Epoch: 700, Loss: 0.8423, Accuracy: 0.