In [1]:
from __future__ import division
from __future__ import print_function
import os, sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import argparse
import random
import numpy as np
import scipy.sparse as sp
import torch

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

from torch import optim
import torch.nn.functional as F
from model import ROD_cluster
from optimizer import loss_function
from utils import *
from sklearn.cluster import SpectralClustering, KMeans
from clustering_metric import clustering_metrics
from tqdm import tqdm

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--num_hops', type=int, default=3)
parser.add_argument('--dims', type=int, default=[64])
parser.add_argument('--lr', type=float, default=1e-2)
parser.add_argument('--batch_size', type=int, default=1000)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--upd', type=int, default=10)
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args(args=[])
print("Using {} dataset".format(args.dataset))

Using cora dataset


### 1.Data Fetching

In [3]:
if args.dataset == 'cora':
    n_clusters = 7
    lr = 1e-2
    start_hops = 4
elif args.dataset == 'citeseer':
    n_clusters = 6
    lr = 1e-3
    start_hops = 3
elif args.dataset == 'pubmed':
    n_clusters = 3
    lr = 1e-3
    start_hops = 15

device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
adj, features, true_labels, idx_train, idx_val, idx_test = load_data(args.dataset)
n_nodes, feat_dim = features.shape
dims = [feat_dim] + args.dims

adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
adj.eliminate_zeros()

n = adj.shape[0]

adj_normalized = preprocess_graph(adj, norm='sym', renorm=True)
features = sp.csr_matrix(features).toarray()

for i in range(start_hops):
    features = adj_normalized.dot(features)

feature_list = [features]
for i in range(args.num_hops):
    feature_list.append(adj_normalized.dot(feature_list[-1]))
input_feature = [torch.FloatTensor(feat).to(device) for feat in feature_list]

adj_1st = (adj + sp.eye(n)).toarray()
adj_label = torch.FloatTensor(adj_1st)
neg_num = pos_num = adj_label.sum().long()

model = ROD_cluster(dims, n_clusters, args.num_hops)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay)

### 2. Data Preprocessing

In [4]:
sm_sim_mx_list = []
for i in range(args.num_hops+1):
    cur_feat = F.normalize(input_feature[i].cpu().data)
    sm_sim_mx_list.append(torch.mm(cur_feat, cur_feat.t()).reshape([-1,]))

adj_label = adj_label.reshape([-1,])
model = model.to(device)

pos_inds_list = []
neg_inds_list = []
for i in range(args.num_hops+1):
    pos_inds_list.append(np.argpartition(-sm_sim_mx_list[i], pos_num)[:pos_num])
    neg_inds_list.append(np.argpartition(sm_sim_mx_list[i], pos_num*200)[:pos_num*200])

length = len(pos_inds_list[0])
length_neg = len(neg_inds_list[0])

pos_inds_cuda_list = [torch.LongTensor(pos_inds).to(device) for pos_inds in pos_inds_list]

batch_size = args.batch_size

### 3. Clustering

In [5]:
kmeans_list = [KMeans(n_clusters=n_clusters, n_init=20)]
for _ in range(args.num_hops+1):
    kmeans_list.append(KMeans(n_clusters=n_clusters, n_init=20))

