In [1]:
import dataset
import datetime
from datetime import timedelta
from parser import get_parser
import numpy as np 
import pandas as pd 
import torch
from pytorch_lightning import Trainer
from torch_geometric.utils import to_undirected

## Load data

In [2]:
data = dataset.Ndata(path='../Custom-Semi-Supervised/data/ndata.csv')
parser = get_parser()
args = parser.parse_args(args=
                         ["--data","real-n", 
                          "--sampling","xgb",
                          "--mode","scratch",
                          "--train_from","20140101",
                          "--test_from","20170101",
                          "--test_length","365",
                          "--valid_length","90",
                          "--initial_inspection_rate", "5",
                          "--final_inspection_rate", "10",
                         ])

In [3]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
mode = args.mode
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (54134, 52), Train unlabeled: (1028538, 52), Valid labeled: (70917, 52), Valid unlabeled: (0, 26), Test: (274808, 52)
Checking label distribution
Training: 0.05022795615481618
Validation: 0.035556788645191434
Testing: 0.025360899366070794


## Prepare DATA

In [4]:
from utils import *
from pygData_util import *

In [5]:
categories=["importer.id","HS6"]
gdata = GraphData(data,use_xgb=True, categories=categories)

Training XGBoost model...


In [6]:
best_thresh, best_auc = find_best_threshold(gdata.xgb,data.dfvalidx_lab, data.valid_cls_label)
xgb_test_pred = gdata.xgb.predict_proba(data.dfvalidx_lab)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.valid_cls_label,data.valid_reg_label,best_thresh)
print("-"*50)
xgb_test_pred = gdata.xgb.predict_proba(data.dftestx)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.test_cls_label,data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 710
Precision: 0.2915, Recall: 0.0850, Revenue: 0.0443
Checking top 2% suspicious transactions: 1419
Precision: 0.1677, Recall: 0.0977, Revenue: 0.0501
Checking top 5% suspicious transactions: 3546
Precision: 0.0832, Recall: 0.1211, Revenue: 0.0844
Checking top 10% suspicious transactions: 7092
Precision: 0.0495, Recall: 0.1441, Revenue: 0.1336
--------------------------------------------------
Checking top 1% suspicious transactions: 2749
Precision: 0.1404, Recall: 0.0568, Revenue: 0.0942
Checking top 2% suspicious transactions: 5497
Precision: 0.0871, Recall: 0.0705, Revenue: 0.1269
Checking top 5% suspicious transactions: 13741
Precision: 0.0442, Recall: 0.0893, Revenue: 0.1859
Checking top 10% suspicious transactions: 27481
Precision: 0.0888, Recall: 0.3590, Revenue: 0.3713


In [7]:
stage = "train_lab"
trainLab_data = gdata.get_data(stage)
train_nodeidx = torch.tensor(gdata.get_AttNode(stage))
trainLab_data.node_idx = train_nodeidx

In [8]:
stage = "train_unlab"
unlab_data = gdata.get_data(stage)
unlab_nodeidx = torch.tensor(gdata.get_AttNode(stage))
unlab_data.node_idx = unlab_nodeidx

In [9]:
stage = "valid"
valid_data = gdata.get_data(stage)
valid_nodeidx = torch.tensor(gdata.get_AttNode(stage))
valid_data.node_idx = valid_nodeidx

In [10]:
stage = "test"
test_data = gdata.get_data(stage)
test_nodeidx = torch.tensor(gdata.get_AttNode(stage))
test_data.node_idx = test_nodeidx

In [11]:
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)

## New Sampler

In [12]:
from torch_cluster import random_walk
from torch_geometric.data import NeighborSampler 
from pytorch_lightning import LightningDataModule

In [13]:
class UnsupSampler(NeighborSampler):
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(UnsupSampler, self).sample(batch)

In [14]:
class Batch(NamedTuple):
    '''
    convert batch data for pytorch-lightning
    '''
    x: Tensor
    y: Tensor
    rev: Tensor
    adjs_t: NamedTuple
    def to(self, *args, **kwargs):
        return Batch(
            x=self.x.to(*args, **kwargs),
            y=self.y.to(*args, **kwargs),
            rev=self.rev.to(*args, **kwargs),
            adjs_t=[(adj_t.to(*args, **kwargs), eid.to(*args, **kwargs), size) for adj_t, eid, size in self.adjs_t],
        )


