In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics

In [6]:
g_data_name = 'email'  # wiki | email | cora | citeseer
g_toy = False
g_dim = 32
n_samples = 8
total_epoch = 200
lr = 0.005
if g_toy:
    g_dim = g_dim // 2

In [7]:
def gpu(x):
    return x

def cpu(x):
    return x

def ip(x, y):
    return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)

In [8]:
class MAD(nn.Module):
    def __init__(
            self, in_feats, n_nodes, node_feats,
            n_samples, mem, feats, gather2neighbor=False,
    ):
        super(self.__class__, self).__init__()
        self.n_nodes = n_nodes
        self.node_feats = node_feats
        self.n_samples = n_samples
        self.mem = mem
        self.feats = feats
        self.gather2neighbor = gather2neighbor
        self.f = gpu(nn.Linear(in_feats, node_feats))
        self.g = (
            None if gather2neighbor else gpu(nn.Linear(in_feats, node_feats)))
        self.adapt = gpu(nn.Linear(1, 1))
        self.nn = None

    def nns(self, src, dst):
        if self.nn is None:
            n = self.n_samples
            self.nn = gpu(torch.empty((self.n_nodes, n), dtype=int))
            for perm in DataLoader(
                    range(self.n_nodes), 64, shuffle=False):
                self.nn[perm] = (
                    self.feats[perm].unsqueeze(1) - self.feats.unsqueeze(0)
                ).norm(dim=-1).topk(1 + n, largest=False).indices[..., 1:]
        return self.nn[src], self.nn[dst]

    def recall(self, src, dst):
        if self.mem is None:
            return 0
        return self.adapt(
            (0.0 + self.mem[src, dst]).unsqueeze(-1)).squeeze(-1)

    def forward(self, src, dst):
        n = src.shape[0]
        feats = self.feats
        g = self.f if self.gather2neighbor else self.g
        mid0 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        mid1 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        # mid0, mid1 = self.nns(src, dst)
        srcdiff = self.f(feats[src]).unsqueeze(1) - self.f(feats[mid0])
        logits1 = (
            ip(srcdiff, g(feats[dst]).unsqueeze(1))
            + self.recall(mid0, dst.unsqueeze(1))
        )
        dstdiff = self.f(feats[dst]).unsqueeze(1) - self.f(feats[mid1])
        logits2 = (
            ip(dstdiff, g(feats[src]).unsqueeze(1))
            + self.recall(src.unsqueeze(1), mid1)
        )
        logits = torch.cat((logits1, logits2), dim=1)
        dist = torch.cat((srcdiff, dstdiff), dim=1).norm(dim=2)
        logits = torch.cat((
            logits, gpu(torch.zeros(n, self.n_samples))), dim=1)
        dist = torch.cat((
            dist, gpu(torch.ones(n, self.n_samples))), dim=1)
        return torch.sigmoid(ip(logits, torch.softmax(-dist, dim=1)))

In [9]:
import pandas as pd
import numpy as np
def load_network_data(adj_name):
    if adj_name == 'wiki':
        nodes_numbers = 2405
        raw_edges = pd.read_csv('datasets/graph.txt', header=None, sep='\t')
    elif adj_name == 'email':
        nodes_numbers = 1133
        raw_edges = pd.read_csv('datasets/ia-email-univ.mtx', header=None, sep=' ') - 1
    else:
        print("Dataset is not exist!")

    drop_self_loop = raw_edges[raw_edges[0] != raw_edges[1]]

    graph_np = gpu(torch.zeros((nodes_numbers, nodes_numbers), dtype=bool))

    for i in range(drop_self_loop.shape[0]):
        graph_np[drop_self_loop.iloc[i,0], drop_self_loop.iloc[i,1]]=1
        graph_np[drop_self_loop.iloc[i,1], drop_self_loop.iloc[i,0]]=1

    features = torch.eye(nodes_numbers)
    
    return graph_np, features, nodes_numbers

