In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics

In [2]:
g_data_name = 'citeseer'  # wiki | email | cora | citeseer
g_toy = False
g_dim = 32
n_samples = 8
total_epoch = 200
lr = 0.005
if g_toy:
    g_dim = g_dim // 2

In [3]:
def gpu(x):
    return x

def cpu(x):
    return x

def ip(x, y):
    return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)

In [4]:
class MAD(nn.Module):
    def __init__(
            self, in_feats, n_nodes, node_feats,
            n_samples, mem, feats, gather2neighbor=False,
    ):
        super(self.__class__, self).__init__()
        self.n_nodes = n_nodes
        self.node_feats = node_feats
        self.n_samples = n_samples
        self.mem = mem
        self.feats = feats
        self.gather2neighbor = gather2neighbor
        self.f = gpu(nn.Linear(in_feats, node_feats))
        self.g = (
            None if gather2neighbor else gpu(nn.Linear(in_feats, node_feats)))
        self.adapt = gpu(nn.Linear(1, 1))
        self.nn = None

    def nns(self, src, dst):
        if self.nn is None:
            n = self.n_samples
            self.nn = gpu(torch.empty((self.n_nodes, n), dtype=int))
            for perm in DataLoader(
                    range(self.n_nodes), 64, shuffle=False):
                self.nn[perm] = (
                    self.feats[perm].unsqueeze(1) - self.feats.unsqueeze(0)
                ).norm(dim=-1).topk(1 + n, largest=False).indices[..., 1:]
        return self.nn[src], self.nn[dst]

    def recall(self, src, dst):
        if self.mem is None:
            return 0
        return self.adapt(
            (0.0 + self.mem[src, dst]).unsqueeze(-1)).squeeze(-1)

    def forward(self, src, dst):
        n = src.shape[0]
        feats = self.feats
        g = self.f if self.gather2neighbor else self.g
        mid0 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        mid1 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        # mid0, mid1 = self.nns(src, dst)
        srcdiff = self.f(feats[src]).unsqueeze(1) - self.f(feats[mid0])
        logits1 = (
            ip(srcdiff, g(feats[dst]).unsqueeze(1))
            + self.recall(mid0, dst.unsqueeze(1))
        )
        dstdiff = self.f(feats[dst]).unsqueeze(1) - self.f(feats[mid1])
        logits2 = (
            ip(dstdiff, g(feats[src]).unsqueeze(1))
            + self.recall(src.unsqueeze(1), mid1)
        )
        logits = torch.cat((logits1, logits2), dim=1)
        dist = torch.cat((srcdiff, dstdiff), dim=1).norm(dim=2)
        logits = torch.cat((
            logits, gpu(torch.zeros(n, self.n_samples))), dim=1)
        dist = torch.cat((
            dist, gpu(torch.ones(n, self.n_samples))), dim=1)
        return torch.sigmoid(ip(logits, torch.softmax(-dist, dim=1)))

In [5]:
import pandas as pd
import numpy as np
def load_network_data(adj_name):
    if adj_name == 'wiki':
        nodes_numbers = 2405
        raw_edges = pd.read_csv('datasets/graph.txt', header=None, sep='\t')
    elif adj_name == 'email':
        nodes_numbers = 1133
        raw_edges = pd.read_csv('datasets/ia-email-univ.mtx', header=None, sep=' ') - 1
    elif adj_name == 'citeseer':
        nodes_numbers = 3327
        raw_edges = pd.read_csv('datasets/citeseer-edges.txt', header=None)
    else:
        print("Dataset is not exist!")

    drop_self_loop = raw_edges[raw_edges[0] != raw_edges[1]]

    graph_np = gpu(torch.zeros((nodes_numbers, nodes_numbers), dtype=bool))

    for i in range(drop_self_loop.shape[0]):
        graph_np[drop_self_loop.iloc[i,0], drop_self_loop.iloc[i,1]]=1
        graph_np[drop_self_loop.iloc[i,1], drop_self_loop.iloc[i,0]]=1

    features = torch.eye(nodes_numbers)
    
    return graph_np, features, nodes_numbers

