In [None]:
import os
import sys

from torch_geometric.data import Data

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

module_path = os.path.abspath(os.path.join('../gmn_config/'))
if module_path not in sys.path:
    sys.path.append(module_path)

from gmn_config.graph_utils import *

from gmn_config.evaluation import compute_similarity, auc
from gmn_config.loss import pairwise_loss, triplet_loss
from gmn_config.gmn_utils import *
from gmn_config.configure_cosine import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

config = get_default_config()

# torch.manual_seed(seed + 2)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

gmn, optimizer = build_model(config, 64, 4)
gmn.load_state_dict(torch.load("../gmn_config/model64v2.pth"))
gmn.to(device)
gmn.eval()

In [None]:
import os.path as osp
import torch
from torch.nn import Linear
from math import ceil

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphConv
from torch_geometric import utils
from torch_geometric.nn import Sequential, DMoNPooling, GCNConv
import torch_geometric.transforms as T
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import torch_geometric

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.eval_metrics import *

torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

delta = 0.85
edge_index, edge_weight = utils.get_laplacian(data.edge_index, data.edge_weight, normalization='sym')
L = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
A = torch.eye(data.num_nodes) - delta*L
data.edge_index, data.edge_weight = utils.dense_to_sparse(A)
original_data = Data(x=reduce_dimensions(data.x.numpy()), edge_index=data.edge_index)
data = data.to(device)

In [None]:
class DMoNNet(torch.nn.Module):
    def __init__(self, 
                 mp_units,
                 mp_act,
                 in_channels, 
                 n_clusters, 
                 mlp_units=[],
                 mlp_act="Identity"):
        super().__init__()
        
        mp_act = getattr(torch.nn, mp_act)(inplace=True)
        mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
        
        mp = [
            (GCNConv(in_channels, mp_units[0], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'),
            mp_act
        ]
        for i in range(len(mp_units)-1):
            mp.append((GCNConv(mp_units[i], mp_units[i+1], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'))
            mp.append(mp_act)
        self.mp = Sequential('x, edge_index, edge_weight', mp)
        out_chan = mp_units[-1]
        
        self.mlp = torch.nn.Sequential()
        for units in mlp_units:
            self.mlp.append(Linear(out_chan, units))
            out_chan = units
            self.mlp.append(mlp_act)
        self.mlp.append(Linear(out_chan, n_clusters))
        
        self.dmon_pooling = DMoNPooling(mp_units[-1], n_clusters)
        

    def forward(self, x, edge_index, edge_weight, l_value = 0.0):
        x = self.mp(x, edge_index, edge_weight)
        s = self.mlp(x) 
        adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
        s, x_pooled, adj_pooled, spectral_loss, ortho_loss, cluster_loss = self.dmon_pooling(x, adj)
        
        if l_value != 0.0:
            edge_index_pool = utils.dense_to_sparse(adj_pooled)[0]
            clustered_data = torch_geometric.data.Data(x=x_pooled[0], edge_index=edge_index_pool)
            sim = similarity(gmn, config, original_data, clustered_data)
            sim_loss = l_value * ((1 - sim)/2)
            # sim_loss = l_value * (1-sim)
        else:
            sim_loss = 0.0
        
        loss = spectral_loss + ortho_loss + cluster_loss + sim_loss
        
        return torch.softmax(s[0], dim=-1), loss, x_pooled, adj_pooled

In [None]:
model = DMoNNet([64]*8, "ReLU", dataset.num_features, dataset.num_classes, [16], "ReLU").to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)


def train(model, optimizer, l_value = 0.0):
    model.train()
    optimizer.zero_grad()
    _, loss, _, _ = model(data.x, data.edge_index, data.edge_weight, l_value)
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test(model, l_value = 0.0):
    model.eval()
    clust, _, x_pooled, adj_pooled = model(data.x, data.edge_index, data.edge_weight, l_value)
    edge_index_pool = utils.dense_to_sparse(adj_pooled)[0]
    clustered_data = torch_geometric.data.Data(x=x_pooled[0], edge_index=edge_index_pool)
    sim = similarity(gmn, config, original_data, clustered_data)
    acc, nmi, ari = eval_metrics(data.y.cpu(), clust.max(1)[1].cpu())
    return acc, nmi, ari, sim
    

patience = 50
best_nmi = 0
losses = []
nmis = []
accs = []
aris = []
sims = []
for epoch in range(1, 1001):
    train_loss = train(model, optimizer)
    acc, nmi, ari, sim = test(model)
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc:.3f}, ARI: {ari:.3f}, SIM: {sim.item()}')

torch.save(model.state_dict(), "dmon.pth")

In [None]:
model.load_state_dict(torch.load("dmon.pth"))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

new_losses = []
new_nmis = []
new_accs = []
new_aris = []
new_sims = []
for epoch in range(0, 751):
    train_loss = train(model, optimizer, 0.01)
    if epoch % 50 == 0:
        acc, nmi, ari, sim = test(model, 0.01)
        new_losses.append(train_loss)
        new_accs.append(acc)
        new_nmis.append(nmi)
        new_aris.append(ari)
        new_sims.append(sim)
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc:.3f}, ARI: {ari:.3f}, SIM: {sim.item()}')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

epochs = np.arange(0, 750, 37.5)

plt.figure(figsize=(10, 6))

plt.plot(epochs, losses, label='Loss')
plt.plot(epochs, nmis, label='NMI')
plt.plot(epochs, accs, label='ACC')
plt.plot(epochs, aris, label='ARI')
plt.plot(epochs, sims, label='SIM')

plt.xlabel('Epochs')
plt.ylabel('Metrics')
plt.title('DMoN')
plt.xticks(range(0, 751, 100))  # Set x-ticks to match your recording intervals
plt.legend()
plt.show()
