In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics
from torch_geometric.datasets import Planetoid

In [2]:
g_data_name = 'Cora'  # Cora | wiki | power | dublin
g_toy = False
g_dim = 32
n_samples = 8
total_epoch = 200
lr = 0.005
if g_toy:
    g_dim = g_dim // 2

In [3]:
def gpu(x):
    return x

def cpu(x):
    return x

def ip(x, y):
    return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)

In [4]:
class MAD(nn.Module):
    def __init__(
            self, in_feats, n_nodes, node_feats,
            n_samples, mem, feats, gather2neighbor=False,
    ):
        super(self.__class__, self).__init__()
        self.n_nodes = n_nodes
        self.node_feats = node_feats
        self.n_samples = n_samples
        self.mem = mem
        self.feats = feats
        self.gather2neighbor = gather2neighbor
        self.f = gpu(nn.Linear(in_feats, node_feats))
        self.g = (
            None if gather2neighbor else gpu(nn.Linear(in_feats, node_feats)))
        self.adapt = gpu(nn.Linear(1, 1))
        self.nn = None

    def nns(self, src, dst):
        if self.nn is None:
            n = self.n_samples
            self.nn = gpu(torch.empty((self.n_nodes, n), dtype=int))
            for perm in DataLoader(
                    range(self.n_nodes), 64, shuffle=False):
                self.nn[perm] = (
                    self.feats[perm].unsqueeze(1) - self.feats.unsqueeze(0)
                ).norm(dim=-1).topk(1 + n, largest=False).indices[..., 1:]
        return self.nn[src], self.nn[dst]

    def recall(self, src, dst):
        if self.mem is None:
            return 0
        return self.adapt(
            (0.0 + self.mem[src, dst]).unsqueeze(-1)).squeeze(-1)

    def forward(self, src, dst):
        n = src.shape[0]
        feats = self.feats
        g = self.f if self.gather2neighbor else self.g
        mid0 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        mid1 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        # mid0, mid1 = self.nns(src, dst)
        srcdiff = self.f(feats[src]).unsqueeze(1) - self.f(feats[mid0])
        logits1 = (
            ip(srcdiff, g(feats[dst]).unsqueeze(1))
            + self.recall(mid0, dst.unsqueeze(1))
        )
        dstdiff = self.f(feats[dst]).unsqueeze(1) - self.f(feats[mid1])
        logits2 = (
            ip(dstdiff, g(feats[src]).unsqueeze(1))
            + self.recall(src.unsqueeze(1), mid1)
        )
        logits = torch.cat((logits1, logits2), dim=1)
        dist = torch.cat((srcdiff, dstdiff), dim=1).norm(dim=2)
        logits = torch.cat((
            logits, gpu(torch.zeros(n, self.n_samples))), dim=1)
        dist = torch.cat((
            dist, gpu(torch.ones(n, self.n_samples))), dim=1)
        return torch.sigmoid(ip(logits, torch.softmax(-dist, dim=1)))

In [5]:
import pandas as pd
import numpy as np
def load_network_data(adj_name):
    if adj_name == 'Cora':
        nodes_numbers = 2708
        datasets = Planetoid('./datasets', adj_name)
        edges = datasets[0].edge_index
        raw_edges = pd.DataFrame([[edges[0,i].item(), edges[1,i].item()] for i in range(edges.shape[1])])
    elif adj_name == 'wiki':
        nodes_numbers = 2405
        raw_edges = pd.read_csv('datasets/graph.txt', header=None, sep='\t')
    elif adj_name == 'power':
        nodes_numbers = 1176
        raw_edges = pd.read_csv('datasets/power-eris1176.mtx', header=None, sep=' ')-1
    elif adj_name == 'dublin':
        nodes_numbers = 410
        raw_edges = pd.read_csv('datasets/ia-infect-dublin.mtx', header=None, sep=' ') - 1
    else:
        print("Dataset is not exist!")

    drop_self_loop = raw_edges[raw_edges[0] != raw_edges[1]]

    graph_np = gpu(torch.zeros((nodes_numbers, nodes_numbers), dtype=bool))

    for i in range(drop_self_loop.shape[0]):
        graph_np[drop_self_loop.iloc[i,0], drop_self_loop.iloc[i,1]]=1
        graph_np[drop_self_loop.iloc[i,1], drop_self_loop.iloc[i,0]]=1
        
    features = torch.eye(nodes_numbers)
    
    return graph_np, features, nodes_numbers