tqdm.write('Start Training...')
for epoch in tqdm(range(args.epochs)):
    model.train()

    optimizer.zero_grad()
    z_list = model(input_feature)
    start, end = 0, batch_size
    loss1 = 0.
    ran_head = np.random.randint(0, length_neg-length-1)
    sampled_neg_list = []
    for i in range(args.num_hops+1):
        sampled_neg_list.append(torch.LongTensor(neg_inds_list[i][np.arange(ran_head, ran_head+length)]).to(device))

    if epoch % args.upd == 0:
        label_list = []
        centroid_list = []
        for i in range(args.num_hops+1):
            label_list.append(kmeans_list[i].fit_predict(z_list[i].data.cpu().numpy()))
            centroid_list.append(kmeans_list[i].cluster_centers_)

        new_label_list = [label_list[0]]
        new_centroid_list = [torch.FloatTensor(centroid_list[0]).to(device)]

        for i in range(1, args.num_hops+1):
            temp_label, temp_index = munkres(label_list[i], label_list[0])
            temp_centroid = np.array([centroid_list[i][temp_index[j][1]] for j in range(n_clusters)])
            new_label_list.append(temp_label)
            new_centroid_list.append(torch.FloatTensor(temp_centroid).to(device))

    dist_list = []
    for i in range(args.num_hops+1):
        for j in range(n_clusters):
            if j == 0:
                dist = torch.norm(z_list[i] - new_centroid_list[i][j], p=2, dim=1, keepdim=True)
            else:
                dist = torch.cat((dist, torch.norm(z_list[i] - new_centroid_list[i][j], p=2, dim=1, keepdim=True)), 1)
        dist_list.append(dist)

    dist_norm_list = [F.softmax(dist, 1) for dist in dist_list]

    attention_scores = [torch.sigmoid(model.lr_att2(dist_norm)).view(n_nodes, 1) for dist_norm in dist_norm_list]
    W = torch.cat(attention_scores, dim=1)
    W = F.softmax(W, 1)

    dist_ensemble = torch.mul(dist_norm_list[0], W[:, 0].view(n_nodes, 1))
    for i in range(1, args.num_hops+1):
        dist_ensemble += torch.mul(dist_norm_list[i], W[:, i].view(n_nodes, 1))

    label_ensemble = dist_ensemble.min(1)[1].long().cpu().numpy()
    if len(list(set(label_ensemble))) < n_clusters:
        y_pred = kmeans_list[args.num_hops+1].fit_predict(dist_ensemble.data.cpu().numpy())
    else:
        y_pred = label_ensemble

    if epoch == 0:
        cm = clustering_metrics(true_labels, y_pred)
        best_acc, best_nmi, best_ari = cm.evaluationClusterModelFromLabel(tqdm)
    else:
        cm = clustering_metrics(true_labels, y_pred)
        acc, nmi, ari = cm.evaluationClusterModelFromLabel(tqdm)
        if acc > best_acc:
            best_acc = acc
            best_nmi = nmi
            best_ari = ari

    loss3 = 0.
    for i in range(args.num_hops+1):
        loss3 += F.mse_loss(dist_norm_list[i], dist_ensemble)

    loss2 = 0.
    for i in range(args.num_hops+1):
        loss_tmp = -dist_list[i].mean(1).sum()
        loss_tmp += 2 * np.sum([dist_list[i][j, x] for j, x in zip(range(dist_list[i].shape[0]), new_label_list[i])])
        loss2 += loss_tmp / n_nodes

    while end <= length:
        for i in range(args.num_hops+1):
            sampled_inds = torch.cat((pos_inds_cuda_list[i][start:end], sampled_neg_list[i][start:end]), 0)
            xind = sampled_inds // n_nodes
            yind = sampled_inds % n_nodes
            zx = torch.index_select(z_list[i], 0, xind)
            zy = torch.index_select(z_list[i], 0, yind)
            batch_label = torch.cat((torch.ones(end-start), torch.zeros(end-start))).to(device)
            batch_pred = (zx * zy).sum(1)
            weight = torch.cat((batch_pred[:batch_size], 1-batch_pred[batch_size:]), 0).data
            loss1 += loss_function(adj_preds=batch_pred, adj_labels=batch_label, weight=weight)

        start = end
        if end < length <= end + batch_size:
            end += length - end
        else:
            end += batch_size

    loss = 1*loss1 + 10*loss2 + 10*loss3
    loss.backward()
    optimizer.step()


tqdm.write("Optimization Finished!")
tqdm.write('best_acc: {}, best_nmi: {}, best_adj: {}'.format(best_acc, best_nmi, best_ari))


  0%|          | 0/200 [00:00<?, ?it/s]

Start Training...


100%|██████████| 200/200 [04:34<00:00,  1.37s/it]

Optimization Finished!
best_acc: 0.7470457902511078, best_nmi: 0.5810151004989821, best_adj: 0.5345240231222482



