In [1]:
import dataset
import datetime
from datetime import timedelta
from parser import get_parser
import numpy as np 
import pandas as pd 
import torch
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from torch_geometric.utils import from_networkx, to_undirected
from torch_geometric.data import Data, DataLoader, Dataset
from tqdm import tqdm, tqdm_notebook, trange
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from collections import defaultdict
import random
from xgboost import XGBClassifier
%config Completer.use_jedi = False

## Load Tdata

In [2]:
data = dataset.Ndata(path='../Custom-Semi-Supervised/data/ndata.csv')
parser = get_parser()
args = parser.parse_args(args=
                         ["--data","real-t", 
                          "--sampling","xgb",
                          "--mode","scratch",
                          "--train_from","20140101",
                          "--test_from","20170101",
                          "--test_length","365",
                          "--valid_length","180",
                          "--initial_inspection_rate", "3",
                          "--final_inspection_rate", "10",
                         ])

In [3]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
mode = args.mode
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (30302, 51), Train unlabeled: (979778, 51), Valid labeled: (143509, 51), Valid unlabeled: (0, 26), Test: (274808, 51)
Checking label distribution
Training: 0.052372021949017154
Validation: 0.03918984481922127
Testing: 0.025360899366070794


## Prepare DATA

In [4]:
from utils import *
from pygData_util import *

In [5]:
categories=["importer.id","HS6"]
gdata = GraphData(data,use_xgb=True, categories=categories)

Training XGBoost model...


In [38]:
best_thresh, best_auc = find_best_threshold(gdata.xgb,data.dfvalidx_lab, data.valid_cls_label)
xgb_test_pred = gdata.xgb.predict_proba(data.dfvalidx_lab)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.valid_cls_label,data.valid_reg_label,best_thresh)
print("-"*50)
xgb_test_pred = gdata.xgb.predict_proba(data.dftestx)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.test_cls_label,data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 1436
Precision: 0.1950, Recall: 0.0517, Revenue: 0.1257
Checking top 2% suspicious transactions: 2871
Precision: 0.1358, Recall: 0.0721, Revenue: 0.1731
Checking top 5% suspicious transactions: 7170
Precision: 0.0734, Recall: 0.0972, Revenue: 0.2205
Checking top 10% suspicious transactions: 14351
Precision: 0.0471, Recall: 0.1249, Revenue: 0.2609
--------------------------------------------------
Checking top 1% suspicious transactions: 2749
Precision: 0.1124, Recall: 0.0455, Revenue: 0.0782
Checking top 2% suspicious transactions: 5497
Precision: 0.0708, Recall: 0.0572, Revenue: 0.1080
Checking top 5% suspicious transactions: 13741
Precision: 0.0374, Recall: 0.0756, Revenue: 0.1660
Checking top 10% suspicious transactions: 27478
Precision: 0.1083, Recall: 0.4378, Revenue: 0.3609


In [7]:
stage = "train_lab"
trainLab_data = gdata.get_data(stage)
train_nodeidx = torch.tensor(gdata.get_AttNode(stage))
trainLab_data.node_idx = train_nodeidx

In [8]:
stage = "train_unlab"
unlab_data = gdata.get_data(stage)
unlab_nodeidx = torch.tensor(gdata.get_AttNode(stage))
unlab_data.node_idx = unlab_nodeidx

In [9]:
stage = "valid"
valid_data = gdata.get_data(stage)
valid_nodeidx = torch.tensor(gdata.get_AttNode(stage))
valid_data.node_idx = valid_nodeidx

In [10]:
stage = "test"
test_data = gdata.get_data(stage)
test_nodeidx = torch.tensor(gdata.get_AttNode(stage))
test_data.node_idx = test_nodeidx

In [11]:
trainLab_data

Data(edge_attr=[60604], edge_index=[2, 121208], edge_label=[121208], node_idx=[30302], rev=[45573], x=[45573, 100], y=[45573])

In [12]:
valid_data

Data(edge_attr=[287018], edge_index=[2, 574036], edge_label=[574036], node_idx=[143509], rev=[153663], x=[153663, 100], y=[153663])

In [13]:
test_data

Data(edge_attr=[549616], edge_index=[2, 1099232], edge_label=[1099232], node_idx=[274808], rev=[292483], x=[292483, 100], y=[292483])