In [15]:
class UnsupData(LightningDataModule):
    def __init__(self,data,sizes, batch_size = 128):
        '''
        defining dataloader with NeighborSampler to extract k-hop subgraph.
        Args:
            data (Graphdata): graph data for the edges and node index
            sizes ([int]): The number of neighbors to sample for each node in each layer. 
                           If set to :obj:`sizes[l] = -1`, all neighbors are included
            batch_size (int): batch size for training
        '''
        super(UnsupData,self).__init__()
        self.data = data
        self.sizes = sizes
        self.valid_sizes = [-1 for i in self.sizes]
        self.batch_size = batch_size

    def train_dataloader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=True,num_workers=8)
    
    def test_dataloader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.test_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)
    
    def label_loader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.train_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)

    def convert_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.data.x[n_id],
            y=self.data.y[n_id[:batch_size]],
            rev = self.data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )

## Model

In [16]:
from models import MLP, GNNStack, UselessConv, Mish

In [17]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import LightningModule, seed_everything
from torchtools.optim import RangerLars
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class PretrainGNN(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.save_hyperparameters()
        self.input_dim = input_dim
        self.dim = hidden_dim*2
        self.numLayers = numLayers
        self.layers = [self.dim, self.dim//2] #* (numLayers+1)
        self.bn = nn.BatchNorm1d(self.dim)
        self.act = Mish()
        self.useXGB = useXGB
        
        # GNN layer
        if self.useXGB:
            self.initEmbedding = nn.Embedding(self.input_dim, self.dim, padding_idx=0)
        else:
            self.initEmbedding = MLP(self.input_dim, self.dim, Numlayer=2)
        self.initGNN = UselessConv()
        self.GNNs = GNNStack(self.layers,self.numLayers)

    def forward(self, x,adjs):
        # update node embedding
        leaf_emb = self.initEmbedding(x)
        if self.useXGB:
            leaf_emb = torch.sum(leaf_emb,dim=1) # summation over the trees
            leaf_emb = self.bn(leaf_emb)
            leaf_emb = self.act(leaf_emb)
        
        # first update 
        firstHop_neighbor = adjs[-1][0]
        leaf_emb = self.initGNN(leaf_emb,to_undirected(firstHop_neighbor))
        
        # GNN 
        embeddings = self.GNNs(leaf_emb, adjs)
        
        return embeddings[-1]
    
    def training_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
        neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
        train_loss = -pos_loss - neg_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def test_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        return out
    
    def test_epoch_end(self, val_step_outputs):
        val_step_outputs = torch.cat(val_step_outputs)
        val_step_outputs = val_step_outputs.cpu().detach().numpy()
        return {"log":{"predictions":val_step_outputs}}
    
    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.01, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.8)
        return [optimizer], [scheduler]

In [19]:
# model config
seed_everything(2345)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [50,20]
numLayers = len(sizes)
batch_size = 1024

model = PretrainGNN(input_dim, hidden_size, numLayers, useXGB=gdata.use_xgb)

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = UnsupData(stacked_data, sizes = sizes, batch_size=batch_size)
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',    
    dirpath='./saved_model',
    filename='Tdata5-pretrain-{train_loss:.4f}',
    save_top_k=1,
    mode='min',
)
trainer = Trainer(gpus=[0], max_epochs=3,
#                   callbacks=[checkpoint_callback],
                 )
trainer.fit(model, train_dataloader=datamodule.train_dataloader(),
           )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type        | Params
----------------------------------------------
0 | bn            | BatchNorm1d | 128   
1 | act           | Mish        | 0     
2 | initEmbedding | Embedding   | 86.1 K
3 | initGNN       | UselessConv | 0     
4 | GNNs          | GNNStack    | 33.5 K
----------------------------------------------
119 K     Trainable params
0         Non-trainable params
119 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1

In [20]:
trainer.test()

1

In [21]:
# trainer.save_checkpoint("./saved_model/pretrained.ckpt")

In [22]:
from tqdm import tqdm_notebook

