## Training with SAINT (Self-Attention and Intersample Attention Transformer) model

Tabular data underpins numerous high-impact applications of machine learning from fraud detection to genomics and healthcare. Classical approaches to solving tabular problems, such as gradient boosting and random forests, are widely used. SAINT, performs attention over both rows and columns, and it includes an enhanced embedding method. SAINT consistently improves performance over previous deep learning methods, and it even outperforms gradient boosting methods, including XGBoost, CatBoost, and LightGBM, on average over a variety of benchmark tasks.

<center><img src="https://media.arxiv-vanity.com/render-output/6225784/figs/Tabattention_training.png" alt="SAINT pipeline" width="600"/></center>
<center>SAINT model. Source: <a href="https://arxiv.org/pdf/2106.01342v1.pdf">Paper(arxiv),</a> <a href="https://www.arxiv-vanity.com/papers/2106.01342/">image</a>.</center>


### Import libraries

In [1]:
import numpy as np
import pandas as pd
import pickle, optuna, torch, pytorch_lightning
from pathlib import Path
from sklearn.metrics import classification_report, roc_auc_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
from lit_saint import Saint, SaintConfig, SaintDatamodule, SaintTrainer
from pytorch_lightning import Trainer, seed_everything
from IPython.display import clear_output
import functools

### Load Dataset

In [2]:
with open('../data/dataset_df.pkl', 'rb') as f:
    df = pickle.load(f)

col_types = {
    'MONTHS_IN_RESIDENCE':             'float64',
    'PERSONAL_MONTHLY_INCOME':         'float64',
    'OTHER_INCOMES':                   'float64',
    'PERSONAL_ASSETS_VALUE':           'float64',
    'MONTHS_IN_THE_JOB':                 'int64',
    'AGE':                               'int64'
}

for col in df.columns:
    if col in col_types.keys():
        df[col] = df[col].astype(col_types[col])
    else:
        df[col] = df[col].astype('category')

df_train, df_test = train_test_split(df, test_size=0.20, random_state=42)
df_train, df_val = train_test_split(df_train, test_size=0.25, random_state=42)
df_train["split"] = "train"
df_val["split"] = "validation"
df = pd.concat([df_train, df_val])

### Configuration object
cfg = SaintConfig()

### Define functions to get and set values to nested attributes of configuration object.

In [None]:
def rsetattr(obj, attr, val):
    """
    Set a value to a nested attribute of an object.
    Parameters
    ----------
    obj : Object
        object whose nested attribute has to be set.
    attr : str
        string that contains the nested attribute's name.
    val : *
        value given to the nested attribute.
    """
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

def rgetattr(obj, attr):
    """
    Returns the value of the named nested attribute of an object. 
    Parameters
    ----------
    obj : Object
        object whose named nested attribute's value is to be returned.
    attr : str
        string that contains the nested attribute's name.

    Returns
    -------
    value of the named nested attribute of the given object.
    """
    def _getattr(obj, attr):
        return getattr(obj, attr)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

### Hyperparameter tunning with Optuna

In [None]:
data_module = SaintDatamodule(df=df, target=df.columns[28], split_column="split")
pretrain_loader_params = {'batch_size': 1000, 'num_workers': 8}
train_loader_params = {'batch_size': 10000, 'num_workers': 8}

