In [1]:
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import (v_measure_score, homogeneity_score, completeness_score)
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models.autoencoder import ARGVA
from torch_geometric.utils import train_test_split_edges

In [2]:
dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset.get(0)

In [3]:
num_nodes = data.x.shape[0]

In [4]:
data.train_mask = data.val_mask = data.test_mask = None
data = train_test_split_edges(data)
data

Data(test_neg_edge_index=[2, 527], test_pos_edge_index=[2, 527], train_neg_adj_mask=[2708, 2708], train_pos_edge_index=[2, 8976], val_neg_edge_index=[2, 263], val_pos_edge_index=[2, 263], x=[2708, 1433], y=[2708])

## Define the model

In [5]:
class VEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [6]:
class Discriminator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(Discriminator, self).__init__()
        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x

In [7]:
def train():
    model.train()
    encoder_optimizer.zero_grad()
    
    z = model.encode(data.x, data.train_pos_edge_index)

    for _ in range(5):
        idx = range(num_nodes)  
        discriminator.train()
        discriminator_optimizer.zero_grad()
        discriminator_loss = model.discriminator_loss(z[idx]) # Comment
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimizer.step()
 
    loss = 0
    loss = loss + model.reg_loss(z)  # Comment
    
    loss = loss + model.recon_loss(z, data.train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()

    encoder_optimizer.step()
    
    return loss

In [8]:
@torch.no_grad()
def test():
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)

    # Cluster embedded values using k-means.
    kmeans_input = z.cpu().numpy()
    kmeans = KMeans(n_clusters=7, random_state=0).fit(kmeans_input)
    pred = kmeans.predict(kmeans_input)

    labels = data.y.cpu().numpy()
    completeness = completeness_score(labels, pred)
    hm = homogeneity_score(labels, pred)
    nmi = v_measure_score(labels, pred)

    auc, ap = model.test(z, data.test_pos_edge_index, data.test_neg_edge_index)

    return auc, ap, completeness, hm, nmi

In [9]:
latent_size = 32
encoder = VEncoder(data.num_features, out_channels=latent_size)

discriminator = Discriminator(in_channels=latent_size, hidden_channels=64, 
                              out_channels=1) # Comment

In [10]:
# use_cuda = False
model = ARGVA(encoder, discriminator)

# device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
device = 'cuda'
model, data = model.to(device), data.to(device)

In [11]:
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)

## wandb

