In [1]:
import dataset
import datetime
from datetime import timedelta
from parser import get_parser
import numpy as np 
import pandas as pd 
import torch
from pytorch_lightning import Trainer
from torch_geometric.utils import to_undirected

## Load data

In [2]:
data = dataset.Ndata(path='../Custom-Semi-Supervised/data/ndata.csv')
parser = get_parser()
args = parser.parse_args(args=
                         ["--data","real-n", 
                          "--sampling","xgb",
                          "--train_from","20140101",
                          "--test_from","20170101",
                          "--test_length","365",
                          "--valid_length","90",
                          "--initial_inspection_rate", "5",
                          "--final_inspection_rate", "10",
                         ])

In [3]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (54134, 52), Train unlabeled: (1028538, 52), Valid labeled: (70917, 52), Valid unlabeled: (0, 26), Test: (274808, 52)
Checking label distribution
Training: 0.05022795615481618
Validation: 0.035556788645191434
Testing: 0.025360899366070794


## Prepare DATA

In [4]:
from utils import *
from pygData_util import *

In [5]:
categories=["importer.id","HS6"]
gdata = GraphData(data,use_xgb=True, categories=categories, pos_weight = 50)

Training XGBoost model...


In [6]:
best_thresh, best_auc = find_best_threshold(gdata.xgb,data.dfvalidx_lab, data.valid_cls_label)
xgb_test_pred = gdata.xgb.predict_proba(data.dfvalidx_lab)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.valid_cls_label,data.valid_reg_label,best_thresh)
print("-"*50)
xgb_test_pred = gdata.xgb.predict_proba(data.dftestx)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.test_cls_label,data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 710
Precision: 0.2915, Recall: 0.0850, Revenue: 0.0443
Checking top 2% suspicious transactions: 1419
Precision: 0.1677, Recall: 0.0977, Revenue: 0.0501
Checking top 5% suspicious transactions: 3546
Precision: 0.0832, Recall: 0.1211, Revenue: 0.0844
Checking top 10% suspicious transactions: 7092
Precision: 0.0495, Recall: 0.1441, Revenue: 0.1336
--------------------------------------------------
Checking top 1% suspicious transactions: 2749
Precision: 0.1404, Recall: 0.0568, Revenue: 0.0942
Checking top 2% suspicious transactions: 5497
Precision: 0.0871, Recall: 0.0705, Revenue: 0.1269
Checking top 5% suspicious transactions: 13741
Precision: 0.0442, Recall: 0.0893, Revenue: 0.1859
Checking top 10% suspicious transactions: 27481
Precision: 0.0888, Recall: 0.3590, Revenue: 0.3713


In [7]:
stage = "train_lab"
trainLab_data = gdata.get_data(stage)
train_nodeidx = torch.tensor(gdata.get_AttNode(stage))
trainLab_data.node_idx = train_nodeidx

In [8]:
stage = "train_unlab"
unlab_data = gdata.get_data(stage)
unlab_nodeidx = torch.tensor(gdata.get_AttNode(stage))
unlab_data.node_idx = unlab_nodeidx

In [9]:
stage = "valid"
valid_data = gdata.get_data(stage)
valid_nodeidx = torch.tensor(gdata.get_AttNode(stage))
valid_data.node_idx = valid_nodeidx

In [10]:
stage = "test"
test_data = gdata.get_data(stage)
test_nodeidx = torch.tensor(gdata.get_AttNode(stage))
test_data.node_idx = test_nodeidx

In [11]:
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)

## New Sampler

In [12]:
from torch_cluster import random_walk
from torch_geometric.data import NeighborSampler 
from pytorch_lightning import LightningDataModule

In [13]:
class UnsupSampler(NeighborSampler):
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(UnsupSampler, self).sample(batch)

In [14]:
class Batch(NamedTuple):
    '''
    convert batch data for pytorch-lightning
    '''
    x: Tensor
    y: Tensor
    rev: Tensor
    adjs_t: NamedTuple
    def to(self, *args, **kwargs):
        return Batch(
            x=self.x.to(*args, **kwargs),
            y=self.y.to(*args, **kwargs),
            rev=self.rev.to(*args, **kwargs),
            adjs_t=[(adj_t.to(*args, **kwargs), eid.to(*args, **kwargs), size) for adj_t, eid, size in self.adjs_t],
        )


