In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics

In [2]:
g_data_name = 'wiki'  # wiki | power | pages-food | dublin
g_toy = False
g_dim = 32
n_samples = 8
total_epoch = 200
lr = 0.005
if g_toy:
    g_dim = g_dim // 2

In [3]:
def gpu(x):
    return x

def cpu(x):
    return x

def ip(x, y):
    return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)

In [4]:
class MAD(nn.Module):
    def __init__(
            self, in_feats, n_nodes, node_feats,
            n_samples, mem, feats, gather2neighbor=False,
    ):
        super(self.__class__, self).__init__()
        self.n_nodes = n_nodes
        self.node_feats = node_feats
        self.n_samples = n_samples
        self.mem = mem
        self.feats = feats
        self.gather2neighbor = gather2neighbor
        self.f = gpu(nn.Linear(in_feats, node_feats))
        self.g = (
            None if gather2neighbor else gpu(nn.Linear(in_feats, node_feats)))
        self.adapt = gpu(nn.Linear(1, 1))
        self.nn = None

    def nns(self, src, dst):
        if self.nn is None:
            n = self.n_samples
            self.nn = gpu(torch.empty((self.n_nodes, n), dtype=int))
            for perm in DataLoader(
                    range(self.n_nodes), 64, shuffle=False):
                self.nn[perm] = (
                    self.feats[perm].unsqueeze(1) - self.feats.unsqueeze(0)
                ).norm(dim=-1).topk(1 + n, largest=False).indices[..., 1:]
        return self.nn[src], self.nn[dst]

    def recall(self, src, dst):
        if self.mem is None:
            return 0
        return self.adapt(
            (0.0 + self.mem[src, dst]).unsqueeze(-1)).squeeze(-1)

    def forward(self, src, dst):
        n = src.shape[0]
        feats = self.feats
        g = self.f if self.gather2neighbor else self.g
        mid0 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        mid1 = torch.randint(0, self.n_nodes, (n, self.n_samples))
        # mid0, mid1 = self.nns(src, dst)
        srcdiff = self.f(feats[src]).unsqueeze(1) - self.f(feats[mid0])
        logits1 = (
            ip(srcdiff, g(feats[dst]).unsqueeze(1))
            + self.recall(mid0, dst.unsqueeze(1))
        )
        dstdiff = self.f(feats[dst]).unsqueeze(1) - self.f(feats[mid1])
        logits2 = (
            ip(dstdiff, g(feats[src]).unsqueeze(1))
            + self.recall(src.unsqueeze(1), mid1)
        )
        logits = torch.cat((logits1, logits2), dim=1)
        dist = torch.cat((srcdiff, dstdiff), dim=1).norm(dim=2)
        logits = torch.cat((
            logits, gpu(torch.zeros(n, self.n_samples))), dim=1)
        dist = torch.cat((
            dist, gpu(torch.ones(n, self.n_samples))), dim=1)
        return torch.sigmoid(ip(logits, torch.softmax(-dist, dim=1)))

In [5]:
import pandas as pd
import numpy as np
def load_network_data(adj_name):
    if adj_name == 'wiki':
        nodes_numbers = 2405
        raw_edges = pd.read_csv('datasets/graph.txt', header=None, sep='\t')
    elif adj_name == 'power':
        nodes_numbers = 1176
        raw_edges = pd.read_csv('datasets/power-eris1176.mtx', header=None, sep=' ')-1
    elif adj_name == 'pages-food':
        nodes_numbers = 620
        raw_edges = pd.read_csv('datasets/fb-pages-food.edges', header=None)
    elif adj_name == 'dublin':
        nodes_numbers = 410
        raw_edges = pd.read_csv('datasets/ia-infect-dublin.mtx', header=None, sep=' ') - 1
    else:
        print("Dataset is not exist!")

    drop_self_loop = raw_edges[raw_edges[0] != raw_edges[1]]

    graph_np = gpu(torch.zeros((nodes_numbers, nodes_numbers), dtype=bool))

    for i in range(drop_self_loop.shape[0]):
        graph_np[drop_self_loop.iloc[i,0], drop_self_loop.iloc[i,1]]=1
        graph_np[drop_self_loop.iloc[i,1], drop_self_loop.iloc[i,0]]=1

    features = torch.eye(nodes_numbers)
    
    return graph_np, features, nodes_numbers

In [6]:
adj, node_features, n_nodes = load_network_data(g_data_name)
n_features = node_features.shape[1]

src = adj.nonzero()[:,0]
dst = adj.nonzero()[:,1]

flt = src>dst

src = src[flt]
dst = dst[flt]

In [7]:
if g_toy:
    mem = None
    train_src = gpu(src)
    train_dst = gpu(dst)
    mlp = gpu(nn.Linear(g_dim, n_labels))
    params = list(mlp.parameters())
    print('mlp params:', sum(p.numel() for p in params))
    mlp_opt = optim.Adam(params, lr=lr)
