In [1]:
import dataset
import datetime
from datetime import timedelta
from parser import get_parser
import numpy as np 
import pandas as pd 
import torch
from pytorch_lightning import Trainer
from torch_geometric.utils import to_undirected

## Load data

In [2]:
data = dataset.Tdata(path='../Custom-Semi-Supervised/data/tdata.csv')
parser = get_parser()
args = parser.parse_args(args=
                         ["--data","real-t", 
                          "--sampling","xgb",
                          "--mode","scratch",
                          "--train_from","20170101",
                          "--test_from","20190101",
                          "--test_length","365",
                          "--valid_length","90",
                          "--initial_inspection_rate", "5",
                          "--final_inspection_rate", "10",
                         ])

In [3]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
mode = args.mode
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (77391, 41), Train unlabeled: (1470434, 41), Valid labeled: (134457, 41), Valid unlabeled: (0, 13), Test: (703090, 41)
Checking label distribution
Training: 0.09757342825942052
Validation: 0.09589052260946108
Testing: 0.10476480792437651


## Prepare DATA

In [4]:
from utils import *
from pygData_util import *

In [5]:
categories=["importer.id","HS6"]
gdata = GraphData(data,use_xgb=True, categories=categories)

Training XGBoost model...


In [6]:
best_thresh, best_auc = find_best_threshold(gdata.xgb,data.dfvalidx_lab, data.valid_cls_label)
xgb_test_pred = gdata.xgb.predict_proba(data.dfvalidx_lab)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.valid_cls_label,data.valid_reg_label,best_thresh)
print("-"*50)
xgb_test_pred = gdata.xgb.predict_proba(data.dftestx)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.test_cls_label,data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 1345
Precision: 0.6387, Recall: 0.0730, Revenue: 0.0754
Checking top 2% suspicious transactions: 2690
Precision: 0.5647, Recall: 0.1291, Revenue: 0.1244
Checking top 5% suspicious transactions: 6723
Precision: 0.4409, Recall: 0.2519, Revenue: 0.2514
Checking top 10% suspicious transactions: 13446
Precision: 0.3244, Recall: 0.3708, Revenue: 0.3665
--------------------------------------------------
Checking top 1% suspicious transactions: 7030
Precision: 0.7296, Recall: 0.0769, Revenue: 0.1280
Checking top 2% suspicious transactions: 14062
Precision: 0.6358, Recall: 0.1341, Revenue: 0.2154
Checking top 5% suspicious transactions: 35155
Precision: 0.4679, Recall: 0.2467, Revenue: 0.3736
Checking top 10% suspicious transactions: 70309
Precision: 0.3369, Recall: 0.3553, Revenue: 0.4870


In [7]:
stage = "train_lab"
trainLab_data = gdata.get_data(stage)
train_nodeidx = torch.tensor(gdata.get_AttNode(stage))
trainLab_data.node_idx = train_nodeidx

In [8]:
stage = "train_unlab"
unlab_data = gdata.get_data(stage)
unlab_nodeidx = torch.tensor(gdata.get_AttNode(stage))
unlab_data.node_idx = unlab_nodeidx

In [9]:
stage = "valid"
valid_data = gdata.get_data(stage)
valid_nodeidx = torch.tensor(gdata.get_AttNode(stage))
valid_data.node_idx = valid_nodeidx

In [10]:
stage = "test"
test_data = gdata.get_data(stage)
test_nodeidx = torch.tensor(gdata.get_AttNode(stage))
test_data.node_idx = test_nodeidx

In [11]:
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)

## New Sampler

In [12]:
from torch_cluster import random_walk
from torch_geometric.data import NeighborSampler 
from pytorch_lightning import LightningDataModule

In [13]:
class UnsupSampler(NeighborSampler):
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(UnsupSampler, self).sample(batch)