In [15]:
class UnsupData(LightningDataModule):
    def __init__(self,data,sizes, batch_size = 128):
        '''
        defining dataloader with NeighborSampler to extract k-hop subgraph.
        Args:
            data (Graphdata): graph data for the edges and node index
            sizes ([int]): The number of neighbors to sample for each node in each layer. 
                           If set to :obj:`sizes[l] = -1`, all neighbors are included
            batch_size (int): batch size for training
        '''
        super(UnsupData,self).__init__()
        self.data = data
        self.sizes = sizes
        self.valid_sizes = [-1 for i in self.sizes]
        self.batch_size = batch_size

    def train_dataloader(self):
        return UnsupSampler(self.data.train_edge, sizes=self.sizes,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=True,num_workers=8)
    
    def test_dataloader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.test_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)
    
    def label_loader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.train_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)

    def convert_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.data.x[n_id],
            y=self.data.y[n_id[:batch_size]],
            rev = self.data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )

## Model

In [16]:
from models import MLP, GNNStack, UselessConv, Mish

In [17]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import LightningModule, seed_everything
from torchtools.optim import RangerLars
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class PretrainGNN(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.save_hyperparameters()
        self.input_dim = input_dim
        self.dim = hidden_dim*2
        self.numLayers = numLayers
        self.layers = [self.dim, self.dim//2] #* (numLayers+1)
        self.bn = nn.BatchNorm1d(self.dim)
        self.act = Mish()
        self.useXGB = useXGB
        
        # GNN layer
        if self.useXGB:
            self.initEmbedding = nn.Embedding(self.input_dim, self.dim, padding_idx=0)
        else:
            self.initEmbedding = MLP(self.input_dim, self.dim, Numlayer=2)
        self.initGNN = UselessConv()
        self.GNNs = GNNStack(self.layers,self.numLayers)

    def forward(self, x,adjs):
        # update node embedding
        leaf_emb = self.initEmbedding(x)
        if self.useXGB:
            leaf_emb = torch.sum(leaf_emb,dim=1) # summation over the trees
            leaf_emb = self.bn(leaf_emb)
            leaf_emb = self.act(leaf_emb)
        
        # first update 
        firstHop_neighbor = adjs[-1][0]
        leaf_emb = self.initGNN(leaf_emb,to_undirected(firstHop_neighbor))
        
        # GNN 
        embeddings = self.GNNs(leaf_emb, adjs)
        
        return embeddings[-1]
    
    def training_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
        neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
        train_loss = -pos_loss - neg_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def test_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        return out
    
    def test_epoch_end(self, val_step_outputs):
        val_step_outputs = torch.cat(val_step_outputs)
        val_step_outputs = val_step_outputs.cpu().detach().numpy()
        return {"log":{"predictions":val_step_outputs}}
    
    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.01, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.8)
        return [optimizer], [scheduler]

In [19]:
# model config
seed_everything(2345)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [50,20]
numLayers = len(sizes)
batch_size = 1024

model = PretrainGNN(input_dim, hidden_size, numLayers, useXGB=gdata.use_xgb)

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = UnsupData(stacked_data, sizes = sizes, batch_size=batch_size)
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',    
    dirpath='./saved_model',
    filename='Tdata5-pretrain-{train_loss:.4f}',
    save_top_k=1,
    mode='min',
)
trainer = Trainer(gpus=[1], max_epochs=3,
#                   callbacks=[checkpoint_callback],
                 )
trainer.fit(model, train_dataloader=datamodule.train_dataloader(),
           )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type        | Params
----------------------------------------------
0 | bn            | BatchNorm1d | 128   
1 | act           | Mish        | 0     
2 | initEmbedding | Embedding   | 86.1 K
3 | initGNN       | UselessConv | 0     
4 | GNNs          | GNNStack    | 33.5 K
----------------------------------------------
119 K     Trainable params
0         Non-trainable params
119 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1

In [20]:
trainer.test()

1

In [21]:
# trainer.save_checkpoint("./saved_model/pretrained.ckpt")

In [22]:
from tqdm import tqdm_notebook