else:
    n = src.shape[0]
    perm = torch.randperm(n)
    val_num = int(0.05 * n)
    test_num = int(0.1 * n)
    train_src = gpu(src[perm[val_num + test_num:]])
    train_dst = gpu(dst[perm[val_num + test_num:]])
    val_src = gpu(src[perm[:val_num]])
    val_dst = gpu(dst[perm[:val_num]])
    test_src = gpu(src[perm[val_num:val_num + test_num]])
    test_dst = gpu(dst[perm[val_num:val_num + test_num]])
    train_src, train_dst = (
        torch.cat((train_src, train_dst)),
        torch.cat((train_dst, train_src)))
    val_src, val_dst = (
        torch.cat((val_src, val_dst)),
        torch.cat((val_dst, val_src)))
    test_src, test_dst = (
        torch.cat((test_src, test_dst)),
        torch.cat((test_dst, test_src)))
    mem = gpu(torch.zeros((n_nodes, n_nodes), dtype=bool))
    mem[train_src, train_dst] = 1

In [8]:
total_aucs = []
total_aps = []
total_accs = []
total_f1s = []

for run in range(10):
    torch.manual_seed(run)
    mad = MAD(
        in_feats=n_features,
        n_nodes=n_nodes,
        node_feats=g_dim,
        n_samples=n_samples,
        mem=mem,
        feats=node_features,
        gather2neighbor=g_toy,
    )
    params = list(mad.parameters())
    print('params:', sum(p.numel() for p in params))
    opt = optim.Adam(params, lr=0.01)
    
    best_accs = [0, 0]
    
    for epoch in range(1, total_epoch + 1):
        mad.train()
        for perm in DataLoader(
                range(train_src.shape[0]), batch_size=1024, shuffle=True):
            opt.zero_grad()
            p_pos = mad(train_src[perm], train_dst[perm])
            neg_src = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            neg_dst = gpu(torch.randint(0, n_nodes, (perm.shape[0], )))
            idx = ~(mem[neg_src, neg_dst])
            p_neg = mad(neg_src[idx], neg_dst[idx])
            loss = (
                -torch.log(1e-5 + 1 - p_neg).mean()
                - torch.log(1e-5 + p_pos).mean()
            )
            loss.backward()
            opt.step()

        if epoch % 10:
            continue

        with torch.no_grad():
            mad.eval()
            aucs = []
            aps = []
            accs = []
            f1s = []
            
            for src, dst in ((val_src, val_dst), (test_src, test_dst)):
                p_pos = mad(src, dst)

                n = src.shape[0]
                perm = torch.randperm(n * 2)
                neg_src = torch.cat((
                    src, gpu(torch.randint(0, n_nodes, (n, )))
                ))[perm]
                neg_dst = torch.cat((
                    gpu(torch.randint(0, n_nodes, (n, ))), dst
                ))[perm]
                idx = ~(adj[neg_src, neg_dst])
                neg_src = neg_src[idx][:n]
                neg_dst = neg_dst[idx][:n]
                p_neg = mad(neg_src, neg_dst)

                y_true = cpu(torch.cat((p_pos * 0 + 1, p_neg * 0)))
                y_score = cpu(torch.cat((p_pos, p_neg)))
                
                fpr, tpr, _ = metrics.roc_curve(y_true, y_score, pos_label=1)
                auc = metrics.auc(fpr, tpr)
                
                ap = metrics.average_precision_score(y_true, y_score)
                
                y_score[y_score>0.5]=1
                y_score[y_score<=0.5]=0
                acc = metrics.accuracy_score(y_true, y_score)
                f1 = metrics.f1_score(y_true, y_score)

                aucs.append(auc)
                aps.append(ap)
                accs.append(acc)
                f1s.append(f1)

        print(epoch, '  auc:', aucs, '  ap:', aps, '  accs:', accs, '  f1:', f1s)
        total_aucs.append(aucs)
        total_aps.append(aps)
        total_accs.append(accs)
        total_f1s.append(f1s)

