In [None]:
import os
os.environ["DGLBACKEND"] = "pytorch"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import dgl
import numpy as np
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F


device = "cuda" 

num_features = 49

batch_size = 32

norm = nn.BatchNorm1d(num_features, affine=False)

negative_sampler = dgl.dataloading.negative_sampler.Uniform(1)
sampler = dgl.dataloading.NeighborSampler([1000, 1000])
sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler, negative_sampler=negative_sampler
)

def load_graph(url):
    es = pickle.load(open('{}/es.pkl'.format(url), 'rb'))
    ndata = pickle.load(open('{}/ndata.pkl'.format(url), 'rb'))

    g = dgl.graph((es[0], es[1]))
    g = dgl.add_reverse_edges(g)

    g.ndata['feat'] = norm(torch.tensor(ndata, dtype=torch.float32))

    eid = torch.concat((g.in_edges(0, 'eid'), g.out_edges(0, 'eid')))

    return dgl.dataloading.DataLoader(
        g, 
        eid, 
        sampler, 
        device=device, 
        batch_size=batch_size, 
        shuffle=True,  
        drop_last=False,  
        num_workers=0,  
    )

# for file in os.listdir('graph/Ponzi'):
#     dl = load_graph('graph/Ponzi/{}'.format(file))
#     print(len(dl))

# print(sum([len(l)for l in train_dataloaders]), sum([len(l)for l in test_dataloaders]))


In [None]:
from dgl.nn import SAGEConv, GATv2Conv, GATConv, GraphConv

