In [1]:
import dataset
import datetime
from datetime import timedelta
from parser import get_parser
import numpy as np 
import pandas as pd 
import torch
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from torch_geometric.utils import from_networkx, to_undirected
from torch_geometric.data import Data, DataLoader, Dataset
from tqdm import tqdm, tqdm_notebook, trange
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from collections import defaultdict
import random

## Load Tdata

In [2]:
data = dataset.Tdata(path='./data/tdata.csv')
parser = get_parser()
args = parser.parse_args(args=
                         ["--data","real-t", 
                          "--sampling","xgb",
                          "--mode","scratch",
                          "--train_from","20170101",
                          "--test_from","20190101",
                          "--test_length","365",
                          "--valid_length","90",
                          "--initial_inspection_rate", "2",
                          "--final_inspection_rate", "10",
                         ])

In [3]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
mode = args.mode
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (30957, 41), Train unlabeled: (1516868, 41), Valid labeled: (134457, 41), Valid unlabeled: (0, 13), Test: (703090, 41)
Checking label distribution
Training: 0.0979606313176095
Validation: 0.09589052260946108
Testing: 0.10476480792437651


## Prepare DATA

In [4]:
# re-label hscode
encoder = LabelEncoder()
encoder.fit(data.dftrainx_lab["HS6"])
num_hs = len(encoder.classes_)

# build transaction-hs bipartite graph
G = nx.Graph()
hs_nodes = encoder.transform(data.dftrainx_lab["HS6"])
transaction_nodes = np.array(range(data.dftrainx_lab.shape[0])) + num_hs
labeled_nodes = np.array(range(data.dftrainx_lab.shape[0])) + num_hs
train_edges = list(zip(hs_nodes,transaction_nodes))
G.add_edges_from(train_edges)
print("Number of Nodes: %d, Number of Edges: %d" % (G.number_of_nodes(), G.number_of_edges()))

# node feature
scaler = MinMaxScaler()
transaction_feature = scaler.fit_transform(data.dftrainx_lab.values)
feature_dim = data.dftrainx_lab.shape[1]

# init hs node embedding with zeros(only receive information from transaction)
# refer to this paper https://arxiv.org/pdf/2011.12193.pdf
nodeFeature = np.zeros((num_hs, feature_dim)) 
nodeFeature = np.vstack((nodeFeature,transaction_feature))
print("Check node feature size:",nodeFeature.shape)

assert feature_dim == nodeFeature.shape[1]
assert G.number_of_nodes() == nodeFeature.shape[0]

Number of Nodes: 34454, Number of Edges: 30957
Check node feature size: (34454, 29)


In [42]:
# convert data to PyG Dataloader format
train_data = from_networkx(G)
node_feature = torch.FloatTensor(nodeFeature)
train_y = torch.FloatTensor(data.train_cls_label)
train_data.x = node_feature
train_data.y = train_y
train_data.label_idx = torch.tensor(labeled_nodes) # record the transaction node index
train_loader = DataLoader([train_data]*10,batch_size=1) # duplicate the graph for 10 times

In [38]:
# construct validation data
valid_batch_size = 256
batch_start = 0
validTranscation_feature = scaler.transform(data.dfvalidx_lab.values)
data_list = []