In [14]:
class Batch(NamedTuple):
    '''
    convert batch data for pytorch-lightning
    '''
    x: Tensor
    y: Tensor
    rev: Tensor
    adjs_t: NamedTuple
    def to(self, *args, **kwargs):
        return Batch(
            x=self.x.to(*args, **kwargs),
            y=self.y.to(*args, **kwargs),
            rev=self.rev.to(*args, **kwargs),
            adjs_t=[(adj_t.to(*args, **kwargs), eid.to(*args, **kwargs), size) for adj_t, eid, size in self.adjs_t],
        )


In [15]:
class UnsupData(LightningDataModule):
    def __init__(self,data,sizes, batch_size = 128):
        '''
        defining dataloader with NeighborSampler to extract k-hop subgraph.
        Args:
            data (Graphdata): graph data for the edges and node index
            sizes ([int]): The number of neighbors to sample for each node in each layer. 
                           If set to :obj:`sizes[l] = -1`, all neighbors are included
            batch_size (int): batch size for training
        '''
        super(UnsupData,self).__init__()
        self.data = data
        self.sizes = sizes
        self.valid_sizes = [-1 for i in self.sizes]
        self.batch_size = batch_size

    def train_dataloader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=True,num_workers=8)
    
    def test_dataloader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.test_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)
    
    def label_loader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.train_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)

    def convert_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.data.x[n_id],
            y=self.data.y[n_id[:batch_size]],
            rev = self.data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )

## Model

In [16]:
from models import MLP, GNNStack, UselessConv, Mish

In [17]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import LightningModule, seed_everything
from torchtools.optim import RangerLars
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F

In [19]:
class PretrainGNN(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.save_hyperparameters()
        self.input_dim = input_dim
        self.dim = hidden_dim*2
        self.numLayers = numLayers
        self.layers = [self.dim, self.dim//2] #* (numLayers+1)
        self.bn = nn.BatchNorm1d(self.dim)
        self.act = Mish()
        self.useXGB = useXGB
        
        # GNN layer
        if self.useXGB:
            self.initEmbedding = nn.Embedding(self.input_dim, self.dim, padding_idx=0)
        else:
            self.initEmbedding = MLP(self.input_dim, self.dim, Numlayer=2)
        self.initGNN = UselessConv()
        self.GNNs = GNNStack(self.layers,self.numLayers)

    def forward(self, x,adjs):
        # update node embedding
        leaf_emb = self.initEmbedding(x)
        if self.useXGB:
            leaf_emb = torch.sum(leaf_emb,dim=1) # summation over the trees
            leaf_emb = self.bn(leaf_emb)
            leaf_emb = self.act(leaf_emb)
        
        # first update 
        firstHop_neighbor = adjs[-1][0]
        leaf_emb = self.initGNN(leaf_emb,to_undirected(firstHop_neighbor))
        
        # GNN 
        embeddings = self.GNNs(leaf_emb, adjs)
        
        return embeddings[-1]
    
    def training_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
        neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
        train_loss = -pos_loss - neg_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def test_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        return out
    
    def test_epoch_end(self, val_step_outputs):
        val_step_outputs = torch.cat(val_step_outputs)
        val_step_outputs = val_step_outputs.cpu().detach().numpy()
        return {"log":{"predictions":val_step_outputs}}
    
    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.005, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.8)
        return [optimizer], [scheduler]

In [20]:
# model config
# seed_everything(2345)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [50,20]
numLayers = len(sizes)
batch_size = 1024

model = PretrainGNN(input_dim, hidden_size, numLayers, useXGB=gdata.use_xgb)

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = UnsupData(stacked_data, sizes = sizes, batch_size=batch_size)
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',    
    dirpath='./saved_model',
    filename='Tdata5-pretrain-{train_loss:.4f}',
    save_top_k=1,
    mode='min',
)
trainer = Trainer(gpus=[0], max_epochs=2,
#                   callbacks=[checkpoint_callback],
                 )
trainer.fit(model, train_dataloader=datamodule.train_dataloader(),
           )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type        | Params