class SAGE_GCN(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(SAGE_GCN, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='gcn')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='gcn')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[: mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[: mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        h = F.relu(h)
        return h


class SAGE_MEAN(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(SAGE_MEAN, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[: mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[: mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        h = F.relu(h)
        return h
    

class SAGE_POOL(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(SAGE_POOL, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='pool')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='pool')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[: mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[: mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        h = F.relu(h)
        return h
    
class GAT(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_feats, h_feats, num_heads=3)
        self.conv2 = GATConv(h_feats, h_feats, num_heads=3)

    def forward(self, mfgs, x):
        h_dst = x[: mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[: mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        h = h.mean(1)
        h = F.relu(h)
        return h

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, h_feats)

    def forward(self, mfgs, x):
        h_dst = x[: mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[: mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        h = F.relu(h)
        return h


import dgl.function as fn

class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata["h"] = h
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
            return g.edata["score"][:, 0]

predictor = DotPredictor().to(device)

def metrics(pos_score, neg_score):
    TP, FP, TN, FN = 0, 0, 0, 0
    l = F.mse_loss
    for score in pos_score:
        if l(score, torch.ones_like(score)) < l(score, torch.zeros_like(score)):
            TP+=1
        else:
            FP+=1
    for score in neg_score:
        if l(score, torch.ones_like(score)) > l(score, torch.zeros_like(score)):
            TN+=1
        else:
            FN+=1
    return TP, FP, TN, FN

def test(data_loader, model):
    model.eval()
    TP, FP, TN, FN = 0, 0, 0, 0
    with torch.no_grad():
        for input_nodes, pos_graph, neg_graph, mfgs in data_loader:
            inputs = mfgs[0].srcdata["feat"]
            outputs = model(mfgs, inputs)
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)
            tp, fp, tn, fn = metrics(pos_score, neg_score)
            TP += tp
            FP += fp
            TN += tn
            FN += fn
    return TP, FP, TN, FN

In [None]:

import tqdm

epoch = 50
batch_size = 32

for times in range(1, 5):
    # ['Money Laundering', 'Gambles', 'Blackmail', 'Ponzi']
    for t in ['Money Laundering', 'Gambles', 'Blackmail', 'Ponzi']:
        #['GCN', 'GAT', 'SAGE_POOL', 'SAGE_MEAN', 'SAGE_GCN']
        for t2 in ['GCN', 'GAT', 'SAGE_POOL', 'SAGE_MEAN', 'SAGE_GCN']:
            files = os.listdir('graph/{}'.format(t))

            train_dataloaders = []
            test_dataloaders = []

            if t == 'Money Laundering':
                train_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f not in ['17UVSMegvrzfobKC82dHXpZLtLcqzW9stF', '3F2sZ4jbhvDKQdGbHYPC6ZxFXEau2m5Lqj']]
                test_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f in ['17UVSMegvrzfobKC82dHXpZLtLcqzW9stF', '3F2sZ4jbhvDKQdGbHYPC6ZxFXEau2m5Lqj']]
            elif t == 'Gambles':
                train_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f not in ['1DzF87L4vQqtvK9VPzr8KShck141MZFXTi', '1PhioUeHy92gkLtUan7sWGH9EYXyjoKkha']]
                test_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f in ['1DzF87L4vQqtvK9VPzr8KShck141MZFXTi', '1PhioUeHy92gkLtUan7sWGH9EYXyjoKkha']]
            elif t == 'Blackmail':
                train_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f not in ['1HMsu3Dg3ocegPN2psqQtnsgZESeHVuxmN', '1JaF87nh5MV47ZbqjamaxgZ7mtAYSDgFd9']]
                test_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f in ['1HMsu3Dg3ocegPN2psqQtnsgZESeHVuxmN', '1JaF87nh5MV47ZbqjamaxgZ7mtAYSDgFd9']]
            elif t == 'Ponzi':
                train_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f not in ['1MHwLU6hqHi7HcZENp4XsZQkYb2nNWGBLf', '31murN3u4dvWjVLEdSQRnhnPeuorxAxcer']]
                test_dataloaders = [load_graph('graph/{}/{}'.format(t, f)) for f in files if f in ['1MHwLU6hqHi7HcZENp4XsZQkYb2nNWGBLf', '31murN3u4dvWjVLEdSQRnhnPeuorxAxcer']]
            
            if t2 == 'GCN':
                model = GCN(num_features, num_features*3).to(device)
            elif t2 == 'GAT':
                model = GAT(num_features, num_features*3).to(device)
            elif t2 == 'SAGE_POOL':
                model = SAGE_POOL(num_features, num_features*3).to(device)
            elif t2 == 'SAGE_MEAN':
                model = SAGE_MEAN(num_features, num_features*3).to(device)
            elif t2 == 'SAGE_GCN':
                model = SAGE_GCN(num_features, num_features*3).to(device)
                
            opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=0.001)
            
            tqdm_epoch = tqdm.tqdm([i for i in range(epoch)], desc='{}_{}'.format(t, t2))
            result_file = open('result/sub-task-1/{}_{}_{}_{}.txt'.format(t, t2, batch_size, times), 'w+')
            for e in tqdm_epoch:
                loss_sum = 0
                for train_dataloader in train_dataloaders:
                    for input_nodes, pos_graph, neg_graph, mfgs in train_dataloader:
                        inputs = mfgs[0].srcdata["feat"]

                        outputs = model(mfgs, inputs)
                        pos_score = predictor(pos_graph, outputs)
                        neg_score = predictor(neg_graph, outputs)
                        
                        score = torch.cat([pos_score, neg_score])
                        label = torch.cat(
                            [torch.ones_like(pos_score), torch.zeros_like(neg_score)]
                        )
                        loss = F.mse_loss(score, label)
                        loss_sum+=loss.item()
                        
                        opt.zero_grad()
                        loss.backward()
                        opt.step()
                        
                model.eval()
                TP, FP, TN, FN = 0,0,0,0
                for test_dataloader in test_dataloaders:
                    _TP, _FP, _TN, _FN = test(test_dataloader, model)
                    TP+=_TP
                    FP+=_FP
                    TN+=_TN
                    FN+=_FN
                P = TP / (TP + FP) if (TP + FP) else 0
                R = TP / (TP + FN) if (TP + FN) else 0
                F1 = 2 * P * R / (P + R) if (P + R) else 0
                
                TP2, FP2, TN2, FN2 = 0,0,0,0
                for train_dataloader in train_dataloaders:
                    _TP2, _FP2, _TN2, _FN2 = test(train_dataloader, model)
                    TP2+=_TP2
                    FP2+=_FP2
                    TN2+=_TN2
                    FN2+=_FN2
                P2 = TP2 / (TP2 + FP2) if (TP2 + FP2) else 0
                R2 = TP2 / (TP2 + FN2) if (TP2 + FN2) else 0
                F2 = 2 * P2 * R2 / (P2 + R2) if (P2 + R2) else 0
                
                tqdm_epoch.set_postfix(loss='{}'.format(loss_sum), P=P, R=R, F1=F1, P2=P2, R2=R2, F2=F2, refresh=False)
                model.train()
                
                result_file.write('{}, {}, {}, {}, {}, {}, {}, {}, {}\n'.format(e, loss_sum, P, R, F1, TP, FP, TN, FN))
                result_file.flush()
                
            result_file.close()
            del model