In [12]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdamon[0m (use `wandb login --relogin` to force relogin)


True

In [None]:
wandb.init(project='ARGVA&ARGA')

for epoch in range(1, 5001):
    loss = train()
    auc, ap, completeness, hm, nmi = test()
    print((f'Epoch: {epoch:03d}, Loss: {loss:.3f}, AUC: {auc:.3f}, '
           f'AP: {ap:.3f}, Completeness: {completeness:.3f}, '
           f'Homogeneity: {hm:.3f}, NMI: {nmi:.3f}'))
    wandb.log({'epoch' : epoch,
              'loss' : loss,
              'auc' : auc,
              'ap' : ap})

Epoch: 001, Loss: 5.577, AUC: 0.755, AP: 0.770, Completeness: 0.102, Homogeneity: 0.105, NMI: 0.104
Epoch: 002, Loss: 4.940, AUC: 0.733, AP: 0.755, Completeness: 0.096, Homogeneity: 0.098, NMI: 0.097
Epoch: 003, Loss: 4.314, AUC: 0.719, AP: 0.743, Completeness: 0.087, Homogeneity: 0.086, NMI: 0.087
Epoch: 004, Loss: 3.907, AUC: 0.715, AP: 0.741, Completeness: 0.061, Homogeneity: 0.059, NMI: 0.060
Epoch: 005, Loss: 3.612, AUC: 0.717, AP: 0.742, Completeness: 0.129, Homogeneity: 0.124, NMI: 0.126
Epoch: 006, Loss: 3.231, AUC: 0.724, AP: 0.746, Completeness: 0.157, Homogeneity: 0.150, NMI: 0.154
Epoch: 007, Loss: 2.816, AUC: 0.730, AP: 0.750, Completeness: 0.210, Homogeneity: 0.210, NMI: 0.210
Epoch: 008, Loss: 2.555, AUC: 0.733, AP: 0.750, Completeness: 0.235, Homogeneity: 0.237, NMI: 0.236
Epoch: 009, Loss: 2.332, AUC: 0.731, AP: 0.748, Completeness: 0.246, Homogeneity: 0.251, NMI: 0.248
Epoch: 010, Loss: 2.208, AUC: 0.740, AP: 0.759, Completeness: 0.270, Homogeneity: 0.275, NMI: 0.272


Epoch: 083, Loss: 4.040, AUC: 0.880, AP: 0.879, Completeness: 0.442, Homogeneity: 0.406, NMI: 0.423
Epoch: 084, Loss: 4.069, AUC: 0.882, AP: 0.879, Completeness: 0.445, Homogeneity: 0.416, NMI: 0.430
Epoch: 085, Loss: 4.056, AUC: 0.884, AP: 0.881, Completeness: 0.461, Homogeneity: 0.455, NMI: 0.458
Epoch: 086, Loss: 4.054, AUC: 0.885, AP: 0.882, Completeness: 0.462, Homogeneity: 0.457, NMI: 0.460
Epoch: 087, Loss: 4.046, AUC: 0.887, AP: 0.884, Completeness: 0.468, Homogeneity: 0.465, NMI: 0.467
Epoch: 088, Loss: 3.880, AUC: 0.888, AP: 0.885, Completeness: 0.473, Homogeneity: 0.471, NMI: 0.472
Epoch: 089, Loss: 3.960, AUC: 0.889, AP: 0.887, Completeness: 0.469, Homogeneity: 0.466, NMI: 0.468
Epoch: 090, Loss: 3.925, AUC: 0.891, AP: 0.888, Completeness: 0.470, Homogeneity: 0.470, NMI: 0.470
Epoch: 091, Loss: 3.924, AUC: 0.892, AP: 0.889, Completeness: 0.473, Homogeneity: 0.474, NMI: 0.474
Epoch: 092, Loss: 3.929, AUC: 0.894, AP: 0.890, Completeness: 0.476, Homogeneity: 0.476, NMI: 0.476


Epoch: 165, Loss: 3.618, AUC: 0.910, AP: 0.910, Completeness: 0.418, Homogeneity: 0.437, NMI: 0.427
Epoch: 166, Loss: 3.555, AUC: 0.911, AP: 0.910, Completeness: 0.417, Homogeneity: 0.438, NMI: 0.427
Epoch: 167, Loss: 3.577, AUC: 0.911, AP: 0.911, Completeness: 0.416, Homogeneity: 0.437, NMI: 0.426
Epoch: 168, Loss: 3.588, AUC: 0.911, AP: 0.911, Completeness: 0.417, Homogeneity: 0.437, NMI: 0.426
Epoch: 169, Loss: 3.521, AUC: 0.911, AP: 0.911, Completeness: 0.417, Homogeneity: 0.437, NMI: 0.427
Epoch: 170, Loss: 3.618, AUC: 0.910, AP: 0.911, Completeness: 0.415, Homogeneity: 0.435, NMI: 0.425
Epoch: 171, Loss: 3.592, AUC: 0.910, AP: 0.910, Completeness: 0.419, Homogeneity: 0.439, NMI: 0.429
Epoch: 172, Loss: 3.527, AUC: 0.909, AP: 0.909, Completeness: 0.418, Homogeneity: 0.438, NMI: 0.428
Epoch: 173, Loss: 3.523, AUC: 0.908, AP: 0.908, Completeness: 0.414, Homogeneity: 0.433, NMI: 0.423
Epoch: 174, Loss: 3.561, AUC: 0.908, AP: 0.908, Completeness: 0.415, Homogeneity: 0.436, NMI: 0.425


Epoch: 247, Loss: 3.312, AUC: 0.897, AP: 0.909, Completeness: 0.354, Homogeneity: 0.369, NMI: 0.362
Epoch: 248, Loss: 3.332, AUC: 0.897, AP: 0.908, Completeness: 0.363, Homogeneity: 0.381, NMI: 0.371
Epoch: 249, Loss: 3.323, AUC: 0.896, AP: 0.908, Completeness: 0.361, Homogeneity: 0.379, NMI: 0.370
Epoch: 250, Loss: 3.355, AUC: 0.895, AP: 0.907, Completeness: 0.364, Homogeneity: 0.385, NMI: 0.374
Epoch: 251, Loss: 3.331, AUC: 0.894, AP: 0.906, Completeness: 0.338, Homogeneity: 0.354, NMI: 0.346
Epoch: 252, Loss: 3.355, AUC: 0.893, AP: 0.905, Completeness: 0.351, Homogeneity: 0.370, NMI: 0.360
Epoch: 253, Loss: 3.296, AUC: 0.894, AP: 0.905, Completeness: 0.365, Homogeneity: 0.381, NMI: 0.373
Epoch: 254, Loss: 3.375, AUC: 0.894, AP: 0.905, Completeness: 0.369, Homogeneity: 0.389, NMI: 0.379
Epoch: 255, Loss: 3.319, AUC: 0.894, AP: 0.906, Completeness: 0.366, Homogeneity: 0.385, NMI: 0.375
Epoch: 256, Loss: 3.410, AUC: 0.894, AP: 0.906, Completeness: 0.362, Homogeneity: 0.377, NMI: 0.370


Epoch: 329, Loss: 3.169, AUC: 0.888, AP: 0.902, Completeness: 0.294, Homogeneity: 0.309, NMI: 0.301
Epoch: 330, Loss: 3.196, AUC: 0.888, AP: 0.902, Completeness: 0.294, Homogeneity: 0.309, NMI: 0.301
Epoch: 331, Loss: 3.143, AUC: 0.888, AP: 0.903, Completeness: 0.292, Homogeneity: 0.308, NMI: 0.300
Epoch: 332, Loss: 3.161, AUC: 0.889, AP: 0.903, Completeness: 0.292, Homogeneity: 0.307, NMI: 0.299
Epoch: 333, Loss: 3.190, AUC: 0.890, AP: 0.904, Completeness: 0.268, Homogeneity: 0.282, NMI: 0.275
Epoch: 334, Loss: 3.086, AUC: 0.890, AP: 0.905, Completeness: 0.306, Homogeneity: 0.322, NMI: 0.314
Epoch: 335, Loss: 3.117, AUC: 0.890, AP: 0.905, Completeness: 0.306, Homogeneity: 0.323, NMI: 0.314
Epoch: 336, Loss: 3.118, AUC: 0.890, AP: 0.905, Completeness: 0.313, Homogeneity: 0.327, NMI: 0.320
Epoch: 337, Loss: 3.147, AUC: 0.891, AP: 0.905, Completeness: 0.299, Homogeneity: 0.315, NMI: 0.307
Epoch: 338, Loss: 3.089, AUC: 0.891, AP: 0.905, Completeness: 0.297, Homogeneity: 0.313, NMI: 0.305


Epoch: 411, Loss: 3.134, AUC: 0.894, AP: 0.909, Completeness: 0.229, Homogeneity: 0.243, NMI: 0.236
Epoch: 412, Loss: 3.183, AUC: 0.893, AP: 0.908, Completeness: 0.223, Homogeneity: 0.237, NMI: 0.230
Epoch: 413, Loss: 3.142, AUC: 0.890, AP: 0.905, Completeness: 0.222, Homogeneity: 0.235, NMI: 0.229
Epoch: 414, Loss: 3.190, AUC: 0.887, AP: 0.902, Completeness: 0.221, Homogeneity: 0.234, NMI: 0.227
Epoch: 415, Loss: 3.220, AUC: 0.883, AP: 0.899, Completeness: 0.219, Homogeneity: 0.231, NMI: 0.225
Epoch: 416, Loss: 3.202, AUC: 0.880, AP: 0.897, Completeness: 0.204, Homogeneity: 0.215, NMI: 0.209
Epoch: 417, Loss: 3.250, AUC: 0.878, AP: 0.896, Completeness: 0.209, Homogeneity: 0.220, NMI: 0.215
Epoch: 418, Loss: 3.203, AUC: 0.876, AP: 0.895, Completeness: 0.184, Homogeneity: 0.191, NMI: 0.187
Epoch: 419, Loss: 3.164, AUC: 0.874, AP: 0.894, Completeness: 0.198, Homogeneity: 0.207, NMI: 0.203
Epoch: 420, Loss: 3.188, AUC: 0.872, AP: 0.893, Completeness: 0.190, Homogeneity: 0.195, NMI: 0.193


Epoch: 493, Loss: 3.231, AUC: 0.871, AP: 0.891, Completeness: 0.278, Homogeneity: 0.288, NMI: 0.283
Epoch: 494, Loss: 3.197, AUC: 0.872, AP: 0.892, Completeness: 0.228, Homogeneity: 0.242, NMI: 0.235
Epoch: 495, Loss: 3.190, AUC: 0.873, AP: 0.894, Completeness: 0.270, Homogeneity: 0.284, NMI: 0.277
Epoch: 496, Loss: 3.090, AUC: 0.874, AP: 0.895, Completeness: 0.259, Homogeneity: 0.273, NMI: 0.266
Epoch: 497, Loss: 3.115, AUC: 0.875, AP: 0.896, Completeness: 0.282, Homogeneity: 0.297, NMI: 0.289
Epoch: 498, Loss: 3.114, AUC: 0.876, AP: 0.896, Completeness: 0.264, Homogeneity: 0.278, NMI: 0.271
Epoch: 499, Loss: 3.062, AUC: 0.877, AP: 0.897, Completeness: 0.285, Homogeneity: 0.301, NMI: 0.293
Epoch: 500, Loss: 3.033, AUC: 0.877, AP: 0.897, Completeness: 0.280, Homogeneity: 0.296, NMI: 0.288
Epoch: 501, Loss: 3.083, AUC: 0.876, AP: 0.897, Completeness: 0.281, Homogeneity: 0.296, NMI: 0.289
Epoch: 502, Loss: 3.053, AUC: 0.875, AP: 0.896, Completeness: 0.273, Homogeneity: 0.288, NMI: 0.280


Epoch: 575, Loss: 3.002, AUC: 0.880, AP: 0.901, Completeness: 0.210, Homogeneity: 0.222, NMI: 0.216
Epoch: 576, Loss: 2.942, AUC: 0.880, AP: 0.900, Completeness: 0.209, Homogeneity: 0.221, NMI: 0.215
Epoch: 577, Loss: 3.005, AUC: 0.881, AP: 0.900, Completeness: 0.230, Homogeneity: 0.242, NMI: 0.236
Epoch: 578, Loss: 3.035, AUC: 0.881, AP: 0.900, Completeness: 0.227, Homogeneity: 0.239, NMI: 0.233
Epoch: 579, Loss: 3.078, AUC: 0.881, AP: 0.899, Completeness: 0.225, Homogeneity: 0.237, NMI: 0.231
Epoch: 580, Loss: 3.047, AUC: 0.881, AP: 0.898, Completeness: 0.241, Homogeneity: 0.246, NMI: 0.244
Epoch: 581, Loss: 3.092, AUC: 0.880, AP: 0.897, Completeness: 0.241, Homogeneity: 0.247, NMI: 0.244
Epoch: 582, Loss: 3.084, AUC: 0.880, AP: 0.896, Completeness: 0.241, Homogeneity: 0.246, NMI: 0.244
Epoch: 583, Loss: 3.161, AUC: 0.879, AP: 0.895, Completeness: 0.238, Homogeneity: 0.242, NMI: 0.240
Epoch: 584, Loss: 3.154, AUC: 0.878, AP: 0.894, Completeness: 0.239, Homogeneity: 0.244, NMI: 0.242


Epoch: 657, Loss: 3.230, AUC: 0.862, AP: 0.883, Completeness: 0.147, Homogeneity: 0.152, NMI: 0.150
Epoch: 658, Loss: 3.360, AUC: 0.861, AP: 0.881, Completeness: 0.148, Homogeneity: 0.152, NMI: 0.150
Epoch: 659, Loss: 3.243, AUC: 0.861, AP: 0.881, Completeness: 0.163, Homogeneity: 0.167, NMI: 0.165
Epoch: 660, Loss: 3.340, AUC: 0.861, AP: 0.881, Completeness: 0.182, Homogeneity: 0.188, NMI: 0.185
Epoch: 661, Loss: 3.288, AUC: 0.862, AP: 0.882, Completeness: 0.185, Homogeneity: 0.191, NMI: 0.188
Epoch: 662, Loss: 3.261, AUC: 0.864, AP: 0.885, Completeness: 0.151, Homogeneity: 0.156, NMI: 0.153
Epoch: 663, Loss: 3.296, AUC: 0.869, AP: 0.889, Completeness: 0.160, Homogeneity: 0.165, NMI: 0.162
Epoch: 664, Loss: 3.242, AUC: 0.872, AP: 0.893, Completeness: 0.173, Homogeneity: 0.180, NMI: 0.177
Epoch: 665, Loss: 3.257, AUC: 0.876, AP: 0.896, Completeness: 0.195, Homogeneity: 0.206, NMI: 0.201
Epoch: 666, Loss: 3.170, AUC: 0.877, AP: 0.898, Completeness: 0.199, Homogeneity: 0.210, NMI: 0.205


Epoch: 739, Loss: 3.207, AUC: 0.871, AP: 0.888, Completeness: 0.197, Homogeneity: 0.209, NMI: 0.203
Epoch: 740, Loss: 3.244, AUC: 0.872, AP: 0.890, Completeness: 0.188, Homogeneity: 0.200, NMI: 0.194
Epoch: 741, Loss: 3.219, AUC: 0.873, AP: 0.891, Completeness: 0.200, Homogeneity: 0.212, NMI: 0.206
Epoch: 742, Loss: 3.218, AUC: 0.875, AP: 0.894, Completeness: 0.200, Homogeneity: 0.212, NMI: 0.206
Epoch: 743, Loss: 3.256, AUC: 0.875, AP: 0.896, Completeness: 0.195, Homogeneity: 0.206, NMI: 0.200
Epoch: 744, Loss: 3.152, AUC: 0.876, AP: 0.897, Completeness: 0.230, Homogeneity: 0.241, NMI: 0.236
Epoch: 745, Loss: 3.226, AUC: 0.876, AP: 0.898, Completeness: 0.204, Homogeneity: 0.216, NMI: 0.210
Epoch: 746, Loss: 3.222, AUC: 0.875, AP: 0.898, Completeness: 0.184, Homogeneity: 0.194, NMI: 0.189
Epoch: 747, Loss: 3.131, AUC: 0.873, AP: 0.897, Completeness: 0.192, Homogeneity: 0.203, NMI: 0.197
Epoch: 748, Loss: 3.122, AUC: 0.872, AP: 0.895, Completeness: 0.176, Homogeneity: 0.187, NMI: 0.181


Epoch: 821, Loss: 3.132, AUC: 0.860, AP: 0.886, Completeness: 0.165, Homogeneity: 0.175, NMI: 0.170
Epoch: 822, Loss: 3.096, AUC: 0.860, AP: 0.886, Completeness: 0.167, Homogeneity: 0.176, NMI: 0.171
Epoch: 823, Loss: 3.166, AUC: 0.859, AP: 0.885, Completeness: 0.149, Homogeneity: 0.157, NMI: 0.153
Epoch: 824, Loss: 3.066, AUC: 0.857, AP: 0.884, Completeness: 0.171, Homogeneity: 0.182, NMI: 0.176
Epoch: 825, Loss: 3.068, AUC: 0.856, AP: 0.884, Completeness: 0.174, Homogeneity: 0.183, NMI: 0.178
Epoch: 826, Loss: 2.993, AUC: 0.855, AP: 0.883, Completeness: 0.179, Homogeneity: 0.188, NMI: 0.183
Epoch: 827, Loss: 2.979, AUC: 0.854, AP: 0.882, Completeness: 0.154, Homogeneity: 0.162, NMI: 0.158
Epoch: 828, Loss: 2.909, AUC: 0.853, AP: 0.882, Completeness: 0.173, Homogeneity: 0.182, NMI: 0.177
Epoch: 829, Loss: 2.902, AUC: 0.853, AP: 0.881, Completeness: 0.157, Homogeneity: 0.164, NMI: 0.161
Epoch: 830, Loss: 2.872, AUC: 0.853, AP: 0.880, Completeness: 0.179, Homogeneity: 0.189, NMI: 0.184


Epoch: 903, Loss: 2.975, AUC: 0.854, AP: 0.876, Completeness: 0.186, Homogeneity: 0.187, NMI: 0.186
Epoch: 904, Loss: 2.903, AUC: 0.853, AP: 0.874, Completeness: 0.168, Homogeneity: 0.176, NMI: 0.172
Epoch: 905, Loss: 2.852, AUC: 0.851, AP: 0.872, Completeness: 0.147, Homogeneity: 0.154, NMI: 0.151
Epoch: 906, Loss: 2.888, AUC: 0.850, AP: 0.871, Completeness: 0.141, Homogeneity: 0.148, NMI: 0.145
Epoch: 907, Loss: 2.806, AUC: 0.849, AP: 0.869, Completeness: 0.172, Homogeneity: 0.181, NMI: 0.176
Epoch: 908, Loss: 2.795, AUC: 0.847, AP: 0.868, Completeness: 0.187, Homogeneity: 0.195, NMI: 0.190
Epoch: 909, Loss: 2.842, AUC: 0.847, AP: 0.868, Completeness: 0.182, Homogeneity: 0.191, NMI: 0.186
Epoch: 910, Loss: 2.850, AUC: 0.846, AP: 0.868, Completeness: 0.187, Homogeneity: 0.196, NMI: 0.191
Epoch: 911, Loss: 2.885, AUC: 0.846, AP: 0.869, Completeness: 0.178, Homogeneity: 0.186, NMI: 0.182
Epoch: 912, Loss: 2.918, AUC: 0.848, AP: 0.869, Completeness: 0.179, Homogeneity: 0.186, NMI: 0.182


Epoch: 985, Loss: 2.918, AUC: 0.854, AP: 0.877, Completeness: 0.108, Homogeneity: 0.114, NMI: 0.111
Epoch: 986, Loss: 2.878, AUC: 0.853, AP: 0.875, Completeness: 0.116, Homogeneity: 0.122, NMI: 0.119
Epoch: 987, Loss: 2.907, AUC: 0.852, AP: 0.874, Completeness: 0.108, Homogeneity: 0.113, NMI: 0.110
Epoch: 988, Loss: 2.914, AUC: 0.851, AP: 0.875, Completeness: 0.102, Homogeneity: 0.108, NMI: 0.105
Epoch: 989, Loss: 2.928, AUC: 0.851, AP: 0.875, Completeness: 0.113, Homogeneity: 0.119, NMI: 0.116
Epoch: 990, Loss: 2.956, AUC: 0.850, AP: 0.875, Completeness: 0.081, Homogeneity: 0.086, NMI: 0.083
Epoch: 991, Loss: 3.006, AUC: 0.850, AP: 0.874, Completeness: 0.109, Homogeneity: 0.114, NMI: 0.112
Epoch: 992, Loss: 3.022, AUC: 0.850, AP: 0.873, Completeness: 0.125, Homogeneity: 0.132, NMI: 0.128
Epoch: 993, Loss: 3.080, AUC: 0.850, AP: 0.872, Completeness: 0.117, Homogeneity: 0.124, NMI: 0.121
Epoch: 994, Loss: 3.046, AUC: 0.851, AP: 0.873, Completeness: 0.130, Homogeneity: 0.137, NMI: 0.133


Epoch: 1067, Loss: 3.069, AUC: 0.860, AP: 0.886, Completeness: 0.143, Homogeneity: 0.150, NMI: 0.147
Epoch: 1068, Loss: 2.971, AUC: 0.860, AP: 0.886, Completeness: 0.127, Homogeneity: 0.134, NMI: 0.131
Epoch: 1069, Loss: 3.087, AUC: 0.859, AP: 0.885, Completeness: 0.132, Homogeneity: 0.139, NMI: 0.135
Epoch: 1070, Loss: 3.042, AUC: 0.858, AP: 0.885, Completeness: 0.146, Homogeneity: 0.154, NMI: 0.150
Epoch: 1071, Loss: 3.075, AUC: 0.857, AP: 0.884, Completeness: 0.146, Homogeneity: 0.153, NMI: 0.149
Epoch: 1072, Loss: 3.078, AUC: 0.856, AP: 0.883, Completeness: 0.142, Homogeneity: 0.149, NMI: 0.145
Epoch: 1073, Loss: 3.092, AUC: 0.856, AP: 0.883, Completeness: 0.147, Homogeneity: 0.155, NMI: 0.151
Epoch: 1074, Loss: 3.142, AUC: 0.857, AP: 0.884, Completeness: 0.148, Homogeneity: 0.156, NMI: 0.152
Epoch: 1075, Loss: 3.105, AUC: 0.858, AP: 0.884, Completeness: 0.157, Homogeneity: 0.166, NMI: 0.161
Epoch: 1076, Loss: 3.143, AUC: 0.860, AP: 0.885, Completeness: 0.131, Homogeneity: 0.136, N

Epoch: 1149, Loss: 3.103, AUC: 0.861, AP: 0.888, Completeness: 0.101, Homogeneity: 0.106, NMI: 0.104
Epoch: 1150, Loss: 3.112, AUC: 0.861, AP: 0.887, Completeness: 0.120, Homogeneity: 0.126, NMI: 0.123
Epoch: 1151, Loss: 3.103, AUC: 0.861, AP: 0.887, Completeness: 0.108, Homogeneity: 0.114, NMI: 0.111
Epoch: 1152, Loss: 3.128, AUC: 0.861, AP: 0.887, Completeness: 0.114, Homogeneity: 0.120, NMI: 0.117
Epoch: 1153, Loss: 3.149, AUC: 0.861, AP: 0.887, Completeness: 0.104, Homogeneity: 0.109, NMI: 0.106
Epoch: 1154, Loss: 3.129, AUC: 0.861, AP: 0.887, Completeness: 0.102, Homogeneity: 0.107, NMI: 0.104
Epoch: 1155, Loss: 3.094, AUC: 0.860, AP: 0.887, Completeness: 0.119, Homogeneity: 0.126, NMI: 0.122
Epoch: 1156, Loss: 3.106, AUC: 0.860, AP: 0.887, Completeness: 0.096, Homogeneity: 0.101, NMI: 0.098
Epoch: 1157, Loss: 3.097, AUC: 0.861, AP: 0.888, Completeness: 0.124, Homogeneity: 0.131, NMI: 0.127
Epoch: 1158, Loss: 3.089, AUC: 0.861, AP: 0.889, Completeness: 0.107, Homogeneity: 0.113, N

Epoch: 1231, Loss: 3.028, AUC: 0.866, AP: 0.890, Completeness: 0.097, Homogeneity: 0.102, NMI: 0.099
Epoch: 1232, Loss: 3.002, AUC: 0.866, AP: 0.891, Completeness: 0.110, Homogeneity: 0.117, NMI: 0.113
Epoch: 1233, Loss: 2.982, AUC: 0.867, AP: 0.891, Completeness: 0.118, Homogeneity: 0.125, NMI: 0.122
Epoch: 1234, Loss: 2.961, AUC: 0.867, AP: 0.892, Completeness: 0.114, Homogeneity: 0.120, NMI: 0.117
Epoch: 1235, Loss: 2.959, AUC: 0.867, AP: 0.893, Completeness: 0.090, Homogeneity: 0.095, NMI: 0.092
Epoch: 1236, Loss: 2.978, AUC: 0.868, AP: 0.893, Completeness: 0.124, Homogeneity: 0.131, NMI: 0.127
Epoch: 1237, Loss: 2.937, AUC: 0.868, AP: 0.894, Completeness: 0.132, Homogeneity: 0.140, NMI: 0.136
Epoch: 1238, Loss: 2.949, AUC: 0.868, AP: 0.894, Completeness: 0.145, Homogeneity: 0.154, NMI: 0.149
Epoch: 1239, Loss: 2.986, AUC: 0.868, AP: 0.894, Completeness: 0.139, Homogeneity: 0.147, NMI: 0.143
Epoch: 1240, Loss: 2.970, AUC: 0.868, AP: 0.894, Completeness: 0.127, Homogeneity: 0.134, N

Epoch: 1313, Loss: 3.007, AUC: 0.862, AP: 0.888, Completeness: 0.140, Homogeneity: 0.148, NMI: 0.144
Epoch: 1314, Loss: 3.096, AUC: 0.863, AP: 0.889, Completeness: 0.149, Homogeneity: 0.158, NMI: 0.154
Epoch: 1315, Loss: 3.093, AUC: 0.864, AP: 0.890, Completeness: 0.151, Homogeneity: 0.159, NMI: 0.155
Epoch: 1316, Loss: 3.134, AUC: 0.865, AP: 0.891, Completeness: 0.157, Homogeneity: 0.165, NMI: 0.161
Epoch: 1317, Loss: 3.106, AUC: 0.864, AP: 0.891, Completeness: 0.160, Homogeneity: 0.170, NMI: 0.165
Epoch: 1318, Loss: 3.149, AUC: 0.863, AP: 0.890, Completeness: 0.167, Homogeneity: 0.176, NMI: 0.171
Epoch: 1319, Loss: 3.174, AUC: 0.863, AP: 0.890, Completeness: 0.131, Homogeneity: 0.139, NMI: 0.135
Epoch: 1320, Loss: 3.127, AUC: 0.862, AP: 0.889, Completeness: 0.134, Homogeneity: 0.142, NMI: 0.138
Epoch: 1321, Loss: 3.086, AUC: 0.862, AP: 0.889, Completeness: 0.152, Homogeneity: 0.161, NMI: 0.156
Epoch: 1322, Loss: 3.072, AUC: 0.861, AP: 0.888, Completeness: 0.128, Homogeneity: 0.136, N

Epoch: 1395, Loss: 2.952, AUC: 0.861, AP: 0.889, Completeness: 0.128, Homogeneity: 0.135, NMI: 0.132
Epoch: 1396, Loss: 3.012, AUC: 0.860, AP: 0.888, Completeness: 0.095, Homogeneity: 0.101, NMI: 0.098
Epoch: 1397, Loss: 2.974, AUC: 0.859, AP: 0.887, Completeness: 0.110, Homogeneity: 0.117, NMI: 0.113
Epoch: 1398, Loss: 3.007, AUC: 0.858, AP: 0.886, Completeness: 0.119, Homogeneity: 0.124, NMI: 0.121
Epoch: 1399, Loss: 2.907, AUC: 0.858, AP: 0.886, Completeness: 0.108, Homogeneity: 0.114, NMI: 0.111
Epoch: 1400, Loss: 2.963, AUC: 0.858, AP: 0.886, Completeness: 0.118, Homogeneity: 0.124, NMI: 0.121
Epoch: 1401, Loss: 2.954, AUC: 0.859, AP: 0.886, Completeness: 0.123, Homogeneity: 0.129, NMI: 0.126
Epoch: 1402, Loss: 2.965, AUC: 0.859, AP: 0.886, Completeness: 0.115, Homogeneity: 0.122, NMI: 0.118
Epoch: 1403, Loss: 3.008, AUC: 0.859, AP: 0.887, Completeness: 0.110, Homogeneity: 0.115, NMI: 0.113
Epoch: 1404, Loss: 3.056, AUC: 0.860, AP: 0.888, Completeness: 0.138, Homogeneity: 0.145, N

Epoch: 1477, Loss: 3.090, AUC: 0.850, AP: 0.882, Completeness: 0.102, Homogeneity: 0.108, NMI: 0.105
Epoch: 1478, Loss: 3.059, AUC: 0.850, AP: 0.883, Completeness: 0.085, Homogeneity: 0.089, NMI: 0.087
Epoch: 1479, Loss: 3.065, AUC: 0.851, AP: 0.883, Completeness: 0.083, Homogeneity: 0.088, NMI: 0.085
Epoch: 1480, Loss: 3.114, AUC: 0.852, AP: 0.884, Completeness: 0.081, Homogeneity: 0.086, NMI: 0.084
Epoch: 1481, Loss: 3.052, AUC: 0.852, AP: 0.884, Completeness: 0.099, Homogeneity: 0.105, NMI: 0.102
Epoch: 1482, Loss: 3.037, AUC: 0.853, AP: 0.884, Completeness: 0.086, Homogeneity: 0.092, NMI: 0.089
Epoch: 1483, Loss: 3.049, AUC: 0.853, AP: 0.884, Completeness: 0.090, Homogeneity: 0.095, NMI: 0.092
Epoch: 1484, Loss: 3.059, AUC: 0.853, AP: 0.884, Completeness: 0.106, Homogeneity: 0.112, NMI: 0.109
Epoch: 1485, Loss: 3.028, AUC: 0.852, AP: 0.883, Completeness: 0.062, Homogeneity: 0.066, NMI: 0.064
Epoch: 1486, Loss: 2.980, AUC: 0.851, AP: 0.882, Completeness: 0.079, Homogeneity: 0.084, N

Epoch: 1559, Loss: 2.913, AUC: 0.852, AP: 0.883, Completeness: 0.116, Homogeneity: 0.123, NMI: 0.119
Epoch: 1560, Loss: 2.940, AUC: 0.851, AP: 0.882, Completeness: 0.099, Homogeneity: 0.105, NMI: 0.102
Epoch: 1561, Loss: 2.956, AUC: 0.851, AP: 0.882, Completeness: 0.099, Homogeneity: 0.105, NMI: 0.102
Epoch: 1562, Loss: 2.969, AUC: 0.851, AP: 0.882, Completeness: 0.108, Homogeneity: 0.113, NMI: 0.111
Epoch: 1563, Loss: 2.972, AUC: 0.851, AP: 0.882, Completeness: 0.086, Homogeneity: 0.091, NMI: 0.089
Epoch: 1564, Loss: 2.987, AUC: 0.850, AP: 0.882, Completeness: 0.109, Homogeneity: 0.115, NMI: 0.112
Epoch: 1565, Loss: 2.989, AUC: 0.851, AP: 0.882, Completeness: 0.092, Homogeneity: 0.098, NMI: 0.095
Epoch: 1566, Loss: 3.036, AUC: 0.851, AP: 0.882, Completeness: 0.108, Homogeneity: 0.115, NMI: 0.111
Epoch: 1567, Loss: 3.079, AUC: 0.850, AP: 0.881, Completeness: 0.095, Homogeneity: 0.101, NMI: 0.098
Epoch: 1568, Loss: 3.088, AUC: 0.849, AP: 0.880, Completeness: 0.105, Homogeneity: 0.111, N

Epoch: 1641, Loss: 3.074, AUC: 0.854, AP: 0.883, Completeness: 0.126, Homogeneity: 0.133, NMI: 0.129
Epoch: 1642, Loss: 2.956, AUC: 0.854, AP: 0.884, Completeness: 0.125, Homogeneity: 0.132, NMI: 0.128
Epoch: 1643, Loss: 2.998, AUC: 0.855, AP: 0.884, Completeness: 0.124, Homogeneity: 0.132, NMI: 0.128
Epoch: 1644, Loss: 2.986, AUC: 0.855, AP: 0.884, Completeness: 0.115, Homogeneity: 0.122, NMI: 0.119
Epoch: 1645, Loss: 2.969, AUC: 0.854, AP: 0.883, Completeness: 0.127, Homogeneity: 0.134, NMI: 0.130
Epoch: 1646, Loss: 3.017, AUC: 0.853, AP: 0.882, Completeness: 0.093, Homogeneity: 0.099, NMI: 0.096
Epoch: 1647, Loss: 2.927, AUC: 0.851, AP: 0.880, Completeness: 0.144, Homogeneity: 0.152, NMI: 0.148
Epoch: 1648, Loss: 2.888, AUC: 0.848, AP: 0.877, Completeness: 0.105, Homogeneity: 0.111, NMI: 0.108
Epoch: 1649, Loss: 2.955, AUC: 0.846, AP: 0.875, Completeness: 0.120, Homogeneity: 0.127, NMI: 0.123
Epoch: 1650, Loss: 2.934, AUC: 0.845, AP: 0.874, Completeness: 0.090, Homogeneity: 0.095, N

In [None]:
@torch.no_grad()
def plot_points(colors):
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)
    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())
    y = data.y.cpu().numpy()

    fig = plt.figure(1, figsize=(8, 8))
    fig.clf()
    for i in range(dataset.num_classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])
    plt.axis('off')
    plt.show()

In [None]:

#%%
colors = [
    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'
]
plot_points(colors)