In [None]:
import os
import sys

from torch_geometric.data import Data

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

module_path = os.path.abspath(os.path.join('../gmn_config/'))
if module_path not in sys.path:
    sys.path.append(module_path)

from gmn_config.graph_utils import *

from gmn_config.evaluation import compute_similarity, auc
from gmn_config.loss import pairwise_loss, triplet_loss
from gmn_config.gmn_utils import *
from gmn_config.configure_cosine import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# Print configure
config = get_default_config()
for k, v in config.items():
    print("%s= %s" % (k, v))

# Set random seeds
seed = config["seed"]
random.seed(seed)
np.random.seed(seed + 1)
# torch.manual_seed(seed + 2)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)


training_set, validation_set = build_datasets(config)

if config["training"]["mode"] == "pair":
    training_data_iter = training_set.pairs(config["training"]["batch_size"])
    first_batch_graphs, _ = next(training_data_iter)
else:
    training_data_iter = training_set.triplets(config["training"]["batch_size"])
    first_batch_graphs = next(training_data_iter)

node_feature_dim = first_batch_graphs.node_features.shape[-1]
edge_feature_dim = first_batch_graphs.edge_features.shape[-1]

gmn, optimizer = build_model(config, node_feature_dim, edge_feature_dim)
gmn.load_state_dict(torch.load("../gmn_config/model64_5.pth"))
gmn.to(device)
gmn.eval()

In [None]:
import os.path as osp
import torch
from torch.nn import Linear

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphConv, dense_mincut_pool
from torch_geometric import utils
from torch_geometric.nn import Sequential
from torch_geometric.nn.conv.gcn_conv import gcn_norm

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.eval_metrics import *

torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

# data.edge_index, data.edge_weight = gcn_norm(  
#                 data.edge_index, data.edge_weight, data.num_nodes,
#                 add_self_loops=False, dtype=data.x.dtype)

delta = 0.85
edge_index, edge_weight = utils.get_laplacian(data.edge_index, data.edge_weight, normalization='sym')
L = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
A = torch.eye(data.num_nodes) - delta*L
data.edge_index, data.edge_weight = utils.dense_to_sparse(A)

data = data.to(device)
original_data = Data(x=reduce_dimensions(data.x.numpy()), edge_index=data.edge_index)

In [None]:
class MinCutNet(torch.nn.Module):
    def __init__(self, 
                 mp_units,
                 mp_act,
                 in_channels, 
                 n_clusters, 
                 mlp_units=[],
                 mlp_act="Identity"):
        super().__init__()
        
        mp_act = getattr(torch.nn, mp_act)(inplace=True)
        mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
        
        mp = [
            (GraphConv(in_channels, mp_units[0]), 'x, edge_index, edge_weight -> x'),
            mp_act
        ]
        for i in range(len(mp_units)-1):
            mp.append((GraphConv(mp_units[i], mp_units[i+1]), 'x, edge_index, edge_weight -> x'))
            mp.append(mp_act)
        self.mp = Sequential('x, edge_index, edge_weight', mp)
        out_chan = mp_units[-1]
        
        self.mlp = torch.nn.Sequential()
        for units in mlp_units:
            self.mlp.append(Linear(out_chan, units))
            out_chan = units
            self.mlp.append(mlp_act)
        self.mlp.append(Linear(out_chan, n_clusters))
        

    def forward(self, x, edge_index, edge_weight, epoch, max_epochs, l_value=0.0):
        x = self.mp(x, edge_index, edge_weight) 
        s = self.mlp(x) 
        adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
        x_pooled, adj_pooled, mc_loss, o_loss = dense_mincut_pool(x, adj, s)
        
        if l_value != 0.0:
            edge_index_pool = utils.dense_to_sparse(adj_pooled)[0]
            clustered_data = torch_geometric.data.Data(x=x_pooled[0], edge_index=edge_index_pool)
            sim = similarity(gmn, config, original_data, clustered_data)
            sim_loss = - l_value * ((1 + sim)/2)
        else:
            sim = 1.0
            sim_loss = -1.0
                
        loss = (mc_loss + o_loss)/(-sim_loss)
                
        return torch.softmax(s, dim=-1), loss

In [None]:
model = MinCutNet([64], "ELU", dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)


def train(epoch, max_epochs, l_value=0.0):
    global best_sim
    model.train()
    optimizer.zero_grad()  
    _, loss = model(data.x, data.edge_index, data.edge_weight, epoch, max_epochs, l_value)
    loss.backward()
    optimizer.step()
    
    return loss.item()


@torch.no_grad()
def test(epoch, max_epochs, l_value=0.0):
    model.eval()
    clust, _ = model(data.x, data.edge_index, data.edge_weight, epoch, max_epochs, l_value)
    return eval_metrics(data.y.cpu(), clust.max(1)[1].cpu())
    

for epoch in range(1, 751):
    train_loss = train(epoch, 1001)
    acc, nmi = test(epoch, 1001)
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc:.3f}')
        
torch.save(model.state_dict(), "mincut.pth")

In [None]:
model.load_state_dict(torch.load("mincut.pth"))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, 751):
    train_loss = train(epoch, 751, l_value=1.0)
    acc, nmi = test(epoch, 751, l_value=1.0)
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}, ACC: {acc:.3f}')