In [6]:
adj, node_features, n_nodes = load_network_data(g_data_name)
n_features = node_features.shape[1]

src = adj.nonzero()[:,0]
dst = adj.nonzero()[:,1]

flt = src>dst

src = src[flt]
dst = dst[flt]

In [7]:
if g_toy:
    mem = None
    train_src = gpu(src)
    train_dst = gpu(dst)
    mlp = gpu(nn.Linear(g_dim, n_labels))
    params = list(mlp.parameters())
    print('mlp params:', sum(p.numel() for p in params))
    mlp_opt = optim.Adam(params, lr=lr)
else:
    n = src.shape[0]
    perm = torch.randperm(n)
    val_num = int(0.05 * n)
    test_num = int(0.1 * n)
    train_src = gpu(src[perm[val_num + test_num:]])
    train_dst = gpu(dst[perm[val_num + test_num:]])
    val_src = gpu(src[perm[:val_num]])
    val_dst = gpu(dst[perm[:val_num]])
    test_src = gpu(src[perm[val_num:val_num + test_num]])
    test_dst = gpu(dst[perm[val_num:val_num + test_num]])
    train_src, train_dst = (
        torch.cat((train_src, train_dst)),
        torch.cat((train_dst, train_src)))
    val_src, val_dst = (
        torch.cat((val_src, val_dst)),
        torch.cat((val_dst, val_src)))
    test_src, test_dst = (
        torch.cat((test_src, test_dst)),
        torch.cat((test_dst, test_src)))
    mem = gpu(torch.zeros((n_nodes, n_nodes), dtype=bool))
    mem[train_src, train_dst] = 1

In [8]:
total_aucs = []
total_aps = []

for run in range(10):
    torch.manual_seed(run)
    mad = MAD(
        in_feats=n_features,
        n_nodes=n_nodes,
        node_feats=g_dim,
        n_samples=n_samples,
        mem=mem,
        feats=node_features,
        gather2neighbor=g_toy,
    )
    params = list(mad.parameters())
    print('params:', sum(p.numel() for p in params))
    opt = optim.Adam(params, lr=0.01)
    
    best_accs = [0, 0]
    
    for epoch in range(1, total_epoch + 1):
        mad.train()
        for perm in DataLoader(
                range(train_src.shape[0]), batch_size=1024, shuffle=True):
            opt.zero_grad()
            p_pos = mad(train_src[perm], train_dst[perm])
            neg_src = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            neg_dst = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            idx = ~(mem[neg_src, neg_dst])
            p_neg = mad(neg_src[idx], neg_dst[idx])
            loss = (
                -torch.log(1e-5 + 1 - p_neg).mean()
                - torch.log(1e-5 + p_pos).mean()
            )
            loss.backward()
            opt.step()

        if epoch % 10:
            continue

        with torch.no_grad():
            mad.eval()
            aucs = []
            aps = []
            
            for src, dst in ((val_src, val_dst), (test_src, test_dst)):
                p_pos = mad(src, dst)

                n = src.shape[0]
                perm = torch.randperm(n * 2)
                neg_src = torch.cat((
                    src, gpu(torch.randint(0, n_nodes, (n, )))
                ))[perm]
                neg_dst = torch.cat((
                    gpu(torch.randint(0, n_nodes, (n, ))), dst
                ))[perm]
                idx = ~(adj[neg_src, neg_dst])
                neg_src = neg_src[idx][:n]
                neg_dst = neg_dst[idx][:n]
                p_neg = mad(neg_src, neg_dst)

                y_true = cpu(torch.cat((p_pos * 0 + 1, p_neg * 0)))
                y_score = cpu(torch.cat((p_pos, p_neg)))
                
                fpr, tpr, _ = metrics.roc_curve(y_true, y_score, pos_label=1)
                auc = metrics.auc(fpr, tpr)
                
                ap = metrics.average_precision_score(y_true, y_score)

                aucs.append(auc)
                aps.append(ap)

        print(epoch, '  auc:', aucs, '  ap:', aps)
        total_aucs.append(aucs)
        total_aps.append(aps)