In [23]:
test_embeddings = []
for batch in tqdm_notebook(datamodule.test_dataloader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    test_embeddings.append(out.cpu().detach().numpy())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=269.0), HTML(value='')))




In [24]:
embeddings = []
for batch in tqdm_notebook(datamodule.label_loader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    embeddings.append(out.cpu().detach().numpy())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=53.0), HTML(value='')))




In [25]:
embeddings = np.concatenate(embeddings)
test_embeddings = np.concatenate(test_embeddings)

In [26]:
from sklearn.linear_model import LogisticRegression

In [30]:
lr = LogisticRegression(class_weight={1:50,0:1})
lr.fit(embeddings,data.train_cls_label)

LogisticRegression(class_weight={0: 1, 1: 50})

In [31]:
y_prob = lr.predict_proba(test_embeddings)[:,1]

In [32]:
_ = metrics(y_prob, data.test_cls_label,data.test_reg_label,None)

Checking top 1% suspicious transactions: 2749
Precision: 0.0382, Recall: 0.0154, Revenue: 0.0240
Checking top 2% suspicious transactions: 5497
Precision: 0.0326, Recall: 0.0263, Revenue: 0.0348
Checking top 5% suspicious transactions: 13741
Precision: 0.0313, Recall: 0.0633, Revenue: 0.0959
Checking top 10% suspicious transactions: 27481
Precision: 0.0309, Recall: 0.1248, Revenue: 0.2028


## Fine tuning

In [33]:
from torch_geometric.nn import GATConv,TransformerConv
from pygData_util import *

In [34]:
class Predictor(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.gnn_encoder = PretrainGNN(input_dim, hidden_size, numLayers, useXGB)
        self.dim = hidden_dim * 2
        
        # output
        self.clsLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.revLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.loss_func = nn.BCEWithLogitsLoss(pos_weight = torch.tensor([20])) #FocalLoss(logits=True)
        
    def load_fromPretrain(self,path):
        self.gnn_encoder.load_from_checkpoint(path)
        
    def loadGNN_state(self,model):
        self.gnn_encoder.load_state_dict(model.state_dict())

    def forward(self, x,adjs):
#         firstHop_neighbor = adjs[-1][0]
        # get node embedding from pre-trained model
        embedding = self.gnn_encoder(x,adjs)
        logit = self.clsLayer(embedding)
        revenue = self.revLayer(embedding)
        
        return logit, revenue
    
    def compute_CLS_loss(self,logit, label):
        logit = logit.flatten()
        loss = self.loss_func(logit,label)
        return loss
    
    def compute_REG_loss(self,pred_rev, rev):
        pred_rev = pred_rev.flatten()
        loss = F.mse_loss(pred_rev,rev)
        return loss 

    def training_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        train_loss = CLS_loss + 10 * REG_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def validation_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        valid_loss = CLS_loss + 1 * REG_loss
        self.log('val_loss', valid_loss, on_step=True, on_epoch=True, sync_dist=True)
        return logits
    
    def validation_epoch_end(self, val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.valid_cls_label, self.data.valid_reg_label)
        f1_top = np.mean(f)
        self.log("F1-top",f1_top)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Val/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        return {"Val/F1-top":f1_top, "log":tensorboard_logs}
        
    def test_step(self,batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self,val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.test_cls_label, self.data.test_reg_label)
        f1_top = np.mean(f)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Test/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        
        return {"F1-top":f1_top, "log":tensorboard_logs}

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.05, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]

In [35]:
# model config
seed_everything(5674)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [-1,200]
numLayers = len(sizes)
batch_size = 512
pretrain_path = "./saved_model/pretrained.ckpt"

predictor = Predictor(input_dim, hidden_size, numLayers)
predictor.loadGNN_state(model)
# predictor.load_fromPretrain(pretrain_path)
predictor.data = data

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = CustomData(stacked_data, sizes = sizes, batch_size=batch_size)
logger = TensorBoardLogger("ssl_exp",name="Tdata")
logger.log_hyperparams(model.hparams, metrics={"F1-top":0})
checkpoint_callback = ModelCheckpoint(
    monitor='F1-top',    
    dirpath='./saved_model',
    filename='Analysis-Tdata5-{F1-top:.4f}',
    save_top_k=1,
    mode='max',
)
trainer = Trainer(gpus=[0], max_epochs=40,
                 num_sanity_val_steps=0,
                  check_val_every_n_epoch=1,
                  callbacks=[checkpoint_callback],
                 )
