In [11]:
import os
import os.path as osp
from collections import Counter

import numpy as np
import torch
import torch_geometric.transforms as T
from sklearn.cluster import DBSCAN, KMeans
from sklearn.decomposition import PCA
from torch_geometric.datasets import Planetoid

from configure_cosine import *
from eval_metrics import *
from gmn_utils import *
from graph_utils import *
from visualize import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

config = get_default_config()

torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(True)

gmn, optimizer = build_model(config, 64, 4)
gmn.load_state_dict(torch.load("models/model64_5.pth"))
gmn.to(device)
gmn.eval()

GraphMatchingNet(
  (_encoder): GraphEncoder(
    (MLP1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
    )
    (MLP2): Sequential(
      (0): Linear(in_features=4, out_features=4, bias=True)
    )
  )
  (_aggregator): GraphAggregator(
    (MLP1): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
    )
    (MLP2): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (_prop_layers): ModuleList(
    (0-4): 5 x GraphPropMatchingLayer(
      (_message_net): Sequential(
        (0): Linear(in_features=132, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
      )
      (_reverse_message_net): Sequential(
        (0): Linear(in_features=132, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
      )
      (GRU): GRU(192, 64)
    )
  )
)

In [12]:
original_data = torch.load("original_data.pt")
clustered_data = torch.load("clustered_data.pt")
sim = similarity(gmn, config, original_data, original_data)

In [43]:
cross_attentions = torch.load("cross_attentions.pt")
topk_attentions = torch.load("topk_cross_attentions.pt")
(a_x, a_y) = cross_attentions[1]
a_x, a_y = a_x[0], a_y[0]
a = ((a_x + a_y.t())/2)
a_x = a_x
a_y = a_y

In [14]:
pca = PCA(0.85)
a = pca.fit_transform(a)

In [15]:
print(a)

[[-0.0115331  -0.00022987 -0.00079375]
 [-0.01266247 -0.00083285 -0.00192065]
 [-0.00137232  0.00072734  0.00447995]
 ...
 [-0.01698002 -0.00187975 -0.00530083]
 [-0.00695435  0.00080998  0.00227842]
 [-0.00463777  0.00142905  0.00421528]]


In [16]:
clusters = KMeans(n_clusters=5).fit(a)
labels = clusters.labels_

  super()._check_params_vs_input(X, default_n_init=10)


In [17]:
Counter(labels).keys() # equals to list(set(words))
Counter(labels).values() # counts the elements' frequency

dict_values([2498, 156, 13, 40, 1])

In [18]:
dataset = 'Cora'
path = osp.join('data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

In [19]:
acc, nmi, ari = eval_metrics(data.y.cpu(), labels)

In [20]:
print(acc, nmi, ari)

0.29394387001477107 0.0032440466389573033 0.0010617573706210453