## Model

In [14]:
from models import *

In [15]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchtools.optim import RangerLars
from pytorch_lightning.loggers import TensorBoardLogger
import torch

In [16]:
class SSLGNN(LightningModule):
    def __init__(self, data, input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.save_hyperparameters()
        self.data = data
        self.input_dim = input_dim
        self.dim = hidden_dim*2
        self.numLayers = numLayers
        self.layers = [self.dim, self.dim//2] #* (numLayers+1)
        self.bn = nn.BatchNorm1d(self.dim)
        self.act = Mish()
        self.useXGB = useXGB
        
        # GNN layer
        if self.useXGB:
            self.initEmbedding = nn.Embedding(self.input_dim, self.dim, padding_idx=0)
        else:
            self.initEmbedding = MLP(self.input_dim, self.dim, Numlayer=2)
        self.initGNN = UselessConv()
        self.GNNs = GNNStack(self.layers,self.numLayers)
        
        # output
        self.outLayer = nn.ModuleList([nn.Linear(self.dim,1) for _ in range(numLayers+1)])
        self.revLayer = nn.ModuleList([nn.Linear(self.dim,1) for _ in range(numLayers+1)])
        self.combined = nn.Linear(numLayers+1, 1, bias=False)
        self.combinedRev = nn.Linear(numLayers+1, 1, bias=False)
        self.loss_func = FocalLoss(logits=True)

    def forward(self, x,adjs):
        # update node embedding
        leaf_emb = self.initEmbedding(x)
        if self.useXGB:
            leaf_emb = torch.sum(leaf_emb,dim=1) # summation over the trees
            leaf_emb = self.bn(leaf_emb)
            leaf_emb = self.act(leaf_emb)
        
        # first update 
        firstHop_neighbor = adjs[-2][0]
        leaf_emb = self.initGNN(leaf_emb,to_undirected(firstHop_neighbor))
        
        # GNN 
        embeddings = self.GNNs(leaf_emb, adjs)
        
        # logits
        logits = [self.outLayer[i](v) for i,v in enumerate(embeddings)]
        ensemble = torch.cat(logits, dim=-1)
        ensemble = self.combined(ensemble)
        logits.append(ensemble)
        
        # revenue
        revenues = [torch.relu(self.revLayer[i](v)) for i,v in enumerate(embeddings)]
        ensemble = torch.cat(revenues, dim=-1)
        ensemble = self.combinedRev(ensemble)
        revenues.append(ensemble)
        
        return logits, revenues
    
    def compute_CLS_loss(self,Logits, label):
        loss = 0
        for logit in Logits:
            logit = logit.flatten()
            l = self.loss_func(logit,label)
            loss+= l
        return loss
    
    def compute_REG_loss(self,preds, rev):
        loss = 0
        for pred in preds:
            pred = pred.flatten()
            l = F.mse_loss(pred,rev)
            loss += l
        return loss 

    def training_step(self, batch, batch_idx: int):
        target_idx = torch.arange(batch.y.shape[0])
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        train_loss = CLS_loss + 10 * REG_loss
        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx: int):
        target_idx = torch.arange(batch.y.shape[0])
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        valid_loss = CLS_loss + 1 * REG_loss
        self.log('val_loss', valid_loss, on_step=True, on_epoch=True, sync_dist=True)
        return logits[-1]
    
    def validation_epoch_end(self, val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.valid_cls_label, self.data.valid_reg_label)
        f1_top = np.mean(f)
        self.log("F1-top",f1_top)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Val/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        return {"Val/F1-top":f1_top, "log":tensorboard_logs}
        
    def test_step(self,batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self,val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.test_cls_label, self.data.test_reg_label)
        f1_top = np.mean(f)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Test/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        
        return {"F1-top":f1_top, "log":tensorboard_logs}

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.01, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]


