In [13]:
import os.path as osp
import torch
from torch.nn import Linear
from math import ceil

import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphConv
from torch_geometric import utils
from torch_geometric.nn import Sequential, DMoNPooling, GCNConv
import torch_geometric.transforms as T
from torch_geometric.nn.conv.gcn_conv import gcn_norm

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.eval_metrics import *

torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

delta = 0.85
edge_index, edge_weight = utils.get_laplacian(data.edge_index, data.edge_weight, normalization='sym')
L = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
A = torch.eye(data.num_nodes) - delta*L
data.edge_index, data.edge_weight = utils.dense_to_sparse(A)

data = data.to(device)

In [15]:
class DMoNNet(torch.nn.Module):
    def __init__(self, 
                 mp_units,
                 mp_act,
                 in_channels, 
                 n_clusters, 
                 mlp_units=[],
                 mlp_act="Identity"):
        super().__init__()
        
        mp_act = getattr(torch.nn, mp_act)(inplace=True)
        mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
        
        mp = [
            (GCNConv(in_channels, mp_units[0], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'),
            mp_act
        ]
        for i in range(len(mp_units)-1):
            mp.append((GCNConv(mp_units[i], mp_units[i+1], normalize=False, cached=False), 'x, edge_index, edge_weight -> x'))
            mp.append(mp_act)
        self.mp = Sequential('x, edge_index, edge_weight', mp)
        out_chan = mp_units[-1]
        
        self.mlp = torch.nn.Sequential()
        for units in mlp_units:
            self.mlp.append(Linear(out_chan, units))
            out_chan = units
            self.mlp.append(mlp_act)
        self.mlp.append(Linear(out_chan, n_clusters))
        
        self.dmon_pooling = DMoNPooling(channels=mp_units[-1], k=n_clusters)
        

    def forward(self, x, edge_index, edge_weight):
        x = self.mp(x, edge_index, edge_weight)
        s = self.mlp(x) 
        adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
        _, _, _, spectral_loss, ortho_loss, cluster_loss = DMoNPooling(x, adj)
        
        return torch.softmax(s, dim=-1), spectral_loss, ortho_loss, cluster_loss

In [16]:
model = DMoNNet([64]*10, "ELU", dataset.num_features, dataset.num_classes, [8]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


def train():
    model.train()
    optimizer.zero_grad()
    _, spectral_loss, ortho_loss, cluster_loss = model(data.x, data.edge_index, data.edge_weight)
    loss = spectral_loss + ortho_loss + cluster_loss
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test():
    model.eval()
    clust, _, _, _ = model(data.x, data.edge_index, data.edge_weight)
    return eval_metrics(data.y.cpu(), clust.max(1)[1].cpu())
    

patience = 50
best_nmi = 0
for epoch in range(1, 751):
    train_loss = train()
    acc, nmi = test()
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc:.3f}')

Epoch: 050, Loss: 1.0796, NMI: 0.003, ACC: 0.294
Epoch: 100, Loss: 1.0796, NMI: 0.002, ACC: 0.259
Epoch: 150, Loss: 1.0796, NMI: 0.004, ACC: 0.242
Epoch: 200, Loss: 1.0453, NMI: 0.257, ACC: 0.364
Epoch: 250, Loss: 0.8715, NMI: 0.284, ACC: 0.423
Epoch: 300, Loss: 0.5950, NMI: 0.293, ACC: 0.404
Epoch: 350, Loss: -0.1095, NMI: 0.276, ACC: 0.371
Epoch: 400, Loss: -0.3183, NMI: 0.268, ACC: 0.394
Epoch: 450, Loss: -0.3533, NMI: 0.268, ACC: 0.398
Epoch: 500, Loss: -0.3668, NMI: 0.268, ACC: 0.399
Epoch: 550, Loss: -0.3756, NMI: 0.268, ACC: 0.400
Epoch: 600, Loss: -0.3809, NMI: 0.265, ACC: 0.395
Epoch: 650, Loss: -0.3842, NMI: 0.264, ACC: 0.393
Epoch: 700, Loss: -0.3866, NMI: 0.262, ACC: 0.390
Epoch: 750, Loss: -0.3887, NMI: 0.259, ACC: 0.387