trainer.fit(predictor, datamodule=datamodule)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type              | Params
--------------------------------------------------
0 | gnn_encoder | PretrainGNN       | 119 K 
1 | clsLayer    | Linear            | 65    
2 | revLayer    | Linear            | 65    
3 | loss_func   | BCEWithLogitsLoss | 0     
--------------------------------------------------
119 K     Trainable params
0         Non-trainable params
119 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1380, Recall: 0.0402, Revenue: 0.0147
Checking top 2% suspicious transactions: 1419
Precision: 0.0930, Recall: 0.0542, Revenue: 0.0251
Checking top 5% suspicious transactions: 3545
Precision: 0.0663, Recall: 0.0965, Revenue: 0.1236
Checking top 10% suspicious transactions: 7092
Precision: 0.0544, Recall: 0.1585, Revenue: 0.1886


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1310, Recall: 0.0382, Revenue: 0.0204
Checking top 2% suspicious transactions: 1419
Precision: 0.1015, Recall: 0.0591, Revenue: 0.0331
Checking top 5% suspicious transactions: 3546
Precision: 0.0795, Recall: 0.1158, Revenue: 0.0877
Checking top 10% suspicious transactions: 7092
Precision: 0.0649, Recall: 0.1889, Revenue: 0.1925


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1620, Recall: 0.0472, Revenue: 0.0301
Checking top 2% suspicious transactions: 1419
Precision: 0.1276, Recall: 0.0743, Revenue: 0.0502
Checking top 5% suspicious transactions: 3546
Precision: 0.0902, Recall: 0.1314, Revenue: 0.0957
Checking top 10% suspicious transactions: 7092
Precision: 0.0623, Recall: 0.1815, Revenue: 0.2061


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2000, Recall: 0.0583, Revenue: 0.0398
Checking top 2% suspicious transactions: 1418
Precision: 0.1474, Recall: 0.0858, Revenue: 0.0484
Checking top 5% suspicious transactions: 3546
Precision: 0.0911, Recall: 0.1326, Revenue: 0.0764
Checking top 10% suspicious transactions: 7090
Precision: 0.0674, Recall: 0.1963, Revenue: 0.1528


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1831, Recall: 0.0534, Revenue: 0.0384
Checking top 2% suspicious transactions: 1419
Precision: 0.1388, Recall: 0.0809, Revenue: 0.0484
Checking top 5% suspicious transactions: 3546
Precision: 0.0919, Recall: 0.1339, Revenue: 0.0985
Checking top 10% suspicious transactions: 7092
Precision: 0.0666, Recall: 0.1938, Revenue: 0.1669


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2000, Recall: 0.0583, Revenue: 0.0422
Checking top 2% suspicious transactions: 1419
Precision: 0.1416, Recall: 0.0825, Revenue: 0.0527
Checking top 5% suspicious transactions: 3546
Precision: 0.0902, Recall: 0.1314, Revenue: 0.0984
Checking top 10% suspicious transactions: 7092
Precision: 0.0588, Recall: 0.1713, Revenue: 0.2093


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2493, Recall: 0.0727, Revenue: 0.0441
Checking top 2% suspicious transactions: 1419
Precision: 0.1586, Recall: 0.0924, Revenue: 0.0504
Checking top 5% suspicious transactions: 3546
Precision: 0.1001, Recall: 0.1458, Revenue: 0.1137
Checking top 10% suspicious transactions: 7092
Precision: 0.0919, Recall: 0.2678, Revenue: 0.1948


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2352, Recall: 0.0686, Revenue: 0.0448
Checking top 2% suspicious transactions: 1419
Precision: 0.1663, Recall: 0.0969, Revenue: 0.0522
Checking top 5% suspicious transactions: 3546
Precision: 0.1100, Recall: 0.1602, Revenue: 0.1223
Checking top 10% suspicious transactions: 7092
Precision: 0.1052, Recall: 0.3064, Revenue: 0.1927


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2634, Recall: 0.0768, Revenue: 0.0467
Checking top 2% suspicious transactions: 1419
Precision: 0.1783, Recall: 0.1039, Revenue: 0.0547
Checking top 5% suspicious transactions: 3546
Precision: 0.1111, Recall: 0.1618, Revenue: 0.1243
Checking top 10% suspicious transactions: 7092
Precision: 0.1103, Recall: 0.3211, Revenue: 0.2192


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2577, Recall: 0.0752, Revenue: 0.0468
Checking top 2% suspicious transactions: 1419
Precision: 0.1705, Recall: 0.0994, Revenue: 0.0533
Checking top 5% suspicious transactions: 3546
Precision: 0.1043, Recall: 0.1520, Revenue: 0.1194
Checking top 10% suspicious transactions: 7091
Precision: 0.0966, Recall: 0.2813, Revenue: 0.2569


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2986, Recall: 0.0871, Revenue: 0.0475
Checking top 2% suspicious transactions: 1419
Precision: 0.1755, Recall: 0.1023, Revenue: 0.0550
Checking top 5% suspicious transactions: 3545
Precision: 0.1044, Recall: 0.1520, Revenue: 0.1215
Checking top 10% suspicious transactions: 7092
Precision: 0.1004, Recall: 0.2924, Revenue: 0.2458


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3155, Recall: 0.0920, Revenue: 0.0523
Checking top 2% suspicious transactions: 1419
Precision: 0.1994, Recall: 0.1162, Revenue: 0.0695
Checking top 5% suspicious transactions: 3546
Precision: 0.1385, Recall: 0.2016, Revenue: 0.1404
Checking top 10% suspicious transactions: 7092
Precision: 0.1348, Recall: 0.3926, Revenue: 0.2458


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3155, Recall: 0.0920, Revenue: 0.0632
Checking top 2% suspicious transactions: 1419
Precision: 0.2008, Recall: 0.1170, Revenue: 0.0842
Checking top 5% suspicious transactions: 3546
Precision: 0.1351, Recall: 0.1967, Revenue: 0.1432
Checking top 10% suspicious transactions: 7092
Precision: 0.1428, Recall: 0.4160, Revenue: 0.2689


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3437, Recall: 0.1002, Revenue: 0.0644
Checking top 2% suspicious transactions: 1419
Precision: 0.2121, Recall: 0.1236, Revenue: 0.0877
Checking top 5% suspicious transactions: 3546
Precision: 0.1277, Recall: 0.1860, Revenue: 0.1340
Checking top 10% suspicious transactions: 7092
Precision: 0.1304, Recall: 0.3799, Revenue: 0.2465


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3070, Recall: 0.0895, Revenue: 0.0605
Checking top 2% suspicious transactions: 1416
Precision: 0.2161, Recall: 0.1257, Revenue: 0.0860
Checking top 5% suspicious transactions: 3546
Precision: 0.1421, Recall: 0.2070, Revenue: 0.1459
Checking top 10% suspicious transactions: 7092
Precision: 0.1434, Recall: 0.4177, Revenue: 0.2656


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3300, Recall: 0.0961, Revenue: 0.0695
Checking top 2% suspicious transactions: 1419
Precision: 0.2276, Recall: 0.1326, Revenue: 0.0883
Checking top 5% suspicious transactions: 3546
Precision: 0.1475, Recall: 0.2148, Revenue: 0.1474
Checking top 10% suspicious transactions: 7092
Precision: 0.1387, Recall: 0.4041, Revenue: 0.2562


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3338, Recall: 0.0973, Revenue: 0.0723
Checking top 2% suspicious transactions: 1419
Precision: 0.2333, Recall: 0.1359, Revenue: 0.0896
Checking top 5% suspicious transactions: 3546
Precision: 0.1892, Recall: 0.2756, Revenue: 0.1714
Checking top 10% suspicious transactions: 7092
Precision: 0.1803, Recall: 0.5253, Revenue: 0.3190


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3676, Recall: 0.1072, Revenue: 0.0785
Checking top 2% suspicious transactions: 1419
Precision: 0.2650, Recall: 0.1544, Revenue: 0.0954
Checking top 5% suspicious transactions: 3546
Precision: 0.2177, Recall: 0.3170, Revenue: 0.1962
Checking top 10% suspicious transactions: 7080
Precision: 0.1919, Recall: 0.5581, Revenue: 0.3252


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3540, Recall: 0.1031, Revenue: 0.0787
Checking top 2% suspicious transactions: 1419
Precision: 0.2459, Recall: 0.1433, Revenue: 0.0941
Checking top 5% suspicious transactions: 3546
Precision: 0.1748, Recall: 0.2546, Revenue: 0.1763
Checking top 10% suspicious transactions: 7090
Precision: 0.1640, Recall: 0.4776, Revenue: 0.2812


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3394, Recall: 0.0990, Revenue: 0.0746
Checking top 2% suspicious transactions: 1419
Precision: 0.2438, Recall: 0.1421, Revenue: 0.0920
Checking top 5% suspicious transactions: 3545
Precision: 0.1800, Recall: 0.2620, Revenue: 0.1791
Checking top 10% suspicious transactions: 7092
Precision: 0.1600, Recall: 0.4661, Revenue: 0.3016


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3648, Recall: 0.1064, Revenue: 0.0763
Checking top 2% suspicious transactions: 1419
Precision: 0.2622, Recall: 0.1528, Revenue: 0.0948
Checking top 5% suspicious transactions: 3546
Precision: 0.2239, Recall: 0.3261, Revenue: 0.1990
Checking top 10% suspicious transactions: 7092
Precision: 0.1827, Recall: 0.5322, Revenue: 0.3072


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3690, Recall: 0.1076, Revenue: 0.0767
Checking top 2% suspicious transactions: 1419
Precision: 0.2777, Recall: 0.1618, Revenue: 0.1007
Checking top 5% suspicious transactions: 3546
Precision: 0.2403, Recall: 0.3499, Revenue: 0.2077
Checking top 10% suspicious transactions: 7092
Precision: 0.1923, Recall: 0.5602, Revenue: 0.3408


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 707
Precision: 0.3564, Recall: 0.1035, Revenue: 0.0763
Checking top 2% suspicious transactions: 1419
Precision: 0.2685, Recall: 0.1565, Revenue: 0.1003
Checking top 5% suspicious transactions: 3546
Precision: 0.2343, Recall: 0.3413, Revenue: 0.2041
Checking top 10% suspicious transactions: 7092
Precision: 0.1936, Recall: 0.5639, Revenue: 0.3306


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3845, Recall: 0.1121, Revenue: 0.0807
Checking top 2% suspicious transactions: 1419
Precision: 0.2925, Recall: 0.1704, Revenue: 0.1043
Checking top 5% suspicious transactions: 3546
Precision: 0.2504, Recall: 0.3647, Revenue: 0.2135
Checking top 10% suspicious transactions: 7092
Precision: 0.1973, Recall: 0.5745, Revenue: 0.3426


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3789, Recall: 0.1105, Revenue: 0.0799
Checking top 2% suspicious transactions: 1414
Precision: 0.3112, Recall: 0.1807, Revenue: 0.1081
Checking top 5% suspicious transactions: 3541
Precision: 0.2717, Recall: 0.3951, Revenue: 0.2174
Checking top 10% suspicious transactions: 7092
Precision: 0.1951, Recall: 0.5684, Revenue: 0.3286


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3822, Recall: 0.1113, Revenue: 0.0793
Checking top 2% suspicious transactions: 1419
Precision: 0.2911, Recall: 0.1696, Revenue: 0.1061
Checking top 5% suspicious transactions: 3546
Precision: 0.2592, Recall: 0.3774, Revenue: 0.2289
Checking top 10% suspicious transactions: 7091
Precision: 0.1902, Recall: 0.5540, Revenue: 0.3463


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3944, Recall: 0.1150, Revenue: 0.0812
Checking top 2% suspicious transactions: 1419
Precision: 0.3101, Recall: 0.1807, Revenue: 0.1153
Checking top 5% suspicious transactions: 3546
Precision: 0.2908, Recall: 0.4234, Revenue: 0.2333
Checking top 10% suspicious transactions: 7050
Precision: 0.1997, Recall: 0.5782, Revenue: 0.3424


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.4175, Recall: 0.1216, Revenue: 0.0839
Checking top 2% suspicious transactions: 1419
Precision: 0.3524, Recall: 0.2053, Revenue: 0.1201
Checking top 5% suspicious transactions: 3546
Precision: 0.2919, Recall: 0.4251, Revenue: 0.2302
Checking top 10% suspicious transactions: 7083
Precision: 0.2044, Recall: 0.5947, Revenue: 0.3490


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3817, Recall: 0.1113, Revenue: 0.0813
Checking top 2% suspicious transactions: 1419
Precision: 0.3143, Recall: 0.1832, Revenue: 0.1074
Checking top 5% suspicious transactions: 3546
Precision: 0.2876, Recall: 0.4189, Revenue: 0.2389
Checking top 10% suspicious transactions: 7092
Precision: 0.2069, Recall: 0.6025, Revenue: 0.3525


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4127, Recall: 0.1203, Revenue: 0.0850
Checking top 2% suspicious transactions: 1418
Precision: 0.3456, Recall: 0.2012, Revenue: 0.1148
Checking top 5% suspicious transactions: 3546
Precision: 0.3003, Recall: 0.4374, Revenue: 0.2496
Checking top 10% suspicious transactions: 7092
Precision: 0.2015, Recall: 0.5869, Revenue: 0.3449


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4014, Recall: 0.1170, Revenue: 0.0825
Checking top 2% suspicious transactions: 1419
Precision: 0.3284, Recall: 0.1914, Revenue: 0.1096
Checking top 5% suspicious transactions: 3546
Precision: 0.2879, Recall: 0.4193, Revenue: 0.2365
Checking top 10% suspicious transactions: 7092
Precision: 0.1957, Recall: 0.5700, Revenue: 0.3440


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3873, Recall: 0.1129, Revenue: 0.0816
Checking top 2% suspicious transactions: 1419
Precision: 0.3178, Recall: 0.1852, Revenue: 0.1082
Checking top 5% suspicious transactions: 3546
Precision: 0.2806, Recall: 0.4086, Revenue: 0.2366
Checking top 10% suspicious transactions: 7052
Precision: 0.2026, Recall: 0.5869, Revenue: 0.3475


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4042, Recall: 0.1179, Revenue: 0.0848
Checking top 2% suspicious transactions: 1411
Precision: 0.3451, Recall: 0.2000, Revenue: 0.1158
Checking top 5% suspicious transactions: 3546
Precision: 0.3170, Recall: 0.4616, Revenue: 0.2793
Checking top 10% suspicious transactions: 7091
Precision: 0.2110, Recall: 0.6144, Revenue: 0.3777


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4127, Recall: 0.1203, Revenue: 0.0857
Checking top 2% suspicious transactions: 1412
Precision: 0.3421, Recall: 0.1984, Revenue: 0.1123
Checking top 5% suspicious transactions: 3546
Precision: 0.2936, Recall: 0.4275, Revenue: 0.2548
Checking top 10% suspicious transactions: 7092
Precision: 0.2023, Recall: 0.5893, Revenue: 0.3448


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3789, Recall: 0.1105, Revenue: 0.0793
Checking top 2% suspicious transactions: 1419
Precision: 0.3333, Recall: 0.1943, Revenue: 0.1084
Checking top 5% suspicious transactions: 3546
Precision: 0.2840, Recall: 0.4136, Revenue: 0.2414
Checking top 10% suspicious transactions: 7040
Precision: 0.1999, Recall: 0.5778, Revenue: 0.3487


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4155, Recall: 0.1211, Revenue: 0.0866
Checking top 2% suspicious transactions: 1419
Precision: 0.3460, Recall: 0.2016, Revenue: 0.1164
Checking top 5% suspicious transactions: 3546
Precision: 0.3029, Recall: 0.4411, Revenue: 0.2637
Checking top 10% suspicious transactions: 7061
Precision: 0.2037, Recall: 0.5906, Revenue: 0.3723


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3972, Recall: 0.1158, Revenue: 0.0840
Checking top 2% suspicious transactions: 1419
Precision: 0.3474, Recall: 0.2025, Revenue: 0.1155
Checking top 5% suspicious transactions: 3545
Precision: 0.3058, Recall: 0.4452, Revenue: 0.2635
Checking top 10% suspicious transactions: 7038
Precision: 0.2062, Recall: 0.5959, Revenue: 0.3639


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3817, Recall: 0.1113, Revenue: 0.0797
Checking top 2% suspicious transactions: 1419
Precision: 0.3531, Recall: 0.2057, Revenue: 0.1141
Checking top 5% suspicious transactions: 3546
Precision: 0.2953, Recall: 0.4300, Revenue: 0.2479
Checking top 10% suspicious transactions: 7033
Precision: 0.1927, Recall: 0.5565, Revenue: 0.3389


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3949, Recall: 0.1150, Revenue: 0.0851
Checking top 2% suspicious transactions: 1419
Precision: 0.3841, Recall: 0.2238, Revenue: 0.1258
Checking top 5% suspicious transactions: 3546
Precision: 0.3314, Recall: 0.4825, Revenue: 0.2904
Checking top 10% suspicious transactions: 7048
Precision: 0.2030, Recall: 0.5877, Revenue: 0.3707


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4521, Recall: 0.1318, Revenue: 0.0893
Checking top 2% suspicious transactions: 1419
Precision: 0.4193, Recall: 0.2444, Revenue: 0.1310
Checking top 5% suspicious transactions: 3545
Precision: 0.3337, Recall: 0.4858, Revenue: 0.2740
Checking top 10% suspicious transactions: 7092
Precision: 0.2006, Recall: 0.5844, Revenue: 0.3382