In [6]:
adj, node_features, n_nodes = load_network_data(g_data_name)
n_features = node_features.shape[1]

src = adj.nonzero()[:,0]
dst = adj.nonzero()[:,1]

flt = src > dst

src = src[flt]
dst = dst[flt]

In [7]:
if g_toy:
    mem = None
    train_src = gpu(src)
    train_dst = gpu(dst)
    mlp = gpu(nn.Linear(g_dim, n_labels))
    params = list(mlp.parameters())
    print('mlp params:', sum(p.numel() for p in params))
    mlp_opt = optim.Adam(params, lr=lr)
else:
    n = src.shape[0]
    perm = torch.randperm(n)
    val_num = int(0.05 * n)
    test_num = int(0.1 * n)
    train_src = gpu(src[perm[val_num + test_num:]])
    train_dst = gpu(dst[perm[val_num + test_num:]])
    val_src = gpu(src[perm[:val_num]])
    val_dst = gpu(dst[perm[:val_num]])
    test_src = gpu(src[perm[val_num:val_num + test_num]])
    test_dst = gpu(dst[perm[val_num:val_num + test_num]])
    train_src, train_dst = (
        torch.cat((train_src, train_dst)),
        torch.cat((train_dst, train_src)))
    val_src, val_dst = (
        torch.cat((val_src, val_dst)),
        torch.cat((val_dst, val_src)))
    test_src, test_dst = (
        torch.cat((test_src, test_dst)),
        torch.cat((test_dst, test_src)))
    mem = gpu(torch.zeros((n_nodes, n_nodes), dtype=bool))
    mem[train_src, train_dst] = 1

In [8]:
total_aucs = []
total_aps = []
for run in range(10):
    torch.manual_seed(run)
    mad = MAD(
        in_feats=n_features,
        n_nodes=n_nodes,
        node_feats=g_dim,
        n_samples=n_samples,
        mem=mem,
        feats=node_features,
        gather2neighbor=g_toy,
    )
    params = list(mad.parameters())
    print('params:', sum(p.numel() for p in params))
    opt = optim.Adam(params, lr=0.01)
    best_aucs = [0, 0]
    best_aps = [0, 0]
    best_accs = [0, 0]
    for epoch in range(1, total_epoch + 1):
        mad.train()
        for perm in DataLoader(
                range(train_src.shape[0]), batch_size=1024, shuffle=True):
            opt.zero_grad()
            p_pos = mad(train_src[perm], train_dst[perm])
            neg_src = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            neg_dst = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            idx = ~(mem[neg_src, neg_dst])
            p_neg = mad(neg_src[idx], neg_dst[idx])
            loss = (
                -torch.log(1e-5 + 1 - p_neg).mean()
                - torch.log(1e-5 + p_pos).mean()
            )
            loss.backward()
            opt.step()

        if epoch % 10:
            continue

        if g_toy:
            with torch.no_grad():
                embed = mad.f(node_features)
            for i in range(100):
                mlp.train()
                mlp_opt.zero_grad()
                logits = mlp(embed)
                loss = F.cross_entropy(
                    logits[train_mask], node_labels[train_mask])
                loss.backward()
                mlp_opt.step()
            with torch.no_grad():
                logits = mlp(embed)
                _, indices = torch.max(logits[valid_mask], dim=1)
                labels = node_labels[valid_mask]
                v_acc = torch.sum(indices == labels).item() * 1.0 / len(labels)
                _, indices = torch.max(logits[test_mask], dim=1)
                labels = node_labels[test_mask]
                t_acc = torch.sum(indices == labels).item() * 1.0 / len(labels)
            if v_acc > best_accs[0]:
                best_accs = [v_acc, t_acc]
                print(epoch, 'acc:', v_acc, t_acc)
            continue

        with torch.no_grad():
            mad.eval()
            aucs = []
            aps = []
            for src, dst in ((val_src, val_dst), (test_src, test_dst)):
                p_pos = mad(src, dst)

                n = src.shape[0]
                perm = torch.randperm(n * 2)
                neg_src = torch.cat((
                    src, gpu(torch.randint(0, n_nodes, (n, )))
                ))[perm]
                neg_dst = torch.cat((
                    gpu(torch.randint(0, n_nodes, (n, ))), dst
                ))[perm]
                idx = ~(adj[neg_src, neg_dst])
                neg_src = neg_src[idx][:n]
                neg_dst = neg_dst[idx][:n]
                p_neg = mad(neg_src, neg_dst)

                y_true = cpu(torch.cat((p_pos * 0 + 1, p_neg * 0)))
                y_score = cpu(torch.cat((p_pos, p_neg)))
                fpr, tpr, _ = metrics.roc_curve(y_true, y_score, pos_label=1)
                auc = metrics.auc(fpr, tpr)
                ap = metrics.average_precision_score(y_true, y_score)
                aucs.append(auc)
                aps.append(ap)
            if aucs[0] > best_aucs[0]:
                best_aucs = aucs
                print(epoch, 'auc:', aucs)
            if aps[0] > best_aps[0]:
                best_aps = aps
                print(epoch, 'ap:', aps)
    print(run, 'best auc:', best_aucs)
    print(run, 'best ap:', best_aps)
    print(run, 'best acc (toy):', best_accs)
    total_aucs.append(best_aucs[1])
    total_aps.append(best_aps[1])