params: 212994
10   auc: [0.585049195598595, 0.5729742784687839]   ap: [0.6212121090583536, 0.6370911215049881]
20   auc: [0.6124997574181529, 0.6099734331602464]   ap: [0.6690525349267862, 0.671899567988977]
30   auc: [0.6433270585495545, 0.6322207462866805]   ap: [0.7165819701030725, 0.7017901710520332]
40   auc: [0.6400764617982108, 0.638305760173892]   ap: [0.7099599331222483, 0.7131429132529894]
50   auc: [0.6601476838285236, 0.6557034174616592]   ap: [0.73782808208756, 0.7367301402038269]
60   auc: [0.6659502416115197, 0.6743738678903515]   ap: [0.7489216875276189, 0.7511980997114565]
70   auc: [0.6600943158221584, 0.6624695085134645]   ap: [0.7457297864984236, 0.7490427260926319]
80   auc: [0.6702051272099206, 0.6650175099625649]   ap: [0.7476139893164624, 0.7553996021724241]
90   auc: [0.6714325913563237, 0.6702644608139113]   ap: [0.7551477919742403, 0.7535474294051363]
100   auc: [0.6726066874963613, 0.6653459727086101]   ap: [0.7625424082061775, 0.7489882784745675]
110   auc

40   auc: [0.6644801956180015, 0.6591872962202633]   ap: [0.7468009218390751, 0.727000761228682]
50   auc: [0.6432639872693047, 0.6599118463953628]   ap: [0.7333430305300309, 0.7411099243396811]
60   auc: [0.6686380484775563, 0.673738678903514]   ap: [0.7487891036240597, 0.7568278762273897]
70   auc: [0.6488336664790701, 0.6748750150947953]   ap: [0.7390832602529696, 0.7489403622939961]
80   auc: [0.6669060140891537, 0.6754208428933703]   ap: [0.7600801799264576, 0.7576829559583309]
90   auc: [0.6418036445496711, 0.6751793261683372]   ap: [0.7216995436372171, 0.7592689946307407]
100   auc: [0.6528314153195289, 0.6759364811013162]   ap: [0.7479116928521935, 0.761735986868735]
110   auc: [0.6546944439053737, 0.6845429295978747]   ap: [0.7487427321033185, 0.770861787950463]
120   auc: [0.6653437869937316, 0.6759908223644489]   ap: [0.7604045984175576, 0.7617964848927234]
130   auc: [0.6575083933319101, 0.6934428209153485]   ap: [0.7558689471825456, 0.7738581888329494]
140   auc: [0.652719

80   auc: [0.7165140018242155, 0.6951793261683372]   ap: [0.7807923896133668, 0.7694485679612451]
90   auc: [0.700658852296765, 0.6881874169786257]   ap: [0.7706366094966326, 0.7645119238402924]
100   auc: [0.6983397698383433, 0.6979108803284628]   ap: [0.7659611057007341, 0.7678087872759651]
110   auc: [0.6970055696792098, 0.7008308175341142]   ap: [0.7636242044196333, 0.7721493904597231]
120   auc: [0.6879524151448698, 0.6918717546190074]   ap: [0.7624505950445624, 0.7682501185933841]
130   auc: [0.6871713015971588, 0.7022267842048062]   ap: [0.7660540105449832, 0.7745328533683076]
140   auc: [0.6929908401094529, 0.6899830938292477]   ap: [0.7646632812214841, 0.7714676403188274]
150   auc: [0.6848813289603912, 0.7038050960028982]   ap: [0.7588100208437717, 0.7730971304558784]
160   auc: [0.7051660230161656, 0.6943835285593527]   ap: [0.7755944914269028, 0.7707667888168763]
170   auc: [0.6870597139474859, 0.6993056394155295]   ap: [0.7637725133645493, 0.7785668831632647]
180   auc: [0