In [23]:
test_embeddings = []
for batch in tqdm_notebook(datamodule.test_dataloader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    test_embeddings.append(out.cpu().detach().numpy())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=269.0), HTML(value='')))




In [24]:
embeddings = []
for batch in tqdm_notebook(datamodule.label_loader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    embeddings.append(out.cpu().detach().numpy())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=53.0), HTML(value='')))




In [25]:
embeddings = np.concatenate(embeddings)
test_embeddings = np.concatenate(test_embeddings)

In [26]:
from sklearn.linear_model import LogisticRegression

In [27]:
lr = LogisticRegression(class_weight={1:50,0:1})
lr.fit(embeddings,data.train_cls_label)

LogisticRegression(class_weight={0: 1, 1: 50})

In [28]:
y_prob = lr.predict_proba(test_embeddings)[:,1]

In [29]:
_ = metrics(y_prob, data.test_cls_label,data.test_reg_label,None)

Checking top 1% suspicious transactions: 2749
Precision: 0.0517, Recall: 0.0209, Revenue: 0.0244
Checking top 2% suspicious transactions: 5497
Precision: 0.0446, Recall: 0.0360, Revenue: 0.0544
Checking top 5% suspicious transactions: 13741
Precision: 0.0387, Recall: 0.0783, Revenue: 0.1310
Checking top 10% suspicious transactions: 27481
Precision: 0.0343, Recall: 0.1386, Revenue: 0.2171


## Fine tuning

In [30]:
from torch_geometric.nn import GATConv,TransformerConv
from pygData_util import *

In [31]:
class Predictor(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.gnn_encoder = PretrainGNN(input_dim, hidden_size, numLayers, useXGB)
        self.dim = hidden_dim * 2
        
        # output
        self.clsLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.revLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.loss_func = nn.BCEWithLogitsLoss(pos_weight = torch.tensor([20])) #FocalLoss(logits=True)
        
    def load_fromPretrain(self,path):
        self.gnn_encoder.load_from_checkpoint(path)
        
    def loadGNN_state(self,model):
        self.gnn_encoder.load_state_dict(model.state_dict())

    def forward(self, x,adjs):
#         firstHop_neighbor = adjs[-1][0]
        # get node embedding from pre-trained model
        embedding = self.gnn_encoder(x,adjs)
        logit = self.clsLayer(embedding)
        revenue = self.revLayer(embedding)
        
        return logit, revenue
    
    def compute_CLS_loss(self,logit, label):
        logit = logit.flatten()
        loss = self.loss_func(logit,label)
        return loss
    
    def compute_REG_loss(self,pred_rev, rev):
        pred_rev = pred_rev.flatten()
        loss = F.mse_loss(pred_rev,rev)
        return loss 

    def training_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        train_loss = CLS_loss + 10 * REG_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def validation_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        valid_loss = CLS_loss + 1 * REG_loss
        self.log('val_loss', valid_loss, on_step=True, on_epoch=True, sync_dist=True)
        return logits
    
    def validation_epoch_end(self, val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.valid_cls_label, self.data.valid_reg_label)
        f1_top = np.mean(f)
        self.log("F1-top",f1_top)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Val/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        return {"Val/F1-top":f1_top, "log":tensorboard_logs}
        
    def test_step(self,batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self,val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.test_cls_label, self.data.test_reg_label)
        f1_top = np.mean(f)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Test/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        
        return {"F1-top":f1_top, "log":tensorboard_logs}

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.05, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]

In [32]:
# model config
seed_everything(5674)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [-1,200]
numLayers = len(sizes)
batch_size = 512
pretrain_path = "./saved_model/pretrained.ckpt"

predictor = Predictor(input_dim, hidden_size, numLayers)
predictor.loadGNN_state(model)
# predictor.load_fromPretrain(pretrain_path)
predictor.data = data

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = CustomData(stacked_data, sizes = sizes, batch_size=batch_size)
logger = TensorBoardLogger("ssl_exp",name="Tdata")
logger.log_hyperparams(model.hparams, metrics={"F1-top":0})
checkpoint_callback = ModelCheckpoint(
    monitor='F1-top',    
    dirpath='./saved_model',
    filename='Analysis-Tdata5-{F1-top:.4f}',
    save_top_k=1,
    mode='max',
)
trainer = Trainer(gpus=[1], max_epochs=40,
                 num_sanity_val_steps=0,
                  check_val_every_n_epoch=1,
                  callbacks=[checkpoint_callback],
                 )