----------------------------------------------
0 | bn            | BatchNorm1d | 128   
1 | act           | Mish        | 0     
2 | initEmbedding | Embedding   | 98.8 K
3 | initGNN       | UselessConv | 0     
4 | GNNs          | GNNStack    | 33.5 K
----------------------------------------------
132 K     Trainable params
0         Non-trainable params
132 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1

In [None]:
trainer.test()

In [None]:
# trainer.save_checkpoint("./saved_model/pretrained.ckpt")

In [None]:
from tqdm import tqdm_notebook

In [None]:
test_embeddings = []
for batch in tqdm_notebook(datamodule.test_dataloader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    test_embeddings.append(out.cpu().detach().numpy())

In [None]:
embeddings = []
for batch in tqdm_notebook(datamodule.label_loader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    embeddings.append(out.cpu().detach().numpy())

In [None]:
embeddings = np.concatenate(embeddings)
test_embeddings = np.concatenate(test_embeddings)

In [None]:
from sklearn.linear_model import LogisticRegression

In [None]:
lr = LogisticRegression(class_weight={1:50,0:1})
lr.fit(embeddings,data.train_cls_label)

In [None]:
y_prob = lr.predict_proba(test_embeddings)[:,1]

In [None]:
_ = metrics(y_prob, data.test_cls_label,data.test_reg_label,None)

## Fine tuning

In [21]:
from torch_geometric.nn import GATConv,TransformerConv
from pygData_util import *

In [22]:
class Predictor(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.gnn_encoder = PretrainGNN(input_dim, hidden_size, numLayers, useXGB)
        self.dim = hidden_dim * 2
        
        # output
        self.clsLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.revLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.loss_func = nn.BCEWithLogitsLoss(pos_weight = torch.tensor([1])) #FocalLoss(logits=True)
        
    def load_fromPretrain(self,path):
        self.gnn_encoder.load_from_checkpoint(path)
        
    def loadGNN_state(self,model):
        self.gnn_encoder.load_state_dict(model.state_dict())

    def forward(self, x,adjs):
#         firstHop_neighbor = adjs[-1][0]
        # get node embedding from pre-trained model
        embedding = self.gnn_encoder(x,adjs)
        logit = self.clsLayer(embedding)
        revenue = self.revLayer(embedding)
        
        return logit, revenue
    
    def compute_CLS_loss(self,logit, label):
        logit = logit.flatten()
        loss = self.loss_func(logit,label)
        return loss
    
    def compute_REG_loss(self,pred_rev, rev):
        pred_rev = pred_rev.flatten()
        loss = F.mse_loss(pred_rev,rev)
        return loss 

    def training_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        train_loss = CLS_loss + 10 * REG_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def validation_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        valid_loss = CLS_loss + 1 * REG_loss
        self.log('val_loss', valid_loss, on_step=True, on_epoch=True, sync_dist=True)
        return logits
    
    def validation_epoch_end(self, val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.valid_cls_label, self.data.valid_reg_label)
        f1_top = np.mean(f)
        self.log("F1-top",f1_top)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Val/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        return {"Val/F1-top":f1_top, "log":tensorboard_logs}
        
    def test_step(self,batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self,val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.test_cls_label, self.data.test_reg_label)
        f1_top = np.mean(f)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Test/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        
        return {"F1-top":f1_top, "log":tensorboard_logs}

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.01, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]

In [None]:
# model config
seed_everything(1324)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [-1,200]
numLayers = len(sizes)
batch_size = 512
pretrain_path = "./saved_model/pretrained.ckpt"

predictor = Predictor(input_dim, hidden_size, numLayers)
predictor.loadGNN_state(model)
# predictor.load_fromPretrain(pretrain_path)
predictor.data = data

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = CustomData(stacked_data, sizes = sizes, batch_size=batch_size)
logger = TensorBoardLogger("ssl_exp",name="Tdata")
logger.log_hyperparams(model.hparams, metrics={"F1-top":0})
checkpoint_callback = ModelCheckpoint(
    monitor='F1-top',    
    dirpath='./saved_model',
    filename='Analysis-Tdata5-{F1-top:.4f}',
    save_top_k=1,
    mode='max',
)
trainer = Trainer(gpus=[0], max_epochs=40,
                 num_sanity_val_steps=0,
                  check_val_every_n_epoch=1,
                  callbacks=[checkpoint_callback],
                 )
trainer.fit(predictor, datamodule=datamodule)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type              | Params
--------------------------------------------------
0 | gnn_encoder | PretrainGNN       | 132 K 
1 | clsLayer    | Linear            | 65    
2 | revLayer    | Linear            | 65    
3 | loss_func   | BCEWithLogitsLoss | 0     
--------------------------------------------------
132 K     Trainable params
0         Non-trainable params
132 K     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5100, Recall: 0.0583, Revenue: 0.0537
Checking top 2% suspicious transactions: 2690
Precision: 0.5257, Recall: 0.1202, Revenue: 0.1138
Checking top 5% suspicious transactions: 6723
Precision: 0.4428, Recall: 0.2530, Revenue: 0.2690
Checking top 10% suspicious transactions: 13446
Precision: 0.3258, Recall: 0.3724, Revenue: 0.4031


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.5807, Recall: 0.0664, Revenue: 0.0538
Checking top 2% suspicious transactions: 2689
Precision: 0.6010, Recall: 0.1374, Revenue: 0.1313
Checking top 5% suspicious transactions: 6723
Precision: 0.4687, Recall: 0.2678, Revenue: 0.2837
Checking top 10% suspicious transactions: 13446
Precision: 0.3359, Recall: 0.3839, Revenue: 0.4121


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6885, Recall: 0.0787, Revenue: 0.0760
Checking top 2% suspicious transactions: 2690
Precision: 0.6346, Recall: 0.1451, Revenue: 0.1427
Checking top 5% suspicious transactions: 6723
Precision: 0.4840, Recall: 0.2766, Revenue: 0.3000
Checking top 10% suspicious transactions: 13446
Precision: 0.3427, Recall: 0.3917, Revenue: 0.4218


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7138, Recall: 0.0816, Revenue: 0.0804
Checking top 2% suspicious transactions: 2690
Precision: 0.6193, Recall: 0.1416, Revenue: 0.1396
Checking top 5% suspicious transactions: 6723
Precision: 0.4825, Recall: 0.2757, Revenue: 0.3074
Checking top 10% suspicious transactions: 13446
Precision: 0.3357, Recall: 0.3837, Revenue: 0.3986


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7175, Recall: 0.0820, Revenue: 0.0839
Checking top 2% suspicious transactions: 2690
Precision: 0.6152, Recall: 0.1407, Revenue: 0.1354
Checking top 5% suspicious transactions: 6723
Precision: 0.4794, Recall: 0.2739, Revenue: 0.3047
Checking top 10% suspicious transactions: 13446
Precision: 0.3371, Recall: 0.3853, Revenue: 0.4034


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7383, Recall: 0.0844, Revenue: 0.0877
Checking top 2% suspicious transactions: 2690
Precision: 0.6312, Recall: 0.1443, Revenue: 0.1461
Checking top 5% suspicious transactions: 6720
Precision: 0.4914, Recall: 0.2807, Revenue: 0.3143
Checking top 10% suspicious transactions: 13446
Precision: 0.3417, Recall: 0.3906, Revenue: 0.4211


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7569, Recall: 0.0865, Revenue: 0.0853
Checking top 2% suspicious transactions: 2690
Precision: 0.6312, Recall: 0.1443, Revenue: 0.1363
Checking top 5% suspicious transactions: 6723
Precision: 0.4795, Recall: 0.2740, Revenue: 0.2932
Checking top 10% suspicious transactions: 13446
Precision: 0.3397, Recall: 0.3882, Revenue: 0.4082


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6840, Recall: 0.0782, Revenue: 0.0828
Checking top 2% suspicious transactions: 2690
Precision: 0.5989, Recall: 0.1369, Revenue: 0.1321
Checking top 5% suspicious transactions: 6723
Precision: 0.4743, Recall: 0.2711, Revenue: 0.2971
Checking top 10% suspicious transactions: 13446
Precision: 0.3311, Recall: 0.3784, Revenue: 0.4118


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6996, Recall: 0.0800, Revenue: 0.0895
Checking top 2% suspicious transactions: 2690
Precision: 0.6037, Recall: 0.1380, Revenue: 0.1436
Checking top 5% suspicious transactions: 6723
Precision: 0.4836, Recall: 0.2763, Revenue: 0.3100
Checking top 10% suspicious transactions: 13446
Precision: 0.3390, Recall: 0.3874, Revenue: 0.4163


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6967, Recall: 0.0796, Revenue: 0.0832
Checking top 2% suspicious transactions: 2690
Precision: 0.6011, Recall: 0.1374, Revenue: 0.1434
Checking top 5% suspicious transactions: 6723
Precision: 0.4821, Recall: 0.2755, Revenue: 0.2988
Checking top 10% suspicious transactions: 13446
Precision: 0.3417, Recall: 0.3905, Revenue: 0.4145


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6974, Recall: 0.0797, Revenue: 0.0858
Checking top 2% suspicious transactions: 2690
Precision: 0.6145, Recall: 0.1405, Revenue: 0.1430
Checking top 5% suspicious transactions: 6723
Precision: 0.4971, Recall: 0.2841, Revenue: 0.3239
Checking top 10% suspicious transactions: 13446
Precision: 0.3452, Recall: 0.3945, Revenue: 0.4180


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7517, Recall: 0.0859, Revenue: 0.0973
Checking top 2% suspicious transactions: 2690
Precision: 0.6409, Recall: 0.1465, Revenue: 0.1663
Checking top 5% suspicious transactions: 6723
Precision: 0.4788, Recall: 0.2736, Revenue: 0.3176
Checking top 10% suspicious transactions: 13446
Precision: 0.3383, Recall: 0.3867, Revenue: 0.4135


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6944, Recall: 0.0794, Revenue: 0.0878
Checking top 2% suspicious transactions: 2690
Precision: 0.6112, Recall: 0.1397, Revenue: 0.1537
Checking top 5% suspicious transactions: 6723
Precision: 0.4784, Recall: 0.2734, Revenue: 0.3131
Checking top 10% suspicious transactions: 13446
Precision: 0.3338, Recall: 0.3815, Revenue: 0.4153


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7063, Recall: 0.0807, Revenue: 0.0945
Checking top 2% suspicious transactions: 2690
Precision: 0.6275, Recall: 0.1435, Revenue: 0.1559
Checking top 5% suspicious transactions: 6723
Precision: 0.4746, Recall: 0.2712, Revenue: 0.3186
Checking top 10% suspicious transactions: 13446
Precision: 0.3362, Recall: 0.3843, Revenue: 0.4274


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7457, Recall: 0.0853, Revenue: 0.0971
Checking top 2% suspicious transactions: 2690
Precision: 0.6446, Recall: 0.1474, Revenue: 0.1605
Checking top 5% suspicious transactions: 6723
Precision: 0.4922, Recall: 0.2813, Revenue: 0.3467
Checking top 10% suspicious transactions: 13446
Precision: 0.3386, Recall: 0.3870, Revenue: 0.4393


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6625, Recall: 0.0757, Revenue: 0.0857
Checking top 2% suspicious transactions: 2690
Precision: 0.6045, Recall: 0.1382, Revenue: 0.1558
Checking top 5% suspicious transactions: 6723
Precision: 0.4788, Recall: 0.2736, Revenue: 0.3280
Checking top 10% suspicious transactions: 13446
Precision: 0.3400, Recall: 0.3886, Revenue: 0.4220


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7004, Recall: 0.0801, Revenue: 0.0938
Checking top 2% suspicious transactions: 2690
Precision: 0.6089, Recall: 0.1392, Revenue: 0.1531
Checking top 5% suspicious transactions: 6723
Precision: 0.4819, Recall: 0.2754, Revenue: 0.3270
Checking top 10% suspicious transactions: 13446
Precision: 0.3429, Recall: 0.3918, Revenue: 0.4240


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7747, Recall: 0.0886, Revenue: 0.1175
Checking top 2% suspicious transactions: 2690
Precision: 0.6335, Recall: 0.1448, Revenue: 0.1632
Checking top 5% suspicious transactions: 6723
Precision: 0.4681, Recall: 0.2675, Revenue: 0.3089
Checking top 10% suspicious transactions: 13446
Precision: 0.3282, Recall: 0.3751, Revenue: 0.4229


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6848, Recall: 0.0783, Revenue: 0.0997
Checking top 2% suspicious transactions: 2690
Precision: 0.6004, Recall: 0.1373, Revenue: 0.1541
Checking top 5% suspicious transactions: 6723
Precision: 0.4641, Recall: 0.2652, Revenue: 0.3193
Checking top 10% suspicious transactions: 13446
Precision: 0.3334, Recall: 0.3810, Revenue: 0.4137


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6766, Recall: 0.0773, Revenue: 0.1079
Checking top 2% suspicious transactions: 2690
Precision: 0.6059, Recall: 0.1385, Revenue: 0.1591
Checking top 5% suspicious transactions: 6723
Precision: 0.4760, Recall: 0.2720, Revenue: 0.3191
Checking top 10% suspicious transactions: 13446
Precision: 0.3318, Recall: 0.3792, Revenue: 0.4184


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7234, Recall: 0.0827, Revenue: 0.1098
Checking top 2% suspicious transactions: 2690
Precision: 0.5989, Recall: 0.1369, Revenue: 0.1578
Checking top 5% suspicious transactions: 6723
Precision: 0.4581, Recall: 0.2618, Revenue: 0.3077
Checking top 10% suspicious transactions: 13446
Precision: 0.3308, Recall: 0.3781, Revenue: 0.4108


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6900, Recall: 0.0789, Revenue: 0.1080
Checking top 2% suspicious transactions: 2690
Precision: 0.6048, Recall: 0.1383, Revenue: 0.1711
Checking top 5% suspicious transactions: 6723
Precision: 0.4754, Recall: 0.2717, Revenue: 0.3218
Checking top 10% suspicious transactions: 13446
Precision: 0.3350, Recall: 0.3828, Revenue: 0.4217


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7078, Recall: 0.0809, Revenue: 0.1022
Checking top 2% suspicious transactions: 2690
Precision: 0.6167, Recall: 0.1410, Revenue: 0.1788
Checking top 5% suspicious transactions: 6723
Precision: 0.4764, Recall: 0.2722, Revenue: 0.3406
Checking top 10% suspicious transactions: 13446
Precision: 0.3323, Recall: 0.3798, Revenue: 0.4260


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7301, Recall: 0.0835, Revenue: 0.1093
Checking top 2% suspicious transactions: 2690
Precision: 0.6257, Recall: 0.1431, Revenue: 0.1623
Checking top 5% suspicious transactions: 6723
Precision: 0.4684, Recall: 0.2677, Revenue: 0.3120
Checking top 10% suspicious transactions: 13446
Precision: 0.3389, Recall: 0.3873, Revenue: 0.4249


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6900, Recall: 0.0789, Revenue: 0.1075
Checking top 2% suspicious transactions: 2690
Precision: 0.6059, Recall: 0.1385, Revenue: 0.1628
Checking top 5% suspicious transactions: 6723
Precision: 0.4672, Recall: 0.2670, Revenue: 0.3138
Checking top 10% suspicious transactions: 13446
Precision: 0.3387, Recall: 0.3871, Revenue: 0.4196


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.6186, Recall: 0.0707, Revenue: 0.0965
Checking top 2% suspicious transactions: 2690
Precision: 0.5755, Recall: 0.1316, Revenue: 0.1506
Checking top 5% suspicious transactions: 6723
Precision: 0.4647, Recall: 0.2655, Revenue: 0.3095
Checking top 10% suspicious transactions: 13446
Precision: 0.3304, Recall: 0.3776, Revenue: 0.4194


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7011, Recall: 0.0802, Revenue: 0.1079
Checking top 2% suspicious transactions: 2690
Precision: 0.6186, Recall: 0.1414, Revenue: 0.1672
Checking top 5% suspicious transactions: 6723
Precision: 0.4639, Recall: 0.2651, Revenue: 0.3172
Checking top 10% suspicious transactions: 13446
Precision: 0.3360, Recall: 0.3840, Revenue: 0.4222


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 1345
Precision: 0.7026, Recall: 0.0803, Revenue: 0.1068
Checking top 2% suspicious transactions: 2690
Precision: 0.6093, Recall: 0.1393, Revenue: 0.1608
Checking top 5% suspicious transactions: 6723
Precision: 0.4674, Recall: 0.2671, Revenue: 0.3204
Checking top 10% suspicious transactions: 13446
Precision: 0.3365, Recall: 0.3846, Revenue: 0.4119


In [25]:
trainer.test()

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

Checking top 1% suspicious transactions: 7031
Precision: 0.8175, Recall: 0.0862, Revenue: 0.1513
Checking top 2% suspicious transactions: 14062
Precision: 0.7298, Recall: 0.1539, Revenue: 0.2582
Checking top 5% suspicious transactions: 35155
Precision: 0.5184, Recall: 0.2733, Revenue: 0.4318
Checking top 10% suspicious transactions: 70309
Precision: 0.3510, Recall: 0.3702, Revenue: 0.5326

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'F1-top': 0.2821087085208171,
 'Test/F1@1': 0.15597313615087172,
 'Test/F1@10': 0.36033668411408715,
 'Test/F1@2': 0.25421125644074516,
 'Test/F1@5': 0.3579137573775643,
 'Test/Pr@1': 0.8175224007964728,
 'Test/Pr@10': 0.35102191753545064,
 'Test/Pr@2': 0.7297681695349167,
 'Test/Pr@5': 0.5183615417437064,
 'Test/Re@1': 0.08621051684314725,
 'Test/Re@10': 0.3701592824789273,
 'Test/Re@2': 0.15391306956234815,
 'Test/Re@5': 0.2733149353571107,
 'Test/Rev@1': 0.15128984825490807,
 'Test/Rev@10': 

[{'Test/F1@1': 0.15597313615087172,
  'Test/F1@2': 0.25421125644074516,
  'Test/F1@5': 0.3579137573775643,
  'Test/F1@10': 0.36033668411408715,
  'Test/Pr@1': 0.8175224007964728,
  'Test/Pr@2': 0.7297681695349167,
  'Test/Pr@5': 0.5183615417437064,
  'Test/Pr@10': 0.35102191753545064,
  'Test/Re@1': 0.08621051684314725,
  'Test/Re@2': 0.15391306956234815,
  'Test/Re@5': 0.2733149353571107,
  'Test/Re@10': 0.3701592824789273,
  'Test/Rev@1': 0.15128984825490807,
  'Test/Rev@2': 0.2581825200543137,
  'Test/Rev@5': 0.4317528452617966,
  'Test/Rev@10': 0.5325944658083146,
  'F1-top': 0.2821087085208171,
  'val_loss_epoch': 0.3280564844608307,
  'val_loss': 0.31299495697021484,
  'Val/F1-top': 0.26713715287047246,
  'Val/F1@1': 0.147673531655225,
  'Val/F1@2': 0.2285714285714286,
  'Val/F1@5': 0.3408697533535266,
  'Val/F1@10': 0.3514338979017096,
  'Val/Pr@1': 0.7197026022304833,
  'Val/Pr@2': 0.6141263940520446,
  'Val/Pr@5': 0.4686895731072438,
  'Val/Pr@10': 0.3294660121969359,
  'Val/R