total_aucs = torch.tensor(total_aucs)
total_aps = torch.tensor(total_aps)
print('auc mean:', total_aucs.mean().item(), 'std:', total_aucs.std().item())
print('ap mean:', total_aps.mean().item(), 'std:', total_aps.std().item())

params: 173378
10 auc: [0.5715204788272202, 0.5831349985057376]
10 ap: [0.5732373900496415, 0.5916410000334205]
20 auc: [0.607165059491969, 0.6240588847401604]
20 ap: [0.6117713011051301, 0.6486214493292064]
30 auc: [0.6742471338316298, 0.6794969196590922]
30 ap: [0.7122946350262612, 0.7116665292954127]
40 auc: [0.7114205785828913, 0.7087970647645726]
40 ap: [0.7444894809724467, 0.7419148763145389]
50 auc: [0.7297055039107114, 0.7277732969909516]
50 ap: [0.7776494509170714, 0.7752896390224998]
60 auc: [0.7323764981422313, 0.7375823194552963]
70 auc: [0.751199959519438, 0.7389226548181861]
70 ap: [0.8001215297258093, 0.7875418252033126]
80 ap: [0.8031045511267677, 0.8048968186429593]
90 ap: [0.8033826268865888, 0.7986296732584093]
100 auc: [0.752251731266897, 0.7512512197141819]
100 ap: [0.808041047222043, 0.8026915977689655]
110 auc: [0.753542049183883, 0.7599314439615598]
110 ap: [0.8153646944710875, 0.8086199544470012]
160 ap: [0.8160275669986139, 0.7984751970741659]
0 best auc: [0.7

70 auc: [0.7539829981639175, 0.744084161178703]
70 ap: [0.8024981523585689, 0.7998188835781672]
80 ap: [0.8029846635556199, 0.7979642964993308]
90 auc: [0.7543878037849325, 0.7421524219652971]
90 ap: [0.8061200418409615, 0.7983639104495489]
100 auc: [0.7609080657519987, 0.7512368171850977]
100 ap: [0.8164731086280538, 0.8039318323362599]
110 auc: [0.7668825629978748, 0.7543585653640779]
160 ap: [0.8198912070231053, 0.8088765904448187]
170 auc: [0.7732221081698449, 0.7436781898901447]
170 ap: [0.8317487810093248, 0.801330520076184]
8 best auc: [0.7732221081698449, 0.7436781898901447]
8 best ap: [0.8317487810093248, 0.801330520076184]
8 best acc (toy): [0, 0]
params: 173378
10 auc: [0.5874163281238705, 0.5993441448318324]
10 ap: [0.5764732987571561, 0.6160308103131411]
20 auc: [0.6559766658474172, 0.6575861721318264]
20 ap: [0.6627437057694232, 0.6937027379498327]
30 auc: [0.7107085544102125, 0.690740794083441]
30 ap: [0.7393510091491311, 0.7415545975539304]
40 auc: [0.721775650941896, 0