In [10]:
adj, node_features, n_nodes = load_network_data(g_data_name)
n_features = node_features.shape[1]

src = adj.nonzero()[:,0]
dst = adj.nonzero()[:,1]

flt = src>dst

src = src[flt]
dst = dst[flt]

In [11]:
if g_toy:
    mem = None
    train_src = gpu(src)
    train_dst = gpu(dst)
    mlp = gpu(nn.Linear(g_dim, n_labels))
    params = list(mlp.parameters())
    print('mlp params:', sum(p.numel() for p in params))
    mlp_opt = optim.Adam(params, lr=lr)
else:
    n = src.shape[0]
    perm = torch.randperm(n)
    val_num = int(0.05 * n)
    test_num = int(0.1 * n)
    train_src = gpu(src[perm[val_num + test_num:]])
    train_dst = gpu(dst[perm[val_num + test_num:]])
    val_src = gpu(src[perm[:val_num]])
    val_dst = gpu(dst[perm[:val_num]])
    test_src = gpu(src[perm[val_num:val_num + test_num]])
    test_dst = gpu(dst[perm[val_num:val_num + test_num]])
    train_src, train_dst = (
        torch.cat((train_src, train_dst)),
        torch.cat((train_dst, train_src)))
    val_src, val_dst = (
        torch.cat((val_src, val_dst)),
        torch.cat((val_dst, val_src)))
    test_src, test_dst = (
        torch.cat((test_src, test_dst)),
        torch.cat((test_dst, test_src)))
    mem = gpu(torch.zeros((n_nodes, n_nodes), dtype=bool))
    mem[train_src, train_dst] = 1

In [13]:
total_aucs = []
total_aps = []

for run in range(10):
    torch.manual_seed(run)
    mad = MAD(
        in_feats=n_features,
        n_nodes=n_nodes,
        node_feats=g_dim,
        n_samples=n_samples,
        mem=mem,
        feats=node_features,
        gather2neighbor=g_toy,
    )
    params = list(mad.parameters())
    print('params:', sum(p.numel() for p in params))
    opt = optim.Adam(params, lr=0.01)
    
    best_accs = [0, 0]
    
    for epoch in range(1, total_epoch + 1):
        mad.train()
        for perm in DataLoader(
                range(train_src.shape[0]), batch_size=1024, shuffle=True):
            opt.zero_grad()
            p_pos = mad(train_src[perm], train_dst[perm])
            neg_src = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            neg_dst = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            idx = ~(mem[neg_src, neg_dst])
            p_neg = mad(neg_src[idx], neg_dst[idx])
            loss = (
                -torch.log(1e-5 + 1 - p_neg).mean()
                - torch.log(1e-5 + p_pos).mean()
            )
            loss.backward()
            opt.step()

        if epoch % 10:
            continue

        with torch.no_grad():
            mad.eval()
            aucs = []
            aps = []
            
            for src, dst in ((val_src, val_dst), (test_src, test_dst)):
                p_pos = mad(src, dst)

                n = src.shape[0]
                perm = torch.randperm(n * 2)
                neg_src = torch.cat((
                    src, gpu(torch.randint(0, n_nodes, (n, )))
                ))[perm]
                neg_dst = torch.cat((
                    gpu(torch.randint(0, n_nodes, (n, ))), dst
                ))[perm]
                idx = ~(adj[neg_src, neg_dst])
                neg_src = neg_src[idx][:n]
                neg_dst = neg_dst[idx][:n]
                p_neg = mad(neg_src, neg_dst)

                y_true = cpu(torch.cat((p_pos * 0 + 1, p_neg * 0)))
                y_score = cpu(torch.cat((p_pos, p_neg)))
                
                fpr, tpr, _ = metrics.roc_curve(y_true, y_score, pos_label=1)
                auc = metrics.auc(fpr, tpr)
                
                ap = metrics.average_precision_score(y_true, y_score)

                aucs.append(auc)
                aps.append(ap)

        print(epoch, '  auc:', aucs, '  ap:', aps)
        total_aucs.append(aucs)
        total_aps.append(aps)

