In [2]:
import os
import sys

import json
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data

from utils import preprocess

In [3]:
cache_path = '../dataset/b/cache.pt'
model_dir = '../models/tuned_b'
model_path = os.path.join(model_dir, 'model.pt')
device = torch.device('cuda:0')

In [4]:
model = torch.load(model_path).to(device)

In [5]:
config, features, edge_index, edge_attr, labels, train_mask, eval_mask, test_mask = torch.load(cache_path)

In [6]:
data = Data(x=features, edge_index=edge_index, edge_attr=edge_attr, y=labels, train_mask=train_mask, test_mask=test_mask, eval_mask=eval_mask).to(device)

In [7]:
out = model(data)

In [8]:
out

tensor([[-2.4068, -2.2860, -1.9563, -0.5715, -3.2606, -2.7521],
        [-3.7088, -0.1774, -2.8229, -3.8807, -3.4557, -3.6327],
        [-3.1680, -3.5078, -3.5060, -3.0157, -2.8196, -0.2366],
        ...,
        [-3.2262, -2.1226, -1.0199, -0.8682, -3.5160, -3.4889],
        [-2.1501, -0.6912, -3.2305, -2.8570, -1.7174, -2.2439],
        [-2.5790, -1.4328, -2.8318, -3.1127, -2.7010, -0.6636]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward>)

In [9]:
probs = F.softmax(out, dim=-1)

In [10]:
pred_labels = probs.argmax(dim=-1).cpu().detach()

In [11]:
for i in range(config['n_class']):
    n_correct = (pred_labels == labels).logical_and(labels == i).sum()
    n_total = (labels == i).sum()
    print('class {} : total {} recall {}'.format(i, n_total.item(), (n_correct / n_total).item()))

class 0 : total 249 recall 0.3534136414527893
class 1 : total 590 recall 0.7966101765632629
class 2 : total 668 recall 0.8218562602996826
class 3 : total 701 recall 0.8502140045166016
class 4 : total 596 recall 0.8808724880218506
class 5 : total 523 recall 0.7801147103309631


In [12]:
class_features = [[] for i in range(config['n_class'])]
for i in range(config['n_vertex']):
    class_features[labels[i]].append(features[i])

In [13]:
edge_distri = torch.zeros(config['n_class'], config['n_class'])
for i in range(config['n_edge']):
    u = edge_index[0][i]
    v = edge_index[1][i]
    edge_distri[labels[u], labels[v]] += 1
print(edge_distri)


tensor([[ 190.,  108.,   43.,   64.,   93.,   18.],
        [ 108.,  904.,  238.,   60.,   79.,   31.],
        [  43.,  238., 2082.,  180.,   47.,   67.],
        [  64.,   60.,  180., 1256.,   50.,   37.],
        [  93.,   79.,   47.,   50., 1378.,   86.],
        [  18.,   31.,   67.,   37.,   86.,  892.]])


In [17]:
sum_features = torch.zeros(config['n_class'], config['n_feature'])
for clas in range(config['n_class']):
    for ft in class_features[clas]:
        sum_features[clas] += torch.tensor(ft).float()

In [29]:
top_clas = torch.zeros(config['n_class'], 20)
for i, sum_ft in enumerate(sum_features):
    top_clas[i] = sum_ft.sort(descending=True)[1][:20]

print(top_clas)

tensor([[1774., 2753., 1592., 2537.,  591., 2508., 2532., 1602., 2584., 2116.,
         2345., 2708.,  804., 2059., 2534., 2650., 2703., 1223.,  197.,   63.],
        [1845.,  100., 2532., 2537., 2584.,  591.,  166., 2508.,  717., 2059.,
         1223., 1943., 1257., 2116., 1602., 1774.,  411., 3173., 2213., 2211.],
        [3534.,  865., 3429., 2849.,  591., 2719., 2537., 3217., 2708.,  717.,
          981., 2608., 2586., 2508., 1845., 3560., 2584., 3602., 1602., 2345.],
        [ 717.,  719., 2608.,  165., 2116., 2270., 2537., 2584.,  591., 2508.,
         1959., 1590., 2532., 2700., 1602., 2534., 1819., 3124., 2345., 2685.],
        [  63., 2158., 2537., 1592., 2116.,  165., 1602.,  796.,  591.,  804.,
          852.,  166.,  174.,   17., 2586., 2978., 2508., 2112.,  244., 2584.],
        [3429., 1597., 1599.,  796., 2537., 2584., 2508., 2586.,  165., 2708.,
         1451.,  997.,  793., 1602., 2700.,  804., 3192.,   17., 2116., 1223.]])


In [28]:
def overlap(x, y):
    return len(set(x.tolist()).intersection(set(y.tolist())))

matrix = torch.zeros(config['n_class'], config['n_class'])
for i in range(config['n_class']):
    for j in range(config['n_class']):
        matrix[i, j] = overlap(top_clas[i], top_clas[j])

print(matrix)

tensor([[10.,  5.,  2.,  5.,  5.,  3.],
        [ 5., 10.,  3.,  5.,  2.,  3.],
        [ 2.,  3., 10.,  3.,  2.,  3.],
        [ 5.,  5.,  3., 10.,  4.,  4.],
        [ 5.,  2.,  2.,  4., 10.,  3.],
        [ 3.,  3.,  3.,  4.,  3., 10.]])