In [45]:
class SSLGNN(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.save_hyperparameters()
        self.input_dim = input_dim
        self.dim = hidden_dim*2
        self.numLayers = numLayers
        self.layers = [self.dim, self.dim//2] #* (numLayers+1)
        self.bn = nn.BatchNorm1d(self.dim)
        self.act = Mish()
        self.useXGB = useXGB
        
        # GNN layer
        if self.useXGB:
            self.initEmbedding = nn.Embedding(self.input_dim, self.dim, padding_idx=0)
        else:
            self.initEmbedding = MLP(self.input_dim, self.dim, Numlayer=2)
        self.initGNN = UselessConv()
        self.GNNs = GNNStack(self.layers,self.numLayers)
        
        # output
        self.outLayer = nn.ModuleList([nn.Linear(self.dim,1) for _ in range(numLayers+1)])
        self.revLayer = nn.ModuleList([nn.Linear(self.dim,1) for _ in range(numLayers+1)])
        self.combined = nn.Linear(numLayers+1, 1, bias=False)
        self.combinedRev = nn.Linear(numLayers+1, 1, bias=False)
        self.loss_func = FocalLoss(logits=True)

    def forward(self, x,adjs):
        # update node embedding
        leaf_emb = self.initEmbedding(x)
        if self.useXGB:
            leaf_emb = torch.sum(leaf_emb,dim=1) # summation over the trees
            leaf_emb = self.bn(leaf_emb)
            leaf_emb = self.act(leaf_emb)
        
        # first update 
        firstHop_neighbor = adjs[-2][0]
        leaf_emb = self.initGNN(leaf_emb,to_undirected(firstHop_neighbor))
        
        # GNN 
        embeddings = self.GNNs(leaf_emb, adjs)
        
        # logits
        logits = [self.outLayer[i](v) for i,v in enumerate(embeddings)]
        ensemble = torch.cat(logits, dim=-1)
        ensemble = self.combined(ensemble)
        logits.append(ensemble)
        
        # revenue
        revenues = [torch.relu(self.revLayer[i](v)) for i,v in enumerate(embeddings)]
        ensemble = torch.cat(revenues, dim=-1)
        ensemble = self.combinedRev(ensemble)
        revenues.append(ensemble)
        
        return logits, revenues
    
    def compute_CLS_loss(self,Logits, label):
        loss = 0
        for logit in Logits:
            logit = logit.flatten()
            l = self.loss_func(logit,label)
            loss+= l
        return loss
    
    def compute_REG_loss(self,preds, rev):
        loss = 0
        for pred in preds:
            pred = pred.flatten()
            l = F.mse_loss(pred,rev)
            loss += l
        return loss 

    def training_step(self, batch, batch_idx: int):
        target_idx = torch.arange(batch.y.shape[0])
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        train_loss = CLS_loss + 10 * REG_loss
        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx: int):
        target_idx = torch.arange(batch.y.shape[0])
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        valid_loss = CLS_loss + 1 * REG_loss
        self.log('val_loss', valid_loss, on_step=True, on_epoch=True, sync_dist=True)
        return logits[-1]
    
    def validation_epoch_end(self, val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.valid_cls_label, self.data.valid_reg_label)
        f1_top = np.mean(f)
        self.log("F1-top",f1_top)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Val/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        return {"Val/F1-top":f1_top, "log":tensorboard_logs}
        
    def test_step(self,batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self,val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.test_cls_label, self.data.test_reg_label)
        f1_top = np.mean(f)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Test/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        
        return {"F1-top":f1_top, "log":tensorboard_logs}

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.05, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]

In [46]:
# model config
seed_everything(1234)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [-1,200]
numLayers = len(sizes)
batch_size = 512

model = SSLGNN(input_dim, hidden_size, numLayers, useXGB=gdata.use_xgb)
model.data = data

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = CustomData(stacked_data, sizes = sizes, batch_size=batch_size)
logger = TensorBoardLogger("ssl_exp",name="SSL_GNN")
logger.log_hyperparams(model.hparams, metrics={"F1-top":0})
checkpoint_callback = ModelCheckpoint(
    monitor='F1-top',    
    dirpath='./saved_model',
    filename='GNN-{epoch:02d}-{F1-top:.4f}',
    save_top_k=1,
    mode='max',
)
trainer = Trainer(gpus=[3], max_epochs=20,
                  logger = logger,
                 num_sanity_val_steps=0,
                  check_val_every_n_epoch=1,
                  callbacks=[checkpoint_callback],
#                   fast_dev_run=True
                 )
trainer.fit(model, datamodule=datamodule)

GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type        | Params
----------------------------------------------
0 | bn            | BatchNorm1d | 128   
1 | act           | Mish        | 0     
2 | initEmbedding | Embedding   | 74.0 K
3 | initGNN       | UselessConv | 0     
4 | GNNs          | GNNStack    | 33.5 K
5 | outLayer      | ModuleList  | 195   
6 | revLayer      | ModuleList  | 195   
7 | combined      | Linear      | 3     
8 | combinedRev   | Linear      | 3     
9 | loss_func     | FocalLoss   | 0     
----------------------------------------------
108 K     Trainable params
0         Non-trainable params
108 K     Total params
INFO:lightning:
  | Name          | Type        | Params
-------------------------------

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1436
Precision: 0.0815, Recall: 0.0216, Revenue: 0.0146
Checking top 2% suspicious transactions: 2871
Precision: 0.0662, Recall: 0.0351, Revenue: 0.0215
Checking top 5% suspicious transactions: 7175
Precision: 0.0461, Recall: 0.0612, Revenue: 0.1087
Checking top 10% suspicious transactions: 14351
Precision: 0.0427, Recall: 0.1133, Revenue: 0.2013


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1436
Precision: 0.0843, Recall: 0.0224, Revenue: 0.0387
Checking top 2% suspicious transactions: 2871
Precision: 0.0711, Recall: 0.0377, Revenue: 0.0851
Checking top 5% suspicious transactions: 7176
Precision: 0.0492, Recall: 0.0652, Revenue: 0.1267
Checking top 10% suspicious transactions: 14351
Precision: 0.0415, Recall: 0.1099, Revenue: 0.2488


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1436
Precision: 0.0884, Recall: 0.0235, Revenue: 0.0405
Checking top 2% suspicious transactions: 2871
Precision: 0.0714, Recall: 0.0379, Revenue: 0.0853
Checking top 5% suspicious transactions: 7176
Precision: 0.0497, Recall: 0.0660, Revenue: 0.1371
Checking top 10% suspicious transactions: 14351
Precision: 0.0419, Recall: 0.1112, Revenue: 0.2414


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

Checking top 1% suspicious transactions: 2749<br>
Precision: 0.1124, Recall: 0.0455, Revenue: 0.0782<br>
Checking top 2% suspicious transactions: 5497<br>
Precision: 0.0708, Recall: 0.0572, Revenue: 0.1080<br>
Checking top 5% suspicious transactions: 13741<br>
Precision: 0.0374, Recall: 0.0756, Revenue: 0.1660<br>
Checking top 10% suspicious transactions: 27478<br>
Precision: 0.1083, Recall: 0.4378, Revenue: 0.3609

In [32]:
checkpoint_callback.best_model_path

'/dhome/roytsai/gnn_wco/saved_model/GNN-epoch=15-F1-top=0.0991.ckpt'

In [33]:
testing_summary = trainer.test(model, test_dataloaders=datamodule.test_dataloader())

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

Checking top 1% suspicious transactions: 2749
Precision: 0.1131, Recall: 0.0458, Revenue: 0.0792
Checking top 2% suspicious transactions: 5497
Precision: 0.0737, Recall: 0.0596, Revenue: 0.1121
Checking top 5% suspicious transactions: 13741
Precision: 0.0595, Recall: 0.1202, Revenue: 0.1867
Checking top 10% suspicious transactions: 27481
Precision: 0.0785, Recall: 0.3175, Revenue: 0.2862

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'F1-top': 0.08412887078527967,
 'Test/F1@1': 0.0651581814372512,
 'Test/F1@10': 0.12591166345761126,
 'Test/F1@2': 0.06588579795021962,
 'Test/F1@5': 0.0795598402960366,
 'Test/Pr@1': 0.11313204801746089,
 'Test/Pr@10': 0.07852698227866525,
 'Test/Pr@2': 0.07367655084591596,
 'Test/Pr@5': 0.05945709919219853,
 'Test/Re@1': 0.0457554803589819,
 'Test/Re@10': 0.31749301162277477,
 'Test/Re@2': 0.05958511107841695,
 'Test/Re@5': 0.12020008827423863,
 'Test/Rev@1': 0.0791594356297533,
 'Test/Rev@10'

In [None]:
# transform summary as dataframe for saving csv
# df_summary = pd.DataFrame(testing_summary)
# df_summary = df_summary[[i for i in df_summary.columns if "Test" in i]]
# df_summary