for batch_idx in trange(validTranscation_feature.shape[0] //valid_batch_size + 1):
    # build graph
    origin_nodeNum = G.number_of_nodes()
    unseen_idx = G.number_of_nodes()
    valid_hsnode = data.dfvalidx_lab["HS6"][batch_start:batch_start+valid_batch_size].values
    unseen_HS = 0
    for i in range(len(valid_hsnode)):
        if valid_hsnode[i] in encoder.classes_:
            valid_hsnode[i] = encoder.transform([valid_hsnode[i]])[0]
        else:
            valid_hsnode[i] = unseen_idx
            unseen_idx +=1
            unseen_HS += 1
    validTr_id = list(range(unseen_idx, unseen_idx+valid_batch_size))
    valid_edges = list(zip(validTr_id,valid_hsnode))
    G.add_edges_from(valid_edges)
    
    # node feautres
    current_batch = validTranscation_feature[batch_start:batch_start+valid_batch_size,:]
    current_feature = np.zeros((unseen_HS,feature_dim))
    current_feature = torch.FloatTensor(np.vstack((current_feature, current_batch)))
    current_feature = torch.cat((node_feature,current_feature), dim=0)
    
    # pyG data
    valid_Data = Data()
    valid_Data.edge_index = to_undirected(torch.LongTensor(list(G.edges())).T)
    valid_Data.x = current_feature
    valid_Data.y = torch.FloatTensor(data.valid_cls_label[batch_start:batch_start+valid_batch_size])
    valid_Data.label_idx = torch.arange(origin_nodeNum + unseen_HS, origin_nodeNum + unseen_HS + valid_Data.y.shape[0])
    G.remove_edges_from(valid_edges)
    G.remove_nodes_from(list(range(origin_nodeNum,G.number_of_nodes())))
    data_list.append(valid_Data)
    batch_start+= valid_batch_size

100%|██████████| 526/526 [06:06<00:00,  1.43it/s]


In [39]:
valid_loader = DataLoader(data_list, batch_size=1)

## Model

In [8]:
import torch
import torch.optim as optim
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchtools.optim import RangerLars
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
from torch_geometric.nn import SAGEConv, GATConv, GCNConv

In [9]:
class GNN(nn.Module):
    def __init__(self,inDim, outDim):
        '''
        Basic GNN model.
        '''
        super(GNN,self).__init__()
        self.gnn = SAGEConv(inDim, outDim)
        self.norm = nn.LayerNorm(outDim)
        self.act = nn.ReLU()
        
    def forward(self,x,edge_index):
        feature = self.gnn(x,edge_index)
        feature = self.act(self.norm(feature))
        return feature

class GNNStack(nn.Module):
    def __init__(self,layer_dims):
        '''
        Create a model that stacks multi-layer GNN.
        The embedding obtainable are concatenated as output. 
        '''
        super(GNNStack,self).__init__()
        self.gnns = nn.ModuleList([])
        for i in range(len(layer_dims) -1):
            gnn_module = GNN(layer_dims[i], layer_dims[i+1])
            self.gnns.append(gnn_module)
    
    def forward(self,x,edge_index):
        features = [x] # raw feature
        current_feature = x
        for gnn in self.gnns:
            current_feature = gnn(current_feature,edge_index) # update node embedding
            features.append(current_feature)
    
        return torch.cat(features,dim=-1) # concat node embedding of each layer

In [10]:
def metrics(y_prob,xgb_testy,revenue_test, args, best_thresh=None, display=True):
    """ Evaluate the performance"""
    pr, re, f, rev = [], [], [], []
    # For validatation, we measure the performance on 5% (previously, 1%, 2%, 5%, and 10%)
    for i in [99,98,95,90]: 
        threshold = np.percentile(y_prob, i)
        precision = xgb_testy[y_prob > threshold].mean()
        recall = sum(xgb_testy[y_prob > threshold])/ sum(xgb_testy)
        revenue = sum(revenue_test[y_prob > threshold]) / sum(revenue_test)
        f1 = 2 * (precision * recall) / (precision + recall)
        if display:
            print(f'Checking top {100-i}% suspicious transactions: {len(y_prob[y_prob > threshold])}')
            print('Precision: %.4f, Recall: %.4f, Revenue: %.4f' % (precision, recall, revenue))
        # save results
        pr.append(precision)
        re.append(recall)
        f.append(f1)
        rev.append(revenue)
    return f, pr, re, rev

In [44]:
class SSLGNN(pl.LightningModule):
    def __init__(self,layers):
        super(SSLGNN,self).__init__()
        self.gnn = GNNStack(layers) # GNN Layer
        self.cls_layer = nn.Linear(sum(layers),1) # output layer
        self.lr = 0.001
        self.l2 = 0.0001
        self.epochs = 200
        self._weight_init()
        
        # hyperparameters
        hparam = {"lr":self.lr, "l2":self.l2,"epoch":self.epochs}
        self.hparams = hparam
    
    def _weight_init(self):
        for p in self.parameters():
            if p.dim() > 1 and p.requires_grad:
                nn.init.kaiming_normal_(p)

    def forward(self, data):
        x, edge_index = data.x , data.edge_index
        gnn_feature = self.gnn(x,edge_index)
        out = torch.sigmoid(self.cls_layer(gnn_feature))
        return out

    def training_step(self, batch, batch_idx):
        data = batch
        out = self(data)
        out = out[data.label_idx,:] # This step selects the output of Transaction nodes
        loss = F.binary_cross_entropy(out.flatten(), data.y.flatten())
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self,batch, batch_idx):
        data = batch
        out = self(data)
        out = out[data.label_idx,:]
        loss = F.binary_cross_entropy(out.flatten(), data.y.flatten())
        self.log("val_loss",loss)
        return out
    
    def validation_epoch_end(self, val_step_outputs):
        y_prob = []
        for pred in val_step_outputs:
            prob = pred.detach().cpu().numpy().ravel().tolist()
            y_prob.extend(prob)
        y_prob = np.array(y_prob)
        f,pr, re, rev = metrics(y_prob, data.valid_cls_label,data.valid_reg_label,args)

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=self.lr, weight_decay=self.l2)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]
    
    def train_dataloader(self):
        return train_loader
    
    def val_dataloader(self):
        return valid_loader