def objective(trial):
    """
    Optuna's objective function to be optimized. 

    Parameters
    ----------
    trial : Object
        Optuna's trial object, defines params grid to feed the SAINT model.

    Returns
    -------
    roc_auc_score : float
        ROC AUC score of the current trial.
    recall_score : float
        Recall score of the current trial.
    """
    params = {
        'network.transformer.depth': trial.suggest_int('network.transformer.depth', 2,10,1),
        'network.transformer.heads': trial.suggest_int('network.transformer.heads', 1,5,1),
        'network.transformer.dropout': trial.suggest_float('network.transformer.dropout', 0.1,0.9),
        'network.transformer.dim_head': 2 ** trial.suggest_int('network.transformer.dim_head', 4,7,1),
        'pretrain.aug.cutmix.lam': trial.suggest_float('pretrain.aug.cutmix.lam', 0.1,0.9),
        'pretrain.aug.mixup.lam': trial.suggest_float('pretrain.aug.mixup.lam', 0.1,0.9),
        'pretrain.task.contrastive.nce_temp': trial.suggest_float('pretrain.task.contrastive.nce_temp', 0.1,0.9),
        'pretrain.task.contrastive.weight': trial.suggest_float('pretrain.task.contrastive.weight', 0.1,0.9),
        'pretrain.task.contrastive.dropout': trial.suggest_float('pretrain.task.contrastive.dropout', 0.1,0.9),
        'pretrain.task.denoising.weight_cross_entropy': trial.suggest_float('pretrain.task.denoising.weight_cross_entropy', 0.1,0.9),
        'pretrain.task.denoising.weight_mse': trial.suggest_float('pretrain.task.denoising.weight_mse', 0.1,0.9),
        'pretrain.task.denoising.dropout': trial.suggest_float('pretrain.task.denoising.dropout', 0.1,0.9),
        'pretrain.optimizer.learning_rate': trial.suggest_float('pretrain.optimizer.learning_rate', 1e-5,1e-3, log=True),
        'train.optimizer.learning_rate': trial.suggest_float('train.optimizer.learning_rate', 1e-5,1e-3, log=True),
        'train.mlpfory_dropout': trial.suggest_float('train.mlpfory_dropout', 0.1,0.9),
        'train.internal_dimension_output_layer': trial.suggest_int('train.internal_dimension_output_layer', 10,30,10),
        'pretrain.epochs': 90, 
        'train.epochs': 400
    }
    
    for key in params.keys():
        rsetattr(cfg, key, params[key])
    print(cfg)
    
    model = Saint(categories=data_module.categorical_dims, continuous=data_module.numerical_columns,
                  config=cfg, dim_target=data_module.dim_target)
    pretrainer = Trainer(max_epochs=cfg.pretrain.epochs, accelerator='gpu', devices=-1, log_every_n_steps=30, enable_progress_bar=False)
    trainer = Trainer(max_epochs=cfg.train.epochs, accelerator='gpu', devices=-1, log_every_n_steps=15, enable_progress_bar=False)
    saint_trainer = SaintTrainer(pretrainer=pretrainer, trainer=trainer, 
                                pretrain_loader_params=pretrain_loader_params, train_loader_params=train_loader_params)
    saint_trainer.fit(model=model, datamodule=data_module, enable_pretraining=True)

    prediction = saint_trainer.predict(model=model, datamodule=data_module, df=df_val)
    y_pred = prediction["prediction"][:,1]
    y_test = df_val[df.columns[28]]
    
    clear_output(wait=True)
    return roc_auc_score(y_test, y_pred), recall_score(y_test, y_pred.round())

study = optuna.create_study(directions=['maximize', 'maximize'])
study.optimize(objective, n_trials=500, timeout=60*60*10)

### Results of the hyperparameter tunning

In [21]:
roc_auc, recall = study.best_trials[0].values
best_params = study.best_trials[0].params
print('Metrics on validation dataset using the best params:')
print(f'ROC AUC: {roc_auc}, Recall: {recall}')
best_params

Metrics on validation dataset using the best params:
ROC AUC: 0.6273642907415584, Recall: 0.01466615206483983


{'network.transformer.depth': 7,
 'network.transformer.heads': 3,
 'network.transformer.dropout': 0.13811672612633155,
 'network.transformer.dim_head': 5,
 'pretrain.aug.cutmix.lam': 0.7280559655740964,
 'pretrain.aug.mixup.lam': 0.535127688615649,
 'pretrain.task.contrastive.nce_temp': 0.5680516605800289,
 'pretrain.task.contrastive.weight': 0.2691343955809094,
 'pretrain.task.contrastive.dropout': 0.518691770224818,
 'pretrain.task.denoising.weight_cross_entropy': 0.20017568961839238,
 'pretrain.task.denoising.weight_mse': 0.35979232505043923,
 'pretrain.task.denoising.dropout': 0.8164470869440895,
 'pretrain.optimizer.learning_rate': 3.5047338907672324e-05,
 'train.optimizer.learning_rate': 0.0004803097270582274,
 'train.mlpfory_dropout': 0.3976845870661899,
 'train.internal_dimension_output_layer': 20}

### Trainning the best model

In [None]:
data_module = SaintDatamodule(df=df, target=df.columns[28], split_column="split")
pretrain_loader_params = {'batch_size': 2500, 'num_workers': 8}
train_loader_params = {'batch_size': 10000, 'num_workers': 8}
params = {
    'pretrain.epochs': 90,
    'train.epochs': 400
}
for key in params.keys():
    rsetattr(cfg, key, params[key])
for key in best_params.keys():
    rsetattr(cfg, key, params[key])

model = Saint(categories=data_module.categorical_dims, continuous=data_module.numerical_columns,
              config=cfg, dim_target=data_module.dim_target)
pretrainer = Trainer(max_epochs=cfg.pretrain.epochs, accelerator='gpu', devices=-1, log_every_n_steps=10, enable_progress_bar=False)
trainer = Trainer(max_epochs=cfg.train.epochs, accelerator='gpu', devices=-1, log_every_n_steps=1, enable_progress_bar=False)
saint_trainer = SaintTrainer(pretrainer=pretrainer, trainer=trainer, 
                             pretrain_loader_params=pretrain_loader_params, train_loader_params=train_loader_params)
saint_trainer.fit(model=model, datamodule=data_module, enable_pretraining=True)

### Make predictions

In [116]:
prediction = saint_trainer.predict(model=model, datamodule=data_module, df=df_val)
y_pred = prediction["prediction"][:,1]
y_test = df_val[df.columns[28]]
print(f'ROC_AUC: {roc_auc_score(y_test, y_pred):.4f}, Recall: {recall_score(y_test, y_pred.round()):.4f}')

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 3it [00:00, ?it/s]

ROC_AUC: 0.6204, Recall: 0.0000