1

In [36]:
trainer.test()

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

Checking top 1% suspicious transactions: 2749
Precision: 0.4569, Recall: 0.1848, Revenue: 0.1668
Checking top 2% suspicious transactions: 5497
Precision: 0.4370, Recall: 0.3534, Revenue: 0.2662
Checking top 5% suspicious transactions: 13740
Precision: 0.2868, Recall: 0.5797, Revenue: 0.4017
Checking top 10% suspicious transactions: 27481
Precision: 0.1540, Recall: 0.6226, Revenue: 0.4392

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'F1-top': 0.3211316321176841,
 'Test/F1@1': 0.2631468677980306,
 'Test/F1@10': 0.24692222416710427,
 'Test/F1@2': 0.39075972018870997,
 'Test/F1@5': 0.38369771631689154,
 'Test/Pr@1': 0.4568934157875591,
 'Test/Pr@10': 0.15399730723045013,
 'Test/Pr@2': 0.43696561760960523,
 'Test/Pr@5': 0.2867540029112082,
 'Test/Re@1': 0.18478740620862144,
 'Test/Re@10': 0.6226276298366926,
 'Test/Re@2': 0.35339120200088275,
 'Test/Re@5': 0.5796675003678093,
 'Test/Rev@1': 0.16681706804627314,
 'Test/Rev@10': 