trainer.fit(predictor, datamodule=datamodule)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type              | Params
--------------------------------------------------
0 | gnn_encoder | PretrainGNN       | 119 K 
1 | clsLayer    | Linear            | 65    
2 | revLayer    | Linear            | 65    
3 | loss_func   | BCEWithLogitsLoss | 0     
--------------------------------------------------
119 K     Trainable params
0         Non-trainable params
119 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1577, Recall: 0.0460, Revenue: 0.0256
Checking top 2% suspicious transactions: 1419
Precision: 0.1071, Recall: 0.0624, Revenue: 0.0508
Checking top 5% suspicious transactions: 3546
Precision: 0.0804, Recall: 0.1170, Revenue: 0.0991
Checking top 10% suspicious transactions: 7091
Precision: 0.0656, Recall: 0.1910, Revenue: 0.2237


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1887, Recall: 0.0550, Revenue: 0.0401
Checking top 2% suspicious transactions: 1419
Precision: 0.1325, Recall: 0.0772, Revenue: 0.0513
Checking top 5% suspicious transactions: 3546
Precision: 0.0905, Recall: 0.1318, Revenue: 0.1154
Checking top 10% suspicious transactions: 7089
Precision: 0.0633, Recall: 0.1844, Revenue: 0.1711


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2451, Recall: 0.0715, Revenue: 0.0466
Checking top 2% suspicious transactions: 1419
Precision: 0.1501, Recall: 0.0875, Revenue: 0.0518
Checking top 5% suspicious transactions: 3545
Precision: 0.0934, Recall: 0.1359, Revenue: 0.0988
Checking top 10% suspicious transactions: 7092
Precision: 0.0626, Recall: 0.1823, Revenue: 0.2102


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2155, Recall: 0.0628, Revenue: 0.0425
Checking top 2% suspicious transactions: 1419
Precision: 0.1445, Recall: 0.0842, Revenue: 0.0502
Checking top 5% suspicious transactions: 3546
Precision: 0.0936, Recall: 0.1363, Revenue: 0.1061
Checking top 10% suspicious transactions: 7092
Precision: 0.0718, Recall: 0.2090, Revenue: 0.2232


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2732, Recall: 0.0797, Revenue: 0.0481
Checking top 2% suspicious transactions: 1419
Precision: 0.1705, Recall: 0.0994, Revenue: 0.0553
Checking top 5% suspicious transactions: 3546
Precision: 0.0956, Recall: 0.1392, Revenue: 0.1240
Checking top 10% suspicious transactions: 7092
Precision: 0.0746, Recall: 0.2172, Revenue: 0.2242


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2845, Recall: 0.0830, Revenue: 0.0484
Checking top 2% suspicious transactions: 1419
Precision: 0.1875, Recall: 0.1092, Revenue: 0.0831
Checking top 5% suspicious transactions: 3546
Precision: 0.1111, Recall: 0.1618, Revenue: 0.1372
Checking top 10% suspicious transactions: 7092
Precision: 0.0842, Recall: 0.2452, Revenue: 0.1993


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2944, Recall: 0.0858, Revenue: 0.0469
Checking top 2% suspicious transactions: 1419
Precision: 0.1720, Recall: 0.1002, Revenue: 0.0531
Checking top 5% suspicious transactions: 3546
Precision: 0.0894, Recall: 0.1302, Revenue: 0.1010
Checking top 10% suspicious transactions: 7091
Precision: 0.0753, Recall: 0.2193, Revenue: 0.1738


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2901, Recall: 0.0846, Revenue: 0.0479
Checking top 2% suspicious transactions: 1419
Precision: 0.1846, Recall: 0.1076, Revenue: 0.0687
Checking top 5% suspicious transactions: 3546
Precision: 0.1108, Recall: 0.1614, Revenue: 0.1250
Checking top 10% suspicious transactions: 7089
Precision: 0.0959, Recall: 0.2793, Revenue: 0.2152


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3042, Recall: 0.0887, Revenue: 0.0485
Checking top 2% suspicious transactions: 1419
Precision: 0.1917, Recall: 0.1117, Revenue: 0.0736
Checking top 5% suspicious transactions: 3546
Precision: 0.1089, Recall: 0.1585, Revenue: 0.1241
Checking top 10% suspicious transactions: 7092
Precision: 0.0936, Recall: 0.2727, Revenue: 0.2549


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3225, Recall: 0.0940, Revenue: 0.0550
Checking top 2% suspicious transactions: 1419
Precision: 0.1882, Recall: 0.1097, Revenue: 0.0800
Checking top 5% suspicious transactions: 3546
Precision: 0.1153, Recall: 0.1680, Revenue: 0.1423
Checking top 10% suspicious transactions: 7092
Precision: 0.1008, Recall: 0.2936, Revenue: 0.2218


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3061, Recall: 0.0891, Revenue: 0.0502
Checking top 2% suspicious transactions: 1419
Precision: 0.1924, Recall: 0.1121, Revenue: 0.0722
Checking top 5% suspicious transactions: 3546
Precision: 0.1156, Recall: 0.1684, Revenue: 0.1366
Checking top 10% suspicious transactions: 7091
Precision: 0.1121, Recall: 0.3265, Revenue: 0.2863


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3099, Recall: 0.0903, Revenue: 0.0510
Checking top 2% suspicious transactions: 1418
Precision: 0.1968, Recall: 0.1146, Revenue: 0.0706
Checking top 5% suspicious transactions: 3546
Precision: 0.1368, Recall: 0.1992, Revenue: 0.1399
Checking top 10% suspicious transactions: 7092
Precision: 0.1447, Recall: 0.4214, Revenue: 0.2497


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3070, Recall: 0.0895, Revenue: 0.0579
Checking top 2% suspicious transactions: 1417
Precision: 0.1870, Recall: 0.1088, Revenue: 0.0660
Checking top 5% suspicious transactions: 3546
Precision: 0.1221, Recall: 0.1778, Revenue: 0.1263
Checking top 10% suspicious transactions: 7092
Precision: 0.1574, Recall: 0.4583, Revenue: 0.2805


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3479, Recall: 0.1014, Revenue: 0.0624
Checking top 2% suspicious transactions: 1419
Precision: 0.2255, Recall: 0.1314, Revenue: 0.0890
Checking top 5% suspicious transactions: 3546
Precision: 0.1627, Recall: 0.2370, Revenue: 0.1642
Checking top 10% suspicious transactions: 7092
Precision: 0.1578, Recall: 0.4595, Revenue: 0.2844


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3268, Recall: 0.0953, Revenue: 0.0615
Checking top 2% suspicious transactions: 1419
Precision: 0.2178, Recall: 0.1269, Revenue: 0.0888
Checking top 5% suspicious transactions: 3546
Precision: 0.1579, Recall: 0.2300, Revenue: 0.1661
Checking top 10% suspicious transactions: 7092
Precision: 0.1506, Recall: 0.4386, Revenue: 0.2766


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3408, Recall: 0.0994, Revenue: 0.0581
Checking top 2% suspicious transactions: 1419
Precision: 0.2156, Recall: 0.1257, Revenue: 0.0831
Checking top 5% suspicious transactions: 3546
Precision: 0.1399, Recall: 0.2037, Revenue: 0.1734
Checking top 10% suspicious transactions: 7092
Precision: 0.1389, Recall: 0.4045, Revenue: 0.3199


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3408, Recall: 0.0994, Revenue: 0.0606
Checking top 2% suspicious transactions: 1419
Precision: 0.2340, Recall: 0.1363, Revenue: 0.0896
Checking top 5% suspicious transactions: 3546
Precision: 0.1929, Recall: 0.2809, Revenue: 0.1900
Checking top 10% suspicious transactions: 7091
Precision: 0.1722, Recall: 0.5014, Revenue: 0.2861


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3254, Recall: 0.0949, Revenue: 0.0630
Checking top 2% suspicious transactions: 1419
Precision: 0.2100, Recall: 0.1224, Revenue: 0.0875
Checking top 5% suspicious transactions: 3546
Precision: 0.1362, Recall: 0.1984, Revenue: 0.1461
Checking top 10% suspicious transactions: 7092
Precision: 0.1337, Recall: 0.3893, Revenue: 0.2594


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3413, Recall: 0.0994, Revenue: 0.0719
Checking top 2% suspicious transactions: 1419
Precision: 0.2304, Recall: 0.1343, Revenue: 0.1070
Checking top 5% suspicious transactions: 3546
Precision: 0.1813, Recall: 0.2641, Revenue: 0.1818
Checking top 10% suspicious transactions: 7092
Precision: 0.1722, Recall: 0.5014, Revenue: 0.2906


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2775, Recall: 0.0809, Revenue: 0.0479
Checking top 2% suspicious transactions: 1418
Precision: 0.1989, Recall: 0.1158, Revenue: 0.0587
Checking top 5% suspicious transactions: 3546
Precision: 0.1906, Recall: 0.2776, Revenue: 0.1927
Checking top 10% suspicious transactions: 7092
Precision: 0.1806, Recall: 0.5261, Revenue: 0.3373


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3028, Recall: 0.0883, Revenue: 0.0496
Checking top 2% suspicious transactions: 1419
Precision: 0.2149, Recall: 0.1253, Revenue: 0.0788
Checking top 5% suspicious transactions: 3546
Precision: 0.1968, Recall: 0.2867, Revenue: 0.1774
Checking top 10% suspicious transactions: 7092
Precision: 0.1805, Recall: 0.5257, Revenue: 0.3091


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3141, Recall: 0.0916, Revenue: 0.0570
Checking top 2% suspicious transactions: 1419
Precision: 0.2361, Recall: 0.1376, Revenue: 0.0895
Checking top 5% suspicious transactions: 3535
Precision: 0.2260, Recall: 0.3281, Revenue: 0.1904
Checking top 10% suspicious transactions: 7092
Precision: 0.1809, Recall: 0.5269, Revenue: 0.3053


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3296, Recall: 0.0961, Revenue: 0.0696
Checking top 2% suspicious transactions: 1419
Precision: 0.2502, Recall: 0.1458, Revenue: 0.0969
Checking top 5% suspicious transactions: 3546
Precision: 0.2431, Recall: 0.3540, Revenue: 0.2241
Checking top 10% suspicious transactions: 7092
Precision: 0.2023, Recall: 0.5893, Revenue: 0.3495


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3521, Recall: 0.1027, Revenue: 0.0710
Checking top 2% suspicious transactions: 1419
Precision: 0.2664, Recall: 0.1552, Revenue: 0.0971
Checking top 5% suspicious transactions: 3546
Precision: 0.2521, Recall: 0.3671, Revenue: 0.2132
Checking top 10% suspicious transactions: 7092
Precision: 0.1981, Recall: 0.5770, Revenue: 0.3308


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3704, Recall: 0.1080, Revenue: 0.0711
Checking top 2% suspicious transactions: 1419
Precision: 0.2657, Recall: 0.1548, Revenue: 0.0969
Checking top 5% suspicious transactions: 3543
Precision: 0.2433, Recall: 0.3540, Revenue: 0.2084
Checking top 10% suspicious transactions: 7092
Precision: 0.1881, Recall: 0.5478, Revenue: 0.3068


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3662, Recall: 0.1068, Revenue: 0.0734
Checking top 2% suspicious transactions: 1419
Precision: 0.2875, Recall: 0.1676, Revenue: 0.0996
Checking top 5% suspicious transactions: 3546
Precision: 0.2840, Recall: 0.4136, Revenue: 0.2398
Checking top 10% suspicious transactions: 7092
Precision: 0.2064, Recall: 0.6012, Revenue: 0.3502


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3606, Recall: 0.1051, Revenue: 0.0737
Checking top 2% suspicious transactions: 1419
Precision: 0.2889, Recall: 0.1684, Revenue: 0.1013
Checking top 5% suspicious transactions: 3546
Precision: 0.2882, Recall: 0.4197, Revenue: 0.2371
Checking top 10% suspicious transactions: 7092
Precision: 0.1994, Recall: 0.5807, Revenue: 0.3351


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3831, Recall: 0.1117, Revenue: 0.0775
Checking top 2% suspicious transactions: 1419
Precision: 0.3206, Recall: 0.1869, Revenue: 0.1058
Checking top 5% suspicious transactions: 3546
Precision: 0.3029, Recall: 0.4411, Revenue: 0.2545
Checking top 10% suspicious transactions: 7088
Precision: 0.2091, Recall: 0.6086, Revenue: 0.3474


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3437, Recall: 0.1002, Revenue: 0.0702
Checking top 2% suspicious transactions: 1419
Precision: 0.2932, Recall: 0.1708, Revenue: 0.1012
Checking top 5% suspicious transactions: 3546
Precision: 0.2871, Recall: 0.4181, Revenue: 0.2420
Checking top 10% suspicious transactions: 7092
Precision: 0.2046, Recall: 0.5959, Revenue: 0.3480


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3817, Recall: 0.1113, Revenue: 0.0792
Checking top 2% suspicious transactions: 1419
Precision: 0.3369, Recall: 0.1963, Revenue: 0.1112
Checking top 5% suspicious transactions: 3546
Precision: 0.3096, Recall: 0.4509, Revenue: 0.2585
Checking top 10% suspicious transactions: 7092
Precision: 0.2016, Recall: 0.5873, Revenue: 0.3348


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3648, Recall: 0.1064, Revenue: 0.0747
Checking top 2% suspicious transactions: 1419
Precision: 0.3263, Recall: 0.1901, Revenue: 0.1082
Checking top 5% suspicious transactions: 3546
Precision: 0.2981, Recall: 0.4341, Revenue: 0.2530
Checking top 10% suspicious transactions: 7053
Precision: 0.1979, Recall: 0.5733, Revenue: 0.3323


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.3992, Recall: 0.1162, Revenue: 0.0834
Checking top 2% suspicious transactions: 1419
Precision: 0.3298, Recall: 0.1922, Revenue: 0.1083
Checking top 5% suspicious transactions: 3546
Precision: 0.2930, Recall: 0.4267, Revenue: 0.2445
Checking top 10% suspicious transactions: 7092
Precision: 0.2025, Recall: 0.5897, Revenue: 0.3446


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3535, Recall: 0.1031, Revenue: 0.0718
Checking top 2% suspicious transactions: 1419
Precision: 0.3122, Recall: 0.1819, Revenue: 0.1079
Checking top 5% suspicious transactions: 3546
Precision: 0.2978, Recall: 0.4337, Revenue: 0.2578
Checking top 10% suspicious transactions: 7092
Precision: 0.2002, Recall: 0.5832, Revenue: 0.3420


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3901, Recall: 0.1138, Revenue: 0.0777
Checking top 2% suspicious transactions: 1419
Precision: 0.3411, Recall: 0.1988, Revenue: 0.1117
Checking top 5% suspicious transactions: 3541
Precision: 0.3047, Recall: 0.4431, Revenue: 0.2455
Checking top 10% suspicious transactions: 7092
Precision: 0.2046, Recall: 0.5959, Revenue: 0.3453


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3662, Recall: 0.1068, Revenue: 0.0748
Checking top 2% suspicious transactions: 1419
Precision: 0.3136, Recall: 0.1828, Revenue: 0.1068
Checking top 5% suspicious transactions: 3546
Precision: 0.2876, Recall: 0.4189, Revenue: 0.2469
Checking top 10% suspicious transactions: 7090
Precision: 0.1965, Recall: 0.5721, Revenue: 0.3296


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.4175, Recall: 0.1216, Revenue: 0.0810
Checking top 2% suspicious transactions: 1419
Precision: 0.3686, Recall: 0.2148, Revenue: 0.1181
Checking top 5% suspicious transactions: 3544
Precision: 0.3287, Recall: 0.4784, Revenue: 0.2466
Checking top 10% suspicious transactions: 7092
Precision: 0.2016, Recall: 0.5873, Revenue: 0.3380


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.4127, Recall: 0.1203, Revenue: 0.0821
Checking top 2% suspicious transactions: 1419
Precision: 0.3841, Recall: 0.2238, Revenue: 0.1288
Checking top 5% suspicious transactions: 3539
Precision: 0.3227, Recall: 0.4690, Revenue: 0.2634
Checking top 10% suspicious transactions: 7091
Precision: 0.1986, Recall: 0.5782, Revenue: 0.3322


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3803, Recall: 0.1109, Revenue: 0.0779
Checking top 2% suspicious transactions: 1419
Precision: 0.3510, Recall: 0.2045, Revenue: 0.1169
Checking top 5% suspicious transactions: 3546
Precision: 0.3254, Recall: 0.4739, Revenue: 0.2702
Checking top 10% suspicious transactions: 7092
Precision: 0.2032, Recall: 0.5918, Revenue: 0.3426


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3746, Recall: 0.1092, Revenue: 0.0689
Checking top 2% suspicious transactions: 1419
Precision: 0.3594, Recall: 0.2094, Revenue: 0.1213
Checking top 5% suspicious transactions: 3540
Precision: 0.3223, Recall: 0.4686, Revenue: 0.2563
Checking top 10% suspicious transactions: 7061
Precision: 0.2039, Recall: 0.5914, Revenue: 0.3484


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3606, Recall: 0.1051, Revenue: 0.0690
Checking top 2% suspicious transactions: 1419
Precision: 0.3573, Recall: 0.2082, Revenue: 0.1221
Checking top 5% suspicious transactions: 3546
Precision: 0.3204, Recall: 0.4665, Revenue: 0.2689
Checking top 10% suspicious transactions: 7003
Precision: 0.2018, Recall: 0.5803, Revenue: 0.3387