In [46]:
# building model
layers = [feature_dim,32,32] # 3-layer GNN
model = SSLGNN(layers)

# setting configs and logger for training (Can ignore this part)
logger = TensorBoardLogger("ssl_exp",name="SSL_GNN")
logger.log_hyperparams(model.hparams, metrics={"F1-top":0})
checkpoint_callback = ModelCheckpoint(
    monitor='F1-top',
    dirpath='./saved_model',
    filename='SSLMLP-{epoch:02d}-{F1-top:.4f}',
    save_top_k=0,
    mode='max',
)
early_stopping = EarlyStopping(monitor='val_loss', patience=10)
trainer = pl.Trainer(max_epochs=model.epochs,gpus=[0],
#                      fast_dev_run=True,
                     num_sanity_val_steps=0,
                     check_val_every_n_epoch=10
                    )

# training
trainer.fit(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type     | Params
---------------------------------------
0 | gnn       | GNNStack | 4.1 K 
1 | cls_layer | Linear   | 94    
---------------------------------------
4.2 K     Trainable params
0         Non-trainable params
4.2 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.1903, Recall: 0.0218, Revenue: 0.0239
Checking top 2% suspicious transactions: 2690
Precision: 0.2112, Recall: 0.0483, Revenue: 0.0431
Checking top 5% suspicious transactions: 6723
Precision: 0.1575, Recall: 0.0900, Revenue: 0.0775
Checking top 10% suspicious transactions: 13446
Precision: 0.1470, Recall: 0.1680, Revenue: 0.1308


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.2632, Recall: 0.0301, Revenue: 0.0237
Checking top 2% suspicious transactions: 2690
Precision: 0.2405, Recall: 0.0550, Revenue: 0.0520
Checking top 5% suspicious transactions: 6723
Precision: 0.1946, Recall: 0.1112, Revenue: 0.0903
Checking top 10% suspicious transactions: 13446
Precision: 0.1768, Recall: 0.2020, Revenue: 0.1599


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.3323, Recall: 0.0380, Revenue: 0.0351
Checking top 2% suspicious transactions: 2690
Precision: 0.3160, Recall: 0.0722, Revenue: 0.0693
Checking top 5% suspicious transactions: 6723
Precision: 0.2945, Recall: 0.1683, Revenue: 0.1519
Checking top 10% suspicious transactions: 13445
Precision: 0.2164, Recall: 0.2473, Revenue: 0.2085


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.4981, Recall: 0.0569, Revenue: 0.0508
Checking top 2% suspicious transactions: 2690
Precision: 0.4349, Recall: 0.0994, Revenue: 0.0948
Checking top 5% suspicious transactions: 6723
Precision: 0.3481, Recall: 0.1989, Revenue: 0.1883
Checking top 10% suspicious transactions: 13446
Precision: 0.2622, Recall: 0.2996, Revenue: 0.2758


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5762, Recall: 0.0659, Revenue: 0.0566
Checking top 2% suspicious transactions: 2690
Precision: 0.4874, Recall: 0.1114, Revenue: 0.1113
Checking top 5% suspicious transactions: 6723
Precision: 0.3650, Recall: 0.2086, Revenue: 0.2027
Checking top 10% suspicious transactions: 13446
Precision: 0.2733, Recall: 0.3124, Revenue: 0.2967


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6000, Recall: 0.0686, Revenue: 0.0616
Checking top 2% suspicious transactions: 2690
Precision: 0.5019, Recall: 0.1147, Revenue: 0.1201
Checking top 5% suspicious transactions: 6723
Precision: 0.3646, Recall: 0.2083, Revenue: 0.2065
Checking top 10% suspicious transactions: 13446
Precision: 0.2836, Recall: 0.3241, Revenue: 0.3096


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6022, Recall: 0.0688, Revenue: 0.0600
Checking top 2% suspicious transactions: 2690
Precision: 0.5156, Recall: 0.1179, Revenue: 0.1169
Checking top 5% suspicious transactions: 6723
Precision: 0.3643, Recall: 0.2082, Revenue: 0.2007
Checking top 10% suspicious transactions: 13446
Precision: 0.2825, Recall: 0.3228, Revenue: 0.3127


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5896, Recall: 0.0674, Revenue: 0.0575
Checking top 2% suspicious transactions: 2690
Precision: 0.5164, Recall: 0.1181, Revenue: 0.1165
Checking top 5% suspicious transactions: 6723
Precision: 0.3650, Recall: 0.2086, Revenue: 0.2002
Checking top 10% suspicious transactions: 13446
Precision: 0.2826, Recall: 0.3230, Revenue: 0.3149


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5851, Recall: 0.0669, Revenue: 0.0580
Checking top 2% suspicious transactions: 2690
Precision: 0.5097, Recall: 0.1165, Revenue: 0.1094
Checking top 5% suspicious transactions: 6723
Precision: 0.3653, Recall: 0.2088, Revenue: 0.2023
Checking top 10% suspicious transactions: 13446
Precision: 0.2877, Recall: 0.3289, Revenue: 0.3227


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5903, Recall: 0.0675, Revenue: 0.0620
Checking top 2% suspicious transactions: 2690
Precision: 0.5037, Recall: 0.1152, Revenue: 0.1150
Checking top 5% suspicious transactions: 6723
Precision: 0.3683, Recall: 0.2105, Revenue: 0.2112
Checking top 10% suspicious transactions: 13446
Precision: 0.2912, Recall: 0.3328, Revenue: 0.3301


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5866, Recall: 0.0671, Revenue: 0.0632
Checking top 2% suspicious transactions: 2690
Precision: 0.4892, Recall: 0.1119, Revenue: 0.1088
Checking top 5% suspicious transactions: 6723
Precision: 0.3626, Recall: 0.2072, Revenue: 0.2088
Checking top 10% suspicious transactions: 13446
Precision: 0.2915, Recall: 0.3332, Revenue: 0.3300



1