params: 153986
10   auc: [0.7936156377054119, 0.8051702324383356]   ap: [0.808508996299231, 0.815060517656427]   accs: [0.7193436960276338, 0.7288610871440897]   f1: [0.7298420615128844, 0.7434170238824249]
20   auc: [0.8672402838554951, 0.8803359088679137]   ap: [0.8875603806449899, 0.8956740550162563]   accs: [0.7966321243523317, 0.8116911130284729]   f1: [0.7850296668188043, 0.8051774157554118]
30   auc: [0.8880350255487843, 0.9031662027528119]   ap: [0.9049619634320816, 0.9198564790325886]   accs: [0.8147668393782384, 0.838006902502157]   f1: [0.7969711310932324, 0.8260365994903868]
40   auc: [0.8951686995325752, 0.9131328441331337]   ap: [0.9111666905178757, 0.9283153983556458]   accs: [0.8199481865284974, 0.8425366695427092]   f1: [0.7980629539951574, 0.827096162955945]
50   auc: [0.8989227153003361, 0.9163475475347304]   ap: [0.9121363449358834, 0.9296183767734363]   accs: [0.8199481865284974, 0.8492234685073339]   f1: [0.7944800394282897, 0.8312001931900508]
60   auc: [0.903897

40   auc: [0.9037207561127667, 0.9191098139555312]   ap: [0.9208966289142009, 0.93152180973666]   accs: [0.8337651122625216, 0.8584987057808455]   f1: [0.8140994688556253, 0.8445497630331754]
50   auc: [0.9025014840070278, 0.9219012998769431]   ap: [0.9200799100107222, 0.9373014463192316]   accs: [0.8268566493955095, 0.8584987057808455]   f1: [0.8025603151157066, 0.8428366075706755]
60   auc: [0.9065276920185777, 0.9182184330754325]   ap: [0.9249219700058944, 0.9345634296923899]   accs: [0.822538860103627, 0.8546160483175151]   f1: [0.7952167414050823, 0.8384467881112176]
70   auc: [0.910644133623274, 0.9174759413704207]   ap: [0.9278274890163739, 0.9322082449007484]   accs: [0.8182210708117443, 0.8485763589301122]   f1: [0.7881227981882235, 0.8303528274528759]
80   auc: [0.9082093180726701, 0.9195490370220378]   ap: [0.9243805683259305, 0.9340988403564313]   accs: [0.8182210708117443, 0.8421052631578947]   f1: [0.7874810701665825, 0.8205002452182442]
90   auc: [0.9029340086683909, 0.9

70   auc: [0.9018206305314683, 0.9056210130270583]   ap: [0.9137912570705833, 0.9235656419297341]   accs: [0.809153713298791, 0.836281276962899]   f1: [0.7778894472361809, 0.8155528554070475]
80   auc: [0.8999354195936654, 0.9127204211181428]   ap: [0.9220908894290893, 0.9294553372397314]   accs: [0.8169257340241797, 0.8399482312338222]   f1: [0.7832310838445808, 0.8174212598425196]
90   auc: [0.9048348799818638, 0.9082068457753814]   ap: [0.9232284722193677, 0.925886167984324]   accs: [0.8052677029360967, 0.8367126833477135]   f1: [0.7676455435342607, 0.8132247717739945]
100   auc: [0.8990256263404537, 0.9140801515096246]   ap: [0.9199866004350189, 0.9343204521207715]   accs: [0.8074265975820379, 0.8352027610008628]   f1: [0.7698658410732715, 0.8086172344689379]
110   auc: [0.9023090851059388, 0.9129174201079298]   ap: [0.923093450635854, 0.9312886209183514]   accs: [0.7983592400690847, 0.8313201035375324]   f1: [0.756135770234987, 0.8036162732295328]
120   auc: [0.8944132728395393, 0

100   auc: [0.9095747536846627, 0.9074616554540711]   ap: [0.9204147097672233, 0.925039032162441]   accs: [0.8104490500863558, 0.8280845556514237]   f1: [0.7747562852744997, 0.8002005515166708]
110   auc: [0.9145637019338327, 0.9143619242734766]   ap: [0.9328704607656801, 0.9330153722161743]   accs: [0.8069948186528497, 0.8254961173425367]   f1: [0.7663355985363303, 0.7939903234020882]
120   auc: [0.9055567487270351, 0.911728260877657]   ap: [0.923575277949769, 0.9330695084632364]   accs: [0.7979274611398963, 0.8272217428817946]   f1: [0.7521186440677966, 0.7967520933773153]
130   auc: [0.9077857421974043, 0.9115430799661426]   ap: [0.9252992030757333, 0.9288544683325832]   accs: [0.7944732297063903, 0.8192407247627265]   f1: [0.7484143763213531, 0.7851282051282051]
140   auc: [0.9034970364603375, 0.9054650516161549]   ap: [0.9160654254014551, 0.9259685588807163]   accs: [0.7841105354058722, 0.8067299396031061]   f1: [0.7320471596998929, 0.7682359027418519]
150   auc: [0.90387139401206

130   auc: [0.9012516368821235, 0.9148333446240958]   ap: [0.9211229390034551, 0.9356641028644518]   accs: [0.8061312607944733, 0.8252804141501294]   f1: [0.7667532467532467, 0.7926267281105991]
140   auc: [0.8976236498518976, 0.9115334021697619]   ap: [0.9197358733590641, 0.9292399641790746]   accs: [0.8026770293609672, 0.8125539257981018]   f1: [0.7596002104155707, 0.7760886369492399]
150   auc: [0.8952104605343618, 0.9140864792995658]   ap: [0.9176338938929107, 0.9337438187635724]   accs: [0.7914507772020726, 0.8121225194132873]   f1: [0.7424, 0.7733541504033308]
160   auc: [0.89763185290582, 0.9098317850099867]   ap: [0.9206936909655358, 0.9303832293875318]   accs: [0.7966321243523317, 0.8093183779119931]   f1: [0.7498672331386085, 0.7691906005221932]
170   auc: [0.8909456182268876, 0.9053572930756857]   ap: [0.9120566864463953, 0.9270005883448172]   accs: [0.7871329879101899, 0.8091026747195859]   f1: [0.7359400107123729, 0.7685064085796495]
180   auc: [0.8938949889780785, 0.90795