1

In [33]:
trainer.test()

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

Checking top 1% suspicious transactions: 2749
Precision: 0.3965, Recall: 0.1604, Revenue: 0.1398
Checking top 2% suspicious transactions: 5497
Precision: 0.3955, Recall: 0.3198, Revenue: 0.2228
Checking top 5% suspicious transactions: 13741
Precision: 0.2857, Recall: 0.5776, Revenue: 0.3950
Checking top 10% suspicious transactions: 27481
Precision: 0.1666, Recall: 0.6735, Revenue: 0.4737

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'F1-top': 0.30786554154637824,
 'Test/F1@1': 0.22836790278650743,
 'Test/F1@10': 0.26711009977244876,
 'Test/F1@2': 0.35366845615747516,
 'Test/F1@5': 0.38231570746908167,
 'Test/Pr@1': 0.3965078210258276,
 'Test/Pr@10': 0.16658782431498126,
 'Test/Pr@2': 0.39548844824449697,
 'Test/Pr@5': 0.2857142857142857,
 'Test/Re@1': 0.1603648668530234,
 'Test/Re@10': 0.6735324407826982,
 'Test/Re@2': 0.31984699131969985,
 'Test/Re@5': 0.5776077681329999,
 'Test/Rev@1': 0.1398259733054742,
 'Test/Rev@10': 