params: 72578
10   auc: [0.6789508542387543, 0.672781752377746]   ap: [0.6668460170783869, 0.6715799744855789]
20   auc: [0.7738936797145328, 0.7335939735712482]   ap: [0.7716249904433259, 0.7358742419976504]
30   auc: [0.8321697934688581, 0.7794840501641276]   ap: [0.8396634604892331, 0.7798103984509703]
40   auc: [0.8550294658304499, 0.8155332042757344]   ap: [0.8624518214447177, 0.8310010529994554]
50   auc: [0.8604596939878892, 0.8274901102600791]   ap: [0.8741608337671933, 0.8389188419116094]
60   auc: [0.8662886570069204, 0.8287391633700868]   ap: [0.8767563416154398, 0.8367650320175001]
70   auc: [0.8574826989619376, 0.8330376230956991]   ap: [0.8691961715920853, 0.8408092783465242]
80   auc: [0.8773214478806228, 0.8449726454002189]   ap: [0.8803883207299357, 0.8580090180289252]
90   auc: [0.8702388354238754, 0.8475288275397694]   ap: [0.8753221696922164, 0.8655125847346464]
100   auc: [0.8680221399221454, 0.8457621412339029]   ap: [0.881162070631897, 0.8616633049191873]
110   a

40   auc: [0.8390496593858131, 0.8221917347024661]   ap: [0.8404972723103232, 0.8303238091991471]
50   auc: [0.8600947502162628, 0.8318155037454761]   ap: [0.8702942654068446, 0.8379531804357947]
60   auc: [0.8443616187283738, 0.8372535981819712]   ap: [0.8592746850412194, 0.8486767186629548]
70   auc: [0.8599494485294119, 0.8471837387425301]   ap: [0.882387516530621, 0.8609488344477478]
80   auc: [0.8467134785899655, 0.8462730409898158]   ap: [0.861052371074339, 0.8562857202626446]
90   auc: [0.866261624134948, 0.8501540274387678]   ap: [0.8772463765060383, 0.8586170586191146]
100   auc: [0.8678024978373702, 0.8621959431024324]   ap: [0.878781529423242, 0.8725593924474665]
110   auc: [0.8637948745674741, 0.8459245854726034]   ap: [0.8833772696787909, 0.8546558116341156]
120   auc: [0.8743816230536332, 0.8395353926437169]   ap: [0.8942922494216737, 0.8550121175933715]
130   auc: [0.8701408412629759, 0.8620478074236175]   ap: [0.8857545292483482, 0.880900100840007]
140   auc: [0.8587059

80   auc: [0.8672922523788926, 0.8441006649271947]   ap: [0.8811018949163243, 0.8563686580932639]
90   auc: [0.8639976211072664, 0.8412642033498863]   ap: [0.8767870192151274, 0.8578029182292813]
100   auc: [0.8617843047145328, 0.8481828128945375]   ap: [0.8772479252725489, 0.8653312541796006]
110   auc: [0.8680694474480968, 0.8421227169430183]   ap: [0.8721056050415036, 0.8600960080749663]
120   auc: [0.8760069744809689, 0.8486331116909351]   ap: [0.8872677742055874, 0.8699374587688379]
130   auc: [0.8789907277249136, 0.8500791179193671]   ap: [0.8859790425745152, 0.8723653636248366]
140   auc: [0.8574995945069204, 0.8476138372190892]   ap: [0.8685590768936116, 0.861899176707793]
150   auc: [0.8752703287197231, 0.846785624105715]   ap: [0.8917304143579363, 0.8529648062529018]
160   auc: [0.8718608077422145, 0.8445290800437674]   ap: [0.8902060553541782, 0.8586935926627008]
170   auc: [0.8725974535034602, 0.8464784109081727]   ap: [0.8801052605898352, 0.863177637333515]
180   auc: [0.8