[{'Test/F1@1': 0.2631468677980306,
  'Test/F1@2': 0.39075972018870997,
  'Test/F1@5': 0.38369771631689154,
  'Test/F1@10': 0.24692222416710427,
  'Test/Pr@1': 0.4568934157875591,
  'Test/Pr@2': 0.43696561760960523,
  'Test/Pr@5': 0.2867540029112082,
  'Test/Pr@10': 0.15399730723045013,
  'Test/Re@1': 0.18478740620862144,
  'Test/Re@2': 0.35339120200088275,
  'Test/Re@5': 0.5796675003678093,
  'Test/Re@10': 0.6226276298366926,
  'Test/Rev@1': 0.16681706804627314,
  'Test/Rev@2': 0.26624723219174173,
  'Test/Rev@5': 0.40170490177245305,
  'Test/Rev@10': 0.4391643519836273,
  'F1-top': 0.3211316321176841,
  'val_loss_epoch': 3.2925469875335693,
  'val_loss': 1.5478118658065796,
  'Val/F1-top': 0.3018214384189277,
  'Val/F1@1': 0.20413354531001593,
  'Val/F1@2': 0.3087701089776855,
  'Val/F1@5': 0.3956521739130435,
  'Val/F1@10': 0.2987299254749659,
  'Val/Pr@1': 0.45211267605633804,
  'Val/Pr@2': 0.4193093727977449,
  'Val/Pr@5': 0.33370944992947815,
  'Val/Pr@10': 0.20064861816130852,
  