[{'Test/F1@1': 0.22836790278650743,
  'Test/F1@2': 0.35366845615747516,
  'Test/F1@5': 0.38231570746908167,
  'Test/F1@10': 0.26711009977244876,
  'Test/Pr@1': 0.3965078210258276,
  'Test/Pr@2': 0.39548844824449697,
  'Test/Pr@5': 0.2857142857142857,
  'Test/Pr@10': 0.16658782431498126,
  'Test/Re@1': 0.1603648668530234,
  'Test/Re@2': 0.31984699131969985,
  'Test/Re@5': 0.5776077681329999,
  'Test/Re@10': 0.6735324407826982,
  'Test/Rev@1': 0.1398259733054742,
  'Test/Rev@2': 0.22284642973679733,
  'Test/Rev@5': 0.39500999701544054,
  'Test/Rev@10': 0.4737276623781135,
  'F1-top': 0.30786554154637824,
  'val_loss_epoch': 3.0153324604034424,
  'val_loss': 1.4027796983718872,
  'Val/F1-top': 0.27629969836209844,
  'Val/F1@1': 0.16279809220985691,
  'Val/F1@2': 0.2631032693305656,
  'Val/F1@5': 0.379869587025581,
  'Val/F1@10': 0.2994278448823903,
  'Val/Pr@1': 0.36056338028169016,
  'Val/Pr@2': 0.3572938689217759,
  'Val/Pr@5': 0.320360970107163,
  'Val/Pr@10': 0.20177066971298016,
  'V