**NOTES**

- **Comments with two hash symbols (##) are notes for whoever runs this notebook. They are cell-by-cell instruction of what to modify and what to keep as it is.**

- **To avoid out-of-memory issues, it is strongly recommended not to run unneeded cells of code. Some of them are reported for demonstration purposes only and does not need to be run. Please follow the running instructions carefully. Running additional or unnecessary code can lead to excessive memory usage, causing Colab to disconnect.**

---
# **IMPORT AND DRIVE MOUNTING**

This section is responsible for importing the necessary libraries and mounting the Google Drive if you run the notebook in Colab. It ensures that the required dependencies are available and the notebook can access the dataset from Google Drive.

> ## **Drive Mounting and CWD**
>
> **_Important:_**  
>
> **If the "Tabular_Transformer" folder is a shared folder, you will need to create a shortcut to it in your own Drive. You can do this by navigating to the shared folder, right-clicking, and selecting "Add shortcut to Drive". Once you add the shortcut to your Drive, you should be able to access it from Colab as described below. Be sure to set the correct ROOT_DIR path.**

In [None]:
%%capture
## CREATE A SHORTCUT TO THE DRIVE AND RUN THIS CELL
import os
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

## ADJUST THE PATHS IF NEEDED
ROOT_DIR='/content/drive/MyDrive/Tabular_Transformer/Credit_Card'
RAW_DATA_DIR = os.path.join(ROOT_DIR, 'data/raw')
PROCESSED_DATA_DIR = os.path.join(ROOT_DIR, 'data/processed')
VOCAB_DIR = os.path.join(ROOT_DIR, 'vocab')

# navigate to the root directory and run the setup.py file to install the required dependencies
os.chdir(ROOT_DIR)
!pip install -r requirements.txt

>## **Logging**
>
> This cell initializes a basic logging configuration to monitor the activities.

In [None]:
## RUN THIS CELL
from src.utils import setup_logging

logger = setup_logging()

> ## **Imports**
>
> Run this cell to import the required dependencies.

In [None]:
## RUN THIS CELL - NO CHANGES NEEDED
import json
import pickle as pkl
import random
import inspect
from tqdm import tqdm
from pathlib import Path
from typing import Optional, Tuple, List, Any, Dict, Union
import wandb

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import init
from torch.utils.data import Dataset

from scipy import stats
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import SMOTE

from transformers import BertConfig, BertTokenizerFast, DataCollatorForLanguageModeling, BertModel, PreTrainedModel
from transformers import TrainingArguments, Trainer
from transformers.activations import ACT2FN

2024-01-14 11:52:09,475 - INFO - numexpr.utils - NumExpr defaulting to 2 threads.


---
# **DATA EXTRACTION**


> ## **Data Extractor Class**
>
> The DataExtractor class is designed to extract and load the data. Here's a summary of what the class does:
>
> - **Class Initialization:**
>   - The constructor initializes the object with the directory path where the data is stored and the number of samples per file to extract (if the directory contains multiple files).
>   - The inputs are validated and the data is extracted.
>
> - **Data Extraction:**
>   - The *_extract_data* private method allows to extract the data from all the files in the data directory, loading everything into a single unified pandas DataFrame. The Credit Card Dataset used for this project is a public dataset with 24 million transactions from 20000
users. Each transaction (row) has 12 fields (columns) consisting of both continuous and discrete nominal attributes, such
as merchant name, merchant address, transaction amount, etc. [Padhi et al.,  2021]
>
>   - The *_load_from_csv* private method loads the csv data at the specified location into a pandas DataFrame. You can either specify the number of samples per file to extract or set it to None to extract all the samples.
>
>   - The *_merge_time_col* private method merge various time-related columns into a single 'TIMESTAMP' column. This process ensures a unified and consistent time representation, which is crucial for time-series analysis.
>
>   - The *_split_data* private method splits data based on Station and TIMESTAMP columns. The data is splitted into training, validation, and test sets to avoid data leakage between the sets during preprocessing.


[Padhi et al.,  2021]: https://arxiv.org/abs/2011.01843 "Padhi, I., Schiff, Y., Melnyk, I., et al. (2021). Tabular Transformers for Modeling Multivariate Time Series. arXiv:2011.01843 [cs.LG]"

In [None]:
class DataExtractor:

    def __init__(self,
                 data_root_dir: str,
                 samples_per_file: Optional[int]=None,
                 train_size: Optional[float]=0.8,
                 val_size: Optional[float]=0.1) -> None:
        """
        Initialize the DataExtractor module.

        Args:
            - data_root_dir (str):
                The path to the directory containing the data.
            - samples_per_file (int, optional):
                The number of samples to extract from each file in the data directory. Default to None (all samples).
            - train_size (float, optional):
                The percentage of the data to be used for training. Default to 0.8.
            - val_size (float, optional):
                The percentage of the data to be used for validation, the remaining data will be used for testing. Default to 0.1.
        """
        logger.info('Initializing the DataExtractor...')
        self.time_cols = ['Year', 'Month', 'Day', 'Time']
        # helper function to validate the initial inputs
        self._validate_initial_inputs(data_root_dir,
                                      samples_per_file,
                                      train_size,
                                      val_size)
        self.data_root_dir = data_root_dir
        self.samples_per_file = samples_per_file
        self.train_size = train_size
        self.val_size = val_size
        # extract the data into a pandas dataframe
        self._extract_data()
        logger.info("DataExtractor successfully initialized.\n")

    def _validate_initial_inputs(self,
                                 data_root_dir: str,
                                 samples_per_file: Optional[int]=None,
                                 train_size: Optional[float]=0.8,
                                 val_size: Optional[float]=0.1) -> None:
        """Helper function to validate the initial inputs."""

        if not isinstance(data_root_dir, str) or not os.path.isdir(data_root_dir):
            raise ValueError(f'"data_root_dir" must be a valid directory path. Got {data_root_dir}')

        if samples_per_file is not None:
            if not isinstance(samples_per_file, int) or samples_per_file <= 0:
                raise ValueError(f'"samples_per_file" must be a positive integer or None. Got {samples_per_file}')

        if not isinstance(train_size, float) or not 0 <= train_size <= 1:
            raise ValueError(f'"train_size" must be a float between 0 and 1. Got {train_size}')

        if not isinstance(val_size, float) or not 0 <= val_size <= 1:
            raise ValueError(f'"val_size" must be a float between 0 and 1. Got {val_size}')

        if train_size + val_size > 1:
            raise ValueError(f'The sum of "train_size" and "val_size" must be less than or equal to 1.')

    def _merge_time_col(self,
                        data: pd.DataFrame) -> pd.DataFrame:
        """
        Merging the 'year', 'month', 'day', 'time' columns in a unique column named TIMESTAMP.
        The method to_datetime convert in Unix timestamps in nanoseconds.
        Args:
            - data (pd.DataFrame):
                Pandas DataFrame containing the data.

        Returns:
            - data (pd.DataFrame):
                Pandas DataFrame containing the data with a single 'TIMESTAMP' column containing the Unix timestamps.
        """
        splittedTime = data['Time'].str.split(':', expand=True)
        data['TIMESTAMP'] = pd.to_datetime(
            dict(year = data['Year'],
                month = data['Month'],
                day = data['Day'],
                hour = splittedTime[0],
                minute = splittedTime[1])).astype(int)
        # use the MinMaxScaler to transform the time-related columns
        scaler = MinMaxScaler()
        data['TIMESTAMP'] = scaler.fit_transform(data['TIMESTAMP'].values.reshape(-1, 1))
        # drop the time-related columns
        data.drop(columns=self.time_cols, inplace=True)
        return data

    def _load_from_csv(self,
                       file_path: str) -> pd.DataFrame:
        """
        Load a DataFrame from a csv file.

        Args:
            - file_path (str):
                The path to the csv file to be read.

        Returns:
            - pd.DataFrame:
                Pandas DataFrame containing the data.
        """
        # error check for the file_path argument
        if not isinstance(file_path, str):
            raise TypeError('file_path must be a string.')
        # get the dataframe from csv file
        if os.path.exists(file_path):
            data = pd.read_csv(file_path, nrows=self.samples_per_file)
            # check if the dataframe is empty
            if data.empty:
                raise ValueError('Data cannot be empty.')
            # check if the dataframe contains all the required columns
            if not all(col in data.columns for col in self.time_cols):
                missing_cols = [col for col in self.time_cols if col not in data.columns]
                raise ValueError(f"The following columns are missing from the dataframe: {', '.join(missing_cols)}")
            return data
        else:
            raise FileNotFoundError(f'The file at the provided path {os.path.split(file_path)[-1]} was not found.\n')

    def _split_data(self,
                    data: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Split the data into train, validation, and test sets based on timestamp.

        Args:
            - data (pd.DataFrame):
                Pandas DataFrame containing the data.
            - train_size (float):
                Proportion of the dataset to include in the train split (0 to 1).
            - val_size (float):
                Proportion of the dataset to include in the validation split (0 to 1).

        Returns:
            - train_data (pd.DataFrame):
                Pandas DataFrame containing the training data.
            - val_data (pd.DataFrame):
                Pandas DataFrame containing the validation data.
            - test_data (pd.DataFrame):
                Pandas DataFrame containing the test data.
        """
        # sort the data by timestamp
        data = data.sort_values(by='TIMESTAMP')
        # split the data into train, validation, and test sets
        train_end = int(len(data) * self.train_size)
        val_end = train_end + int(len(data) * self.val_size)
        train_data = data.iloc[:train_end]
        val_data = data.iloc[train_end:val_end]
        test_data = data.iloc[val_end:]
        return train_data, val_data, test_data

    def _extract_data(self) -> None:
        """
        Extract the data from the provided directory.
        If multiple files are present, they are merged into a single Pandas DataFrame.
        The data is then split into train, validation, and test sets.
        """
        # files inside the data directory
        all_files = [os.path.join(self.data_root_dir, file) for file in os.listdir(self.data_root_dir) if os.path.isfile(os.path.join(self.data_root_dir, file))]
        # merge all the files into a DataFrame
        train_dataframes = []
        val_dataframes = []
        test_dataframes = []
        for file_path in tqdm(all_files, desc='Data Extraction:'):
            # load the data from the csv file
            data = self._load_from_csv(file_path)
            # merge the time-related columns
            data = self._merge_time_col(data)
            # split the data into train, validation, and test sets
            train_data, val_data, test_data = self._split_data(data)
            train_dataframes.append(train_data)
            val_dataframes.append(val_data)
            test_dataframes.append(test_data)
        # concatenate all the dataframes
        self.train_data = pd.concat(train_dataframes, ignore_index=True)
        self.val_data = pd.concat(val_dataframes, ignore_index=True)
        self.test_data = pd.concat(test_dataframes, ignore_index=True)
        logger.info(
            f'Successfully extracted {len(all_files)} DataFrame. '
            f'Train DataFrame has {len(self.train_data)} rows. '
            f'Validation DataFrame has {len(self.val_data)} rows. '
            f'Test DataFrame has {len(self.test_data)} rows.\n')

> **_Important_:**
>
> **In the next section, the preprocessed data will be loaded to simplify the run of the notebook, thus there's no need to run the code cell below.**
>
> **However, feel free to review the code below to understand how to use the DataExtractor class, but running it is not required and can lead to memory issues!**
>
> **If you wish to run the code below, either load only a subset of the data (we used 5,000,000 samples per file) or restart the kernel before running the next section.**


In [None]:
## NO NEED TO RUN THIS CELL
samples_per_file = 10000

extractor = DataExtractor(data_root_dir=RAW_DATA_DIR,
                          samples_per_file=samples_per_file,
                          train_size=0.8,
                          val_size=0.1)
train_data = extractor.train_data
val_data = extractor.val_data
test_data = extractor.test_data

2024-01-14 11:52:17,188 - INFO - root - Initializing the DataExtractor...
Data Extraction:: 100%|██████████| 1/1 [00:01<00:00,  1.68s/it]
2024-01-14 11:52:19,091 - INFO - root - Successfully extracted 1 DataFrame. Train DataFrame has 8000 rows. Validation DataFrame has 1000 rows. Test DataFrame has 1000 rows.

2024-01-14 11:52:19,094 - INFO - root - DataExtractor successfully initialized.



---
# **VOCABULARY**

> ## **Vocabulary Class**
>
> The *Vocab* class is designed to manage, create, save and load vocabularies.
>
> - **Class Initialization**:
    - The class starts by defining a set of custom special tokens as class-level constants.
    - The vocab object accepts columns from the DataFrame to create the vocabulary, the actual data, a directory to save the vocabulary, and target columns.
    - Various attributes are first initialized and validated to ensure they are correctly provided.
    - Special tokens are initialized and added to the vocabulary.
    - Vocabularies are created based on the provided data.
>
> - **Vocabulary Creation**:
    - The vocabulary is constructed using the unique values from the provided columns in the data. Each unique value is added to the vocabulary with a corresponding tag (the field name), global id (the index considering all the tokens of the vocabulary) and local id (the index considering only the tokens in the current field).
    - Vocabulary structure: {column tag: {token: [global_id, local_id]}}
    - id2token structure: {global_id: [token, tag, local_id]}
>    
> - **Utility Methods**:
    - The class also provides a method to retrieve the global id of a token, a method to retrieve the token corresponding to a global id, a method to map between global and local ids using a lookup tensor and two methods to save/load the vocabulary.
>   
> - **Vocabulary Summary Display**:
    - The Vocab class includes the print_vocab_summary method. This method allows for a detailed display of various statistics and characteristics of the vocabulary. It's possible to print special tokens, sample tokens from each column, the size of the vocabulary for each column, the data types of each column, and the total length of the vocabulary. You can specify the sample size and token limit per column.


In [None]:
class Vocab:

    CLS_TOKEN = '[CLS]'
    START_TOKEN = '[START]'
    END_TOKEN = '[END]'
    UNK_TOKEN = '[UNK]'
    SEP_TOKEN = '[SEP]'
    MASK_TOKEN = '[MASK]'
    PAD_TOKEN = '[PAD]'

    def __init__(self,
                 data: pd.DataFrame=None,
                 cols_for_vocab: List[str]=None) -> None:
        """
        Initialize the vocabulary with special tokens and the provided columns and data.

        Args:
            - cols_for_vocab (List[str]):
                A list of columns used to create the vocabularies.
            - data (pd.DataFrame):
                The Pandas DataFrame containing the data for vocabulary creation.
        """
        logger.info('Initializing the vocabulary...')

        # initialize dataset attributes and attribute validation
        self.cols_for_vocab = cols_for_vocab
        self.data = data
        self.token2id = {}
        self.id2token = {}
        # validate the input attributes
        self._validate_attributes()
        # initialize the vocabularies with the special tokens
        self._initialize_special_tokens()
        # fill the token2id and id2token mappings and create a lookup table from global_ids to local_ids
        self._create_vocabularies()
        self.lookup_tensor = self._create_lookup_tensor()

        logger.info('Vocabularies successfully created.\n')

    def _validate_attributes(self) -> None:
        """Helper function to validate the class attributes."""
        # checks for the columns and the data
        if self.cols_for_vocab is None or len(self.cols_for_vocab) == 0:
            raise ValueError('"cols_for_vocab" cannot be None or empty.')
        if self.data is None or self.data.empty:
            raise ValueError('Data cannot be None or empty.')
        # checks for the columns in the data
        if not all(col in self.data.columns for col in self.cols_for_vocab):
            missing_cols = [col for col in self.cols_for_vocab if col not in self.data.columns]
            raise ValueError(f"The following columns are missing from the dataframe: {', '.join(missing_cols)}")

    def _add_tokens_to_vocab(self,
                             tokens: List[Any],
                             tag: str) -> None:
        """
        Add tokens from a given field to the vocabularies.

        Args:
            - tokens (List[Any]):
                The list of tokens to add to the vocabulary.
            - tag (str):
                The tag or category for the tokens.
        """
        # fill in the values into token2id and id2token mappings
        self.token2id[tag] = {}
        global_index = len(self.id2token)
        # structure of token2id: {tag: {token: [global_index, local_index]}}
        # structure of id2token: {global_index: [token, tag, local_index]}
        for local_index, token in enumerate(tokens):
            self.token2id[tag][token] = [global_index, local_index]
            self.id2token[global_index] = [token, tag, local_index]
            global_index += 1

    def _initialize_special_tokens(self) -> None:
        """Initialize special tokens, token2id and id2token vocabularies."""
        # store the special tokens
        self.cls_token = Vocab.CLS_TOKEN
        self.start_token = Vocab.START_TOKEN
        self.end_token = Vocab.END_TOKEN
        self.unk_token = Vocab.UNK_TOKEN
        self.sep_token = Vocab.SEP_TOKEN
        self.mask_token = Vocab.MASK_TOKEN
        self.pad_token = Vocab.PAD_TOKEN
        # add the special tokens to the token2id and id2token vocabularies
        self.special_tokens = [self.unk_token, self.sep_token, self.pad_token,
                               self.cls_token, self.mask_token, self.start_token, self.end_token]
        self.special_tag = 'SPECIAL'
        self._add_tokens_to_vocab(self.special_tokens, self.special_tag)

    def _create_vocabularies(self) -> None:
        """Create token2id and id2token vocabularies based on the provided columns (fields) and data."""
        # for each column extract the unique values and build the mappings
        for col in tqdm(self.cols_for_vocab, desc='Creating vocabularies...'):
            if col not in self.data.columns:
                raise ValueError(f"Column {col} not found in data.")
            unique_values = self.data[col].unique()
            self._add_tokens_to_vocab(unique_values, col)

    def _create_lookup_tensor(self) -> torch.Tensor:
        """
        Create a lookup tensor to map global ids to local ids.
        It constructs a tensor where each index corresponds to a global id and its value is the local id.

        Returns:
            - torch.Tensor: A tensor where each index is a global id and its value is the corresponding local id.
        """
        max_global_id = max(self.id2token.keys())
        lookup_tensor = torch.full((max_global_id + 1,), -100, dtype=torch.long)
        for global_id, value in self.id2token.items():
            lookup_tensor[global_id] = value[2]
        return lookup_tensor

    def get_id(self,
               token: Any,
               tag: str) -> int:
        """
        Retrieve the global index for a token and tag.

        Args:
            - token (Any):
                The token to find.
            - tag (str):
                The tag or category for the token.

        Returns:
            - int:
                The global index of the token.
        """
        # get the vocabulary corresponding to the given field (tag) and from that retrieve the global index of the token
        return self.token2id.get(tag, {}).get(token, self.token2id[self.special_tag][self.unk_token])[0]

    def get_token(self,
                  id: int) -> str:
        """
        Retrieve the token corresponding to an index.

        Args:
            - id (int):
                The index of the token.

        Returns:
            - str:
                The token corresponding to that global index.
        """
        return self.id2token.get(id, [self.unk_token])[0]

    def get_special_tokens(self) -> Dict[str, str]:
        """
        Create the mapping between custom tokens and the standard keys used in BERT-like tokenizers.
        Inspiration was taken by: https://github.com/IBM/TabFormer/blob/main/dataset/vocab.py

        Returns:
            - special_tokens_map (Dict[str, str]):
                The dictionary mapping the custom tokens to the standard keys used in Bert Tokenizer
        """
        special_tokens_map = {}
        # create a mapping between custom special tokens and standard keys in BERT tokenizer
        keys = ["unk_token", "sep_token", "pad_token", "cls_token", "mask_token", "bos_token", "eos_token"]
        for key, token in zip(keys, self.special_tokens):
            token = "%s_%s" % (self.special_tag, token)
            special_tokens_map[key] = token

        return special_tokens_map

    def get_global_ids(self,
                       field_name: str) -> List[int]:
        """
        Get the indeces of the tabular dataset columns.

        Args:
            - field_name (str):
                Name of the field.
        Returns:
            - List[int]:
                List containing the token ids in a given field.
        """
        field_global_ids = [self.token2id[field_name][idx][0] for idx in self.token2id[field_name]]
        return field_global_ids

    def map_global_to_local(self,
                            global_ids: torch.Tensor) -> torch.Tensor:
        """
        Map the global ids to the corresponding field local ids using the lookup tensor. This method is intended to be used during training.

        Args:
            - global_ids (torch.Tensor):
                A tensor containing token global ids.

        Returns:
            - local_ids (torch.Tensor):
                A tensor containing token local ids corresponding to the token global ids.
        """
        lookup_tensor_device = self.lookup_tensor.to(global_ids.device)
        local_ids = lookup_tensor_device[global_ids]
        local_ids.masked_fill_(global_ids == -100, -100)
        return local_ids

    def print_vocab_summary(self,
                            print_special_tokens:bool=True,
                            print_sample_tokens:bool=True,
                            sample_size:int=5,
                            token_limit_per_column:int=5,
                            print_vocab_size_per_column:bool=True,
                            print_column_data_types:bool=True,
                            print_vocab_length:bool=True) -> None:
        """
        Print a summary of the vocabulary based on the provided flags.

        Args:
            print_special_tokens (bool): Whether to print the special tokens. Default to True.
            print_sample_tokens (bool): Whether to print sample tokens from each column. Default to True.
            sample_size (int): Number of columns to sample from. Default to 5.
            token_limit_per_column (int): Number of tokens to show per sampled column. Default to 5.
            print_vocab_size_per_column (bool): Whether to print the size of vocabulary per column. Default to True.
            print_column_data_types (bool): Whether to print the data types of each column. Default to True.
            print_vocab_length (bool): Whether to print the total length of the vocabulary. Default to True.
        """
        if print_special_tokens:
            print(f'Special tokens: {self.special_tokens}\n')
        if print_sample_tokens:
            print("Sampling from the Vocabulary:\n")
            vocab_sample = dict(random.sample(list(self.token2id.items()), sample_size))
            for column, tokens in vocab_sample.items():
                print(f"COLUMN_TAG: {column}")
                print("TOKEN: [GLOBAL IDX, LOCAL IDX]")
                limited_tokens = list(tokens.items())[:token_limit_per_column]
                for token, indices in limited_tokens:
                    print(f"{token}: {indices}")
                print("\n")
        if print_vocab_size_per_column:
            for col in self.cols_for_vocab:
                print(f"Number of tokens in column '{col}': {len(self.token2id[col])}")
            print("\n")
        if print_column_data_types:
            print("Data Types per Column:")
            for col in self.cols_for_vocab:
                print(f"Column '{col}': {self.data[col].dtype}")
            print("\n")
        if print_vocab_length:
            print(f"Total Length of the Vocabulary: {len(self)}")

    def save_vocab(self,
                   vocab_dir: str) -> None:
        """
        Save the vocabularies at the specified path in two formats:
            - One compatible with BERT tokenizer
            - One to have easy access to the vocabulary object when loading it for the validation and test set
        Inspiration was taken by: https://github.com/IBM/TabFormer/blob/main/dataset/vocab.py

        Args:
            - vocab_dir (str):
                The directory where to save the vocabularies.
        """
        logger.info('Saving vocabularies...')
        if not isinstance(vocab_dir, str):
            raise TypeError(f'"vocab_dir" must be a string. Got {type(vocab_dir)}')
        # create the directory where to store the vocabularies
        if not os.path.exists(vocab_dir):
            os.makedirs(vocab_dir)
        self.vocab_file_for_bert = os.path.join(vocab_dir, f'vocab.nb')
        vocab_object_file_pickle = os.path.join(vocab_dir, f'vocab.pkl')
        # save the vocabularies in a format compatible with BERT tokenizer
        with open(self.vocab_file_for_bert, "w") as fout:
            for idx in self.id2token:
                token, field, _ = self.id2token[idx]
                token = "%s_%s" % (field, token)
                fout.write("%s\n" % token)
        # save the vocabulary object to have easy access to the vocabulary when loading it for the validation and test set
        with open(vocab_object_file_pickle, 'wb') as f:
            pkl.dump(self, f)
        logger.info('Vocabularies successfully saved.\n')

    @staticmethod
    def load_vocab(vocab_dir: str) -> 'Vocab':
        """
        Load the vocabulary object from the specified path.

        Args:
            - vocab_dir (str):
                The directory where the vocabularies are stored.

        Returns:
            - Vocab:
                The vocabulary object.
        """
        logger.info('Loading vocabularies...')
        if not isinstance(vocab_dir, str):
            raise TypeError(f'"vocab_dir" must be a string. Got {type(vocab_dir)}')
        filename = os.path.join(vocab_dir, 'vocab.pkl')
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Vocab file not found in {vocab_dir}")
        # load the vocabulary object
        with open(filename, 'rb') as f:
            vocab = pkl.load(f)
        logger.info('Vocab object successfully loaded.\n')
        return vocab

    def __len__(self) -> int:
        """
        Return the length of the vocabulary.

        Returns:
            - int:
                The length of the vocabulary.
        """
        return len(self.id2token)


---
# **THE DATASET**

> ## **TransactionDataset Class**
>
> The TransactionDataset class is designed to take raw data and prepare it for train and test phases.
>
> - **Class Initialization**:
    - Initializes with a DataFrame and the dataset mode (train, train-cls, val, test).
    When the mode is 'train', the vocabulary is saved in the specified directory and transformations are applied based on statistical information from the training data.
    When the mode is 'val' or 'test', the vocabulary is loaded from the specified directory and transformations are applied based on the statistical information from the training data, same for 'train-cls' with the addition of dataset balancing using SMOTE, that generates synthetic instances of the minority class.
    - Other arguments include: directory paths where to save the vocabulary and the class instance, the columns to discretize, the columns to drop, the target columns, the sequence length and the stride.
    - The attributes are validated to ensure that the class is initialized correctly.
    - The data is preprocessed based on the mode.
    - The data is tokenized using the vocabulary.
    - The samples and targets are prepared for time-series analysis based on the sequence length and the stride.
>
> - **Data Preprocessing**:
    - Processes the input data based on the specified mode (train, val, test).
    - In training mode, the data is discretized and a vocabulary is created and saved.
    - In validation and test modes the data is discretized based on the bin edges found with the training data, then the vocabulary created with the training data is loaded and applied to the val/test data.
>
> - **Discretization**:
    - Compute the number of bins to discretize the dataset based on its interquartile range (IQR).
    - The method uses the Freedman-Diaconis Rule to compute the width of each bin. The rule is robust to outliers and is given as  $$ \text{Bin width} = \frac{2 \times \text{IQR}}{\sqrt[3]{\text{num observations}}} $$    
    - The number of bins is calculated as
$$ \text{Number of bins} = \frac{\text{max value} - \text{min value}}{\text{Bin width}} $$
    - The bin labels are computed based on the number of bins specified.
    - The data is discretized by assigning each value to the closest bin label.
    - The bin labels are saved in a json file and loaded when the mode is 'val' or 'test'.
>
> - **Vocabulary**:
    - Uses the *Vocab* class to manage the vocabulary of the dataset.
    - The vocabulary is created based on the columns specified in *cols_for_vocab*.
    - Each column has its own vocabulary, with the following structure: {token: [global_index, local_index]}
    - The vocabulary is created and saved when the mode is 'train', otherwise it is loaded from the specified directory.
>
> - **Tokenization**:
>   - Maps the tokens in the columns specified in *cols_for_vocab* to the corresponding global indices.
> - **Sample Preparation**:
    - Structures the data into samples and targets with a format suitable for time-series analysis.
    - A single sample contains (seq_len+1)*(ncols+1) token ids. The shape of the sample is (seq_len+1, ncols+1). The +1 comes in the first case from the classification token and in the second from the padding token, which are necessary for BERT training.
    - The number of samples obtained in the end depends on the stride and on the number of subsequent rows considered for each sample (sequence length).

>
> - **Saving and Loading the Dataset**:
    - It's possible to save the entire class instance using pickle for efficient storage and retrieval.
    - It's possible to load the class instance, a static method is provided for this purpose.

In [None]:
class TransactionDataset(Dataset):

    COLS_TO_DISCRETIZE = ['Amount', 'TIMESTAMP']
    COLS_FOR_VOCAB = ['Card', 'TIMESTAMP', 'Amount', 'Use Chip', 'Merchant Name', 'Merchant City', 'Merchant State','Zip', 'MCC', 'Errors?']
    TARGET_COLS = ["Is Fraud?"]

    def __init__(self,
                 data: pd.DataFrame,
                 mode: str='train',
                 vocab_dir: str='vocab',
                 save_dir: str='data/processed',
                 cols_to_discretize: Optional[List[str]]=None,
                 cols_for_vocab: Optional[List[str]]=None,
                 target_cols: Optional[List[str]]=None,
                 smote: Optional[bool]=False,
                 sequence_length: Optional[int]=10,
                 stride: Optional[int]=5) -> None:
        """
        Initialize the PRSADataset module.

        Args:
            - data (pd.DataFrame):
                The DataFrame containing the data.
            - mode (str, optional):
                The mode of the dataset. Can be 'train', 'train-cls', 'val' or 'test'. Default to 'train'.
            - 'vocab_dir' (str, optional):
                When mode is 'train', the vocabulary is saved in this directory.
                When mode is 'val' or 'test', the vocabulary is loaded from this directory. Default to 'vocab'.
            - save_dir (str, optional):
                The directory where to save the class instance. Default to 'data/processed'.
            - cols_to_discretize (List[str], optional):
                List of columns to discretize. If not provided, defaults to class-level constant ['SO2', 'NO2', 'CO', 'O3', 'TEMP', 'PRES', 'DEWP', 'WSPM', 'RAIN', 'TIMESTAMP'].
            - cols_for_vocab (List[str], optional):
                List of columns to be used for the vocabulary. If not provided, defaults to class-level constant ['SO2', 'NO2', 'CO', 'O3', 'TEMP', 'PRES', 'DEWP', 'WSPM', 'RAIN', 'TIMESTAMP', 'wd'].
            - target_cols (List[str], optional):
                List of columns to be used as targets. If not provided, defaults to class-level constant ['PM2.5', 'PM10'].
            - smote (bool, optional):
                Whether to apply SMOTE to the dataset. Default to False.
            - sequence_length (int, optional):
                The numbers of subsequent row to consider as a sequence. Default to 10.
            - stride (int, optional):
                The step of the sliding window when combining subsequent rows. Default to 5.
        """

        logger.info('Initializing the TransactionDataset...')
        # initialize dataset attributes
        self.data = data.copy()
        self.mode = mode
        self.vocab_dir = vocab_dir
        # initialize the columns to discretize, to drop and the target columns
        self.cols_to_discretize = cols_to_discretize or TransactionDataset.COLS_TO_DISCRETIZE.copy()
        self.cols_for_vocab = cols_for_vocab or TransactionDataset.COLS_FOR_VOCAB.copy()
        self.target_cols = target_cols or TransactionDataset.TARGET_COLS.copy()
        self.smote = smote
        # initialize attributes for the sequences
        self.sequence_length = sequence_length
        self.stride = stride
        self.samples, self.targets = [], []
        # attributes validation
        self._validate_attributes()
        self._fill_na()
        self._preprocess_data(mode=self.mode)
        self._tokenize_data()
        self._prepare_samples()
        if self.smote:
            self.apply_smote()
        self.save(save_dir)
        logger.info('TransactionDataset successfully initialized.\n')

    def _validate_attributes(self) -> None:
        """Helper function to validate the attributes of the class."""
        # handle the validation of attributes using tuple (variable, expected type, error message)
        validations = [
            (self.data, pd.DataFrame, '"data" must be a pandas DataFrame'),
            (self.mode, str, '"mode" must be a string'),
            (self.vocab_dir, str, '"vocab_dir" must be a string'),
            (self.smote, bool, '"smote" must be a boolean'),
            (self.sequence_length, int, '"sequence_length" must be an integer'),
            (self.stride, int, '"stride" must be an integer'),
            (self.cols_to_discretize, (list, type(None)), '"cols_to_discretize" must be a list or None'),
            (self.target_cols, (list, type(None)), '"target_cols" must be a list or None')]
        # check if the mode is valid
        if self.mode not in ['train', 'train-cls','val', 'test']:
            raise ValueError(f'Invalid mode: {self.mode}. Must be one of "train", "val" or "test".')
        # iterate over the list of tuples (variable, expected type, error message)
        for var, expected_type, error_msg in validations:
            if not isinstance(var, expected_type):
                raise TypeError(f'{error_msg}. Got {type(var)}')
        # check positivity of samples_per_file, sequence_length and stride
        if self.sequence_length <= 0 or self.stride <= 0:
            raise ValueError(f'"sequence_length" and "stride" must be positive integers.')
        # check that the data is not None or empty
        if self.data is None or self.data.empty:
            raise ValueError('Data cannot be None or empty.')
        # check if the necessary columns are in the DataFrame
        required_cols = set(self.cols_to_discretize + self.target_cols + self.cols_for_vocab + self.target_cols)
        if not all(col in self.data.columns for col in required_cols):
            missing_cols = [col for col in required_cols if col not in self.data.columns]
            raise ValueError(f"The following columns are missing from the dataframe: {', '.join(missing_cols)}")

    def _fill_na(self) -> None:
        """
        Handle the Nan values of the dataset
        'Zip' Nan -> 0
        'Errors?' Nan -> None
        'Merchant State' Nan -> None
        """
        logger.info("Removing Nan values from the dataset...")
        self.data['Zip'] = self.data['Zip'].fillna(0)
        self.data['Errors?'] = self.data['Errors?'].fillna('None')
        self.data['Merchant State'] = self.data['Merchant State'].fillna('None')
        logger.info("Nan values succesfully removed")

    def _compute_number_bins(self,
                             col_data: pd.Series) -> int:
        """
        Compute the number of bins to discretize a dataset based on its interquartile range (IQR).
        The method uses the Freedman-Diaconis Rule to compute the width of each bin.
        The rule is robust to outliers and is given as 2*IQR/cubic_root(num_observations).
        The number of bins is calculated as (max_value - min_value)/bin_width.

        Args:
            - col_data (pd.Series):
                The data series to be discretized.

        Returns:
            - int:
                The number of bins to be used for discretization.
        """
        IQR = stats.iqr(col_data, rng=(25,75), nan_policy='omit')
        bin_width = 2 * IQR / np.cbrt(len(col_data.unique()))
        range = np.max(col_data) - np.min(col_data)
        n_bins = int(range/bin_width)
        return n_bins

    def _compute_bin_labels(self,
                            col_data: pd.Series,
                            n_bins: int) -> np.ndarray:
        """
        Compute the bin labels based on the number of bins specified.
        The labels serve as the edges for each bin.

        Args:
            - col_data (pd.Series):
                The data series for which the bin labels are to be computed.
            - n_bins (int):
                The number of bins for quantile calculation.

        Returns:
            - np.ndarray:
                The unique bin labels, which are the edges for each bin.
        """
        quantiles = np.linspace(0, 1, n_bins + 1)
        bin_labels = np.quantile(col_data, quantiles)
        bin_labels = np.unique(bin_labels)
        return bin_labels

    def _discretize_column(self,
                           col: str,
                           bin_labels: Optional[np.ndarray]=None) -> np.ndarray:
        """
        Helper function to discretize a single column.
        If bin_labels is not provided, the number of bins and bin labels are computed.

        Args:
            - col (str):
                The column to be discretized.
            - bin_labels (np.ndarray, optional):
                The bin labels to be used for discretization. If not provided, they are computed.

        Returns:
            - np.ndarray:
                The bin labels used for discretization.
        """
        # compute the number of bins and the bin labels
        if bin_labels is None:
            n_bins = self._compute_number_bins(self.data[col])
            bin_labels = self._compute_bin_labels(self.data[col], n_bins)
        # subtract the value with the closest bin label
        self.data[col] = self.data[col].apply(lambda x: bin_labels[np.argmin(np.abs(bin_labels - x))])
        return bin_labels

    def _encode_amount(self) -> None:
        """Encode the currency string into float without $"""
        self.data['Amount'] = self.data['Amount'].apply(lambda x: float(x.replace('$', '')))

    def _encode_fraud(self) -> None:
        """
        Encode the Yes/No into 1/0
        """
        self.data['Is Fraud?'] = (self.data['Is Fraud?'] == 'Yes').astype(int)

    def _discretize_data(self,
                         save_stats: bool=True) -> None:
        "Discretize the data. The columns specified in self.cols_to_discretize are discretized based on the Freedman-Diaconis Rule."
        logger.info('Starting the discretization process...')
        self._encode_amount()
        self._encode_fraud()
        # filling na values by interpolating
        self.data[self.cols_to_discretize] = self.data[self.cols_to_discretize].interpolate()
        # discretize each column
        if save_stats:
            info = {}
            for col in tqdm(self.cols_to_discretize, desc='Discretizing columns'):
                # discretize the column and save the bin labels
                info[col] = self._discretize_column(col).tolist()
            # save the bin labels in a json file
            with open('bin_stats.json', 'w') as f:
                json.dump(info, f)
        else:
            # load the bin labels from the json file
            with open('bin_stats.json', 'r') as f:
                info = json.load(f)
            for col in tqdm(self.cols_to_discretize, desc='Applying saved discretization'):
                self._discretize_column(col, np.array(info[col]))
        logger.info('Discretization process completed.\n')

    def _create_and_save_vocab(self) -> None:
        """When mode is 'train', create and save the vocabulary in the specified directory."""
        self.vocab = Vocab(data=self.data,
                           cols_for_vocab=self.cols_for_vocab)
        self.vocab.save_vocab(vocab_dir=self.vocab_dir)

    def _preprocess_data(self,
                         mode: str) -> None:
        """
        Preprocess the data based on the mode.
        - If mode is 'train', columns are dropped and the data is discretized. The vocabulary is created and saved.
        - If mode is 'val' or 'test', the vocabulary is loaded and the data is discretized.
        """
        logger.info(f'Preprocessing the {mode} data...')
        if self.mode=='train' or self.mode=='train-cls':
            # data preprocessing (drop columns and discretize), then create and save the vocabulary
            self._discretize_data(save_stats=True)
            self._create_and_save_vocab()
        else:
            # apply the same preprocessing as in train mode, then load the vocabulary
            self._discretize_data(save_stats=False)
            try:
                self.vocab = Vocab.load_vocab(vocab_dir=self.vocab_dir)
            except FileNotFoundError:
                raise FileNotFoundError(f'Vocabulary not found in {self.vocab_dir}. Please run the script in train mode first.')
        logger.info(f'{mode} data successfully preprocessed.\n')

    def _tokenize_data(self) -> None:
        """Map the tokens in "cols_for_vocab" to the corresponding indices."""
        logger.info('Converting data to indices...')
        self.tokenized_data = self.data.copy()
        # apply the get_id function to each element of the dataframe to get the id
        for col in tqdm(self.cols_for_vocab, desc='Tokenizing columns...'):
            self.tokenized_data[col] = self.data[col].apply(lambda x: self.vocab.get_id(x, col))
        logger.info('Tokenization process completed.\n')

    def _prepare_samples(self) -> None:
        """
        Structuring the samples and the targets for Time-Series Analysis.
        A single sample contains seq_len+1*ncols token ids, representing a sequence of registrations in the tabular data.
        The number of samples obtained in the end depends on the stride and on the number of subsequent rows considered for each sample.
        """
        logger.info('Preparing samples and targets...')
        sep_id = self.vocab.get_id(self.vocab.sep_token, self.vocab.special_tag)
        # get the column indices
        feature_col_indices = [self.tokenized_data.columns.get_loc(c) for c in self.cols_for_vocab]
        target_cols_indices = [self.tokenized_data.columns.get_loc(c) for c in self.target_cols]
        data_numpy = self.tokenized_data.to_numpy()
        # group by User and iterate through groups to prepare the samples
        groups = self.tokenized_data.groupby('User')
        for _, group_indices in tqdm(groups.groups.items(), desc='Preparing Samples'):
            user_data = data_numpy[group_indices]
            nrows = len(user_data) - self.sequence_length
            for start_id in range(0, nrows, self.stride):
                sample, target = [], []
                slice_start = start_id
                slice_end = start_id + self.sequence_length
                # get the values of the sample and the target
                sliced_data = user_data[slice_start:slice_end]
                sample_values = sliced_data[:, feature_col_indices]
                target_value = sliced_data[:, target_cols_indices]
                # add the sep token to the end of the sample
                sep_column = np.full((self.sequence_length, 1), sep_id)
                sample = np.hstack((sample_values, sep_column)).ravel()
                flat_target = [item for sublist in target_value.tolist() for item in sublist]
                # if there is at least one fraud in the target, the target is 1, otherwise it is 0
                if 1 in flat_target:
                    target = 1
                else:
                    target = 0
                # append the sample and the target to the list
                self.samples.append(sample)
                self.targets.append(target)
        logger.info('Samples and targets successfully organized.\n')

    def get_ncols(self) -> int:
        """
        Retrieve the number of columns used for the vocabulary (+1 for the sep token).

        Returns:
            -int:
                number of columns used for the vocabulary (+1 for the sep token).
        """
        return len(self.cols_for_vocab) + 1


    def __len__(self) -> int:
        """
        Retrieve the length of the dataset.

        Returns:
            - int:
                The number of samples in the dataset.
        """
        return len(self.samples)

    def __getitem__(self,
                    index: int)-> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieve the sample and the target at the specified index.

        Args:
            - index (int):
                The index of the sample.

        Returns:
            - tuple: A Tuple containing:
                - 'sample': The tensor containing the sample values. The shape is (sequence_length+1, ncols) as the sample contains the row representing the cls token.
                - 'target': The tensor containing the target values.
        """
        # create the cls row and add it to the sample
        cls_id =  self.vocab.get_id(self.vocab.cls_token, self.vocab.special_tag)
        cls_row = torch.full((1, self.get_ncols()), cls_id, dtype=torch.long)
        sample = torch.tensor(self.samples[index].tolist(), dtype=torch.long).reshape(self.sequence_length, -1)
        sample = torch.cat((cls_row, sample), dim=0)
        # get the target
        target = torch.tensor(self.targets[index], dtype=torch.float32)
        return sample, target

    def apply_smote(self) -> None:
        """Balance the data by applying SMOTE to samples and targets"""
        logger.info('Balancing the data...')
        # convert the samples and the targets to numpy arrays
        X_train_numpy = np.array(self.samples)
        y_train_numpy = np.array(self.targets)
        # apply SMOTE
        smote = SMOTE(sampling_strategy='minority', random_state=2024)
        X_train_resampled, y_train_resampled = smote.fit_resample(X_train_numpy, y_train_numpy)
        self.samples = X_train_resampled
        self.targets = y_train_resampled
        logger.info('Data successfully balanced.\n')

    def save(self,
             data_dir: str) -> None:
        """
        Save the class instance to a file using pickle.

        Args:
            - data_dir (str): The path where to save the class instance.
        """
        logger.info('Saving Card Dataset...')
        if not isinstance(data_dir, str):
            raise TypeError(f'"data_dir" must be a string. Got {type(data_dir)}')
        # create the directory where to store the vocabularies
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        self.card_file = os.path.join(data_dir, f'card_{self.mode}.pkl')
        with open(self.card_file, 'wb') as file:
            pkl.dump(self, file)
        logger.info(f'Class instance successfully saved.\n')

    @staticmethod
    def load(data_dir: str,
             mode: str) -> 'TransactionDataset':
        """
        Load a class instance from a file using pickle.

        Args:
            - data_dir (str): The path of the directory to load the class instance from.

        Returns:
            - TransactionDataset: An instance of the TransactionDataset class.
        """
        logger.info(f'Loading Card Dataset {mode} set...')
        if not isinstance(data_dir, str):
            raise TypeError(f'"data_dir" must be a string. Got {type(data_dir)}')
        # check if the mode is valid
        if mode not in ['train', 'train-cls', 'val', 'test']:
            raise ValueError(f'Invalid mode: {mode}. Must be one of "train", "train-cls", "val" or "test".')
        filename = os.path.join(data_dir, f'card_{mode}.pkl')
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Card file not found in {data_dir}")
        with open(filename, 'rb') as file:
            card = pkl.load(file)
        logger.info(f'Class instance successfully loaded.\n')
        return card

> **_Important_:**
>
> **Run the following cell keeping the 'load' argument to 'True' to load the preprocessed data we used for our experiments.
If you want to preprocess the data again, change the 'load' argument to 'False'. Keep in mind that this will take more time.**


In [None]:
## RUN THIS CELL - NOTHING TO CHANGE
load = True
if not load:
    train_dataset = TransactionDataset(data=train_data,
                                       mode='train',
                                       vocab_dir=VOCAB_DIR,
                                       save_dir=PROCESSED_DATA_DIR)
    train_dataset_cls = TransactionDataset(data=train_data,
                                           mode='train-cls',
                                           vocab_dir=VOCAB_DIR,
                                           smote=True,
                                           stride=10,
                                           save_dir=PROCESSED_DATA_DIR)
    val_dataset = TransactionDataset(data=val_data,
                                     mode='val',
                                     vocab_dir=VOCAB_DIR,
                                     save_dir=PROCESSED_DATA_DIR)
    test_dataset = TransactionDataset(data=test_data,
                                      mode='test',
                                      vocab_dir=VOCAB_DIR,
                                      save_dir=PROCESSED_DATA_DIR)
else:
    train_dataset = TransactionDataset.load(data_dir=PROCESSED_DATA_DIR,
                                            mode='train')
    train_dataset_cls = TransactionDataset.load(data_dir=PROCESSED_DATA_DIR,
                                                mode='train-cls')
    val_dataset = TransactionDataset.load(data_dir=PROCESSED_DATA_DIR,
                                          mode='val')
    test_dataset = TransactionDataset.load(data_dir=PROCESSED_DATA_DIR,
                                           mode='test')

2024-01-14 11:52:19,352 - INFO - root - Loading Card Dataset train set...
2024-01-14 11:52:48,481 - INFO - root - Class instance successfully loaded.

2024-01-14 11:52:48,487 - INFO - root - Loading Card Dataset train-cls set...
2024-01-14 11:53:09,060 - INFO - root - Class instance successfully loaded.

2024-01-14 11:53:09,069 - INFO - root - Loading Card Dataset val set...
2024-01-14 11:53:19,030 - INFO - root - Class instance successfully loaded.

2024-01-14 11:53:19,040 - INFO - root - Loading Card Dataset test set...
2024-01-14 11:53:31,661 - INFO - root - Class instance successfully loaded.



In [None]:
## RUN THIS CELL - NOTHING TO CHANGE
# -----------------------------------------------------------------------------------------------------------
# NOTE: This cell shows the structure of a single sample from the CardDataset. Each sample comprises multiple rows,
# each with 10 columns. The data in these columns has been discretized, mapped to the nearest bin edge, and
# then converted to indices. The separator index has beeen added.
# Observing this sample helps in understanding the preprocessing steps applied to the
# dataset, such as discretization and tokenization, and how the data is presented to the data collator.
# -----------------------------------------------------------------------------------------------------------
first_sample = train_dataset[0][0]
first_sample_label = train_dataset[0][1]
print(f"First sample:\n {first_sample}")
print(f"Shape [seq_len+1, num_cols+1]: {first_sample.shape}\n")
print(f"Target associated to the first sample:\n {first_sample_label}\n")

First sample:
 tensor([[    3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
             3],
        [    8,    26,  1208,  2385,  6481, 37003, 43902, 47943, 62531, 62640,
             1],
        [    8,    26,  1675,  2385,  2805, 37004, 43902, 47944, 62536, 62640,
             1],
        [    8,    26,  1305,  2385,  2805, 37004, 43902, 47944, 62536, 62640,
             1],
        [    8,    26,   563,  2385,  2758, 37004, 43902, 47944, 62542, 62640,
             1],
        [    8,    26,   989,  2385,  2601, 37003, 43902, 47943, 62535, 62640,
             1],
        [    8,    26,  1424,  2385,  2410, 37004, 43902, 47947, 62548, 62640,
             1],
        [    8,    26,   703,  2385,  2805, 37004, 43902, 47944, 62536, 62640,
             1],
        [    8,    26,  1609,  2385,  2805, 37004, 43902, 47944, 62536, 62640,
             1],
        [    8,    26,  2096,  2385,  2805, 37004, 43902, 47944, 62536, 62640,
             1],
        [    8,    26,  2

> ## **Vocabulary Extraction and Summary**
>
> **_Important_**:
>
> **Run the following cell to extract the vocabulary from the training dataset and to display a summary of it.**

In [None]:
vocab = train_dataset.vocab
vocab.print_vocab_summary(print_special_tokens=True,
                          print_sample_tokens=True,
                          sample_size=5,
                          token_limit_per_column=5,
                          print_vocab_size_per_column=True,
                          print_column_data_types=True,
                          print_vocab_length=True)

Special tokens: ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]', '[START]', '[END]']

Sampling from the Vocabulary:

COLUMN_TAG: Errors?
TOKEN: [GLOBAL IDX, LOCAL IDX]
None: [62640, 0]
Technical Glitch,: [62641, 1]
Insufficient Balance,: [62642, 2]
Bad PIN,: [62643, 3]
Bad CVV,: [62644, 4]


COLUMN_TAG: Use Chip
TOKEN: [GLOBAL IDX, LOCAL IDX]
Swipe Transaction: [2385, 0]
Online Transaction: [2386, 1]
Chip Transaction: [2387, 2]


COLUMN_TAG: Merchant City
TOKEN: [GLOBAL IDX, LOCAL IDX]
Quincy: [34574, 0]
Panama City: [34575, 1]
Lincoln Park: [34576, 2]
Brandon: [34577, 3]
Saint Marks: [34578, 4]


COLUMN_TAG: Zip
TOKEN: [GLOBAL IDX, LOCAL IDX]
32352.0: [44055, 0]
32401.0: [44056, 1]
48146.0: [44057, 2]
33511.0: [44058, 3]
32355.0: [44059, 4]


COLUMN_TAG: Card
TOKEN: [GLOBAL IDX, LOCAL IDX]
3: [7, 0]
0: [8, 1]
2: [9, 2]
4: [10, 3]
6: [11, 4]


Number of tokens in column 'Card': 8
Number of tokens in column 'TIMESTAMP': 255
Number of tokens in column 'Amount': 2115
Number of tokens in col

---
# **BERT PARAMETERS**

>## **Bert Custom Config**
>
> The *CustomBertConfig* class is an extension of the BertConfig class, specifically designed for handling tabular and time series data.
>
> 1. **Number of Columns (ncols)**: Specifies the number of columns in the tabular data, aligning with the number of input indices in one row.
>
> 2. **Vocabulary Size (vocab_size)**: Number of unique tokens in the data.
>
> 3. **Field Hidden Size (field_hidden_size)**: Sets the hidden size for field embeddings.
>
> 4. **Hidden Size (hidden_size)**: Determines the dimensionality of the encoder output.
>
> 5. **Number of Hidden Layers (num_hidden_layers)**: Defines the depth of the Transformer encoder.
>
> 6. **Number of Attention Heads (num_attention_heads)**: The number of attention mechanisms in each encoder layer.
>
> 7. **Pad Token ID (pad_token_id)**:  Represents the index used for padding.
>
> 8. **Masked Language Model Probability (mlm_probability)**: Ratio of tokens to mask for masked language modeling.


In [None]:
class CustomBertConfig(BertConfig):
    """
    Custom config class for a hierarchal Bert Model for Tabular Data and Time Series analysis.

    The BertConfig is the configuration class to store the configuration of a [`BertModel`].

    Refer to the following link for source code and documentation of BertConfig:
        - https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/bert/configuration_bert.py#L72
    """

    def __init__(self,
                 ncols: Optional[int]=12,
                 vocab_size: Optional[int]=429,
                 field_hidden_size: Optional[int]=64,
                 hidden_size: Optional[int]=768,
                 num_hidden_layers: Optional[int]=6,
                 num_attention_heads: Optional[int]=8,
                 pad_token_id: Optional[int]=0,
                 mlm_probability: Optional[float]=0.15,
                 **kwargs) -> None:

        """
        Initialize the CustomBertConfig module.

        Args:
            - ncols (int, optional):
                The number of columns in the tabular data. Correspond to the number of 'input_ids' in one row. Default to 12.
            - vocab_size (int, optional):
                Vocabulary size of the model. Defines the number of different tokens that can be represented by the `inputs_ids`. Default to 429.
            - field_hidden_size (int, optional):
                 Hidden size for field embeddings.. Default to 64.
            - hidden_size (int, optional):
                Dimensionality of the encoder layers and the pooler layer.
                Corresponds to the dimensionality of the row embedding. Default to 768.
            - num_hidden_layers (int, optional):
                Number of hidden layers in the Transformer encoder. Default to 6.
            - num_attention_heads (int, optional):
                Number of attention heads for each attention layer in the Transformer encoder. Default to 8.
            - pad_token_int (int, optional):
                Index used for padding. Default to 0.
            - mlm_probability (float, optional):
                Ratio of tokens to mask for masked language modeling. Default to 0.15.
        """

        super().__init__(pad_token_id=pad_token_id, **kwargs)

        self.ncols = ncols
        self.field_hidden_size = field_hidden_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_attention_heads = num_attention_heads
        self.pad_token_id = pad_token_id
        self.num_hidden_layers = num_hidden_layers
        self.mlm_probability = mlm_probability

> **_Important_**:
>
> **Run the cell below to create a CustomBertConfig object. This is for demonstration purposes only as the config object will be created in the training manager given the dictionary of parameters.**

In [None]:
ncols=train_dataset.get_ncols()

model_config_values = {
    "vocab_size": len(vocab),
    "ncols": ncols,
    "field_hidden_size": 128,
    "hidden_size": 128*ncols,
    "num_hidden_layers": 12,
    "num_attention_heads": ncols,
    "pad_token_id": vocab.get_id(vocab.pad_token, vocab.special_tag),
    "sequence_length": train_dataset.sequence_length,
    "stride": train_dataset.stride
    }
config =  CustomBertConfig(**model_config_values)

---
# **TOKENIZER AND DATA COLLATOR**

> ## **Tokenizer**
>
> In this cell we initialize the BERT tokenizer. The tokenizer is initialized with the vocabulary file from the vocab object.
This tokenizer is used in the CustomDataCollator class to pad the input ids and mask tokens for MLM tasks.
>
> **_Important_**:
>
> **Run the following cell to initialize the tokenizer.**

In [None]:
tokenizer = BertTokenizerFast(vocab_file=vocab.vocab_file_for_bert,
                              do_lowercase=False,
                              **vocab.get_special_tokens())

> ## **Custom Data Collator**
>
> The *CustomDataCollator* class is designed to handle tabular and time series data, preparing the samples, the targets and the masked language model labels.
>
> - **Hugging Face Compatibility**:
    - It inherits from *DataCollatorForLanguageModeling*, ensuring compatibility with Hugging Face.
    - The class requires a Bert tokenizer for padding and masking the input ids.
>    
> - ***`__call__`* Method**:
    - Each row represent a collection of features (already converted to indices). Remember that the sequence lenght parameter of the dataset defines the number of subsequent rows that constitute a single sample.
    - The collator efficiently groups multiple samples into a batch, the batch received by the model has shape [batch, seq_len+1, ncols+1].
    - The method can handle both MLM and classification tasks.
    - In MLM mode, it masks certain tokens in the input ids based on mlm_probability and returns the labels.
    - In classification mode, it returns the targets.
> - For source code and usage, refer to the [Hugging Face's documentation](https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/data/data_collator.py#L607).

In [None]:
class CustomDataCollator(DataCollatorForLanguageModeling):
    """
    Custom Data Collator for Tabular Data and Time Series analysis.

    This class inherits from DataCollatorForLanguageModeling from huggingface.
    It is designed to handle tabular and time series where each row consists of multiple columns and each sample consists of multiple rows.
    The collator can be used for Masked Language Modeling tasks.

    Refer to the following link for source code and documentation of DataCollatorForLanguageModelling:
        - https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/data/data_collator.py#L607
    """

    def __call__(self,
                 samples: List[Tuple[torch.Tensor, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """
        Collates the samples into a batch. It can handle both MLM and classification tasks.
        For MLM tasks, it masks certain tokens in the input ids based on mlm_probability and returns the labels.
        For classification tasks, it returns the labels.

        Args:
            - samples (List[Dict[str, Tensor]]):
                List of dictionaries containing the samples and their labels.

        Returns:
            - Dict[str, Tensor]:
                A dictionary containing input ids, attention mask and target (MLM labels or classification labels).
        """
        # pad the batch
        input_ids = [sample[0] for sample in samples]
        targets = [sample[1] for sample in samples]
        batch = self.tokenizer.pad({"input_ids": input_ids}, return_tensors="pt") # expected shape [batch, seq_len+1, ncols+1]
        if self.mlm:
            # get the shape of the input ids and flatten the samples to mask tokens for MLM
            sz = batch['input_ids'].shape
            input_ids = batch['input_ids'].view(sz[0], -1) # expected shape [batch, seq_len+1*ncols+1]
            # mask the tokens with a method from DataCollatorForLanguageModeling
            input_ids, labels = self.torch_mask_tokens(input_ids)
            # reconstruct the initial shape
            batch['input_ids'] = input_ids.view(sz)
            batch['labels'] = labels.view(sz)
        else:
            batch['labels'] = torch.stack(targets)
        return batch

> **_Important_:**
>
> **In the following cell, we demonstrate how to use the CustomDataCollator class for MLM and classification tasks.
> Note that there's no need to create the collator object manually, as the CustomDataCollator object will be created in the training manager.**


In [None]:
## NO NEED TO RUN THIS CELL
# data collator for Masked Language Model
data_collator_for_mlm = CustomDataCollator(tokenizer=tokenizer,
                                           mlm=True,
                                           mlm_probability=config.mlm_probability)
# data collator for Classification task
data_collator_for_regression = CustomDataCollator(tokenizer=tokenizer,
                                                  mlm=False)

---
# **THE MODEL**


In [None]:
## RUN THIS CELL - NO CHANGES NEEDED
from src.utils import set_seed

# setting the seed for reproducibility
set_seed(2024)


> ## **Hierarchical Bert Language Model**
>
> This module contains the implementation of the Hierarchical Bert Language Model.
> It can be used for Masked Language Modeling and classification tasks. It represents the core component of our project.
>
> The model is composed of three components:
> - **TabRowEmbeddings**:
    - An embedding layer for tabular data.
    - It is designed to handle tabular data where each row consists of multiple columns.
    - Each individual token is mapped to an embedding. The sequence is then passed to a transformer encoder to capture relationships between columns.
    - A final linear layer transform the embeddings to the desired hidden size.
> - **BertModel**:
    - A BertModel from the HuggingFace library.
    - It is used to capture relationships between rows.
> - **MLM-specific layers or Classification-specific layers**:
    - The MLM layers are used for Masked Language Modeling. They are used when pretraining the model to obtain a representation of the input.
    - The classification layers are used for the classification task. They are used to fine-tune the model after pretraining.
>
> The forward step of the model is different for MLM and classification tasks:
> - **Masked Language Modeling**:
    - The input ids are passed through the TabRowEmbeddings layer to obtain the embeddings of the tabular data.
    - The embeddings are then passed to the BertModel.
    - The output of the BertModel is passed through the MLM layers to obtain the predictions at field level.
    - The predictions are compared to the masked LM labels at field level and the cross entropy loss is computed.
> - **Classification**:  
    - The input ids are passed through the TabRowEmbeddings layer to obtain the embeddings of the tabular data.
    - The embeddings are then passed to the BertModel.
    - The CLS embedding is extracted from the output of the BertModel and it is passed through the classification layers to obtain a prediction for each sample.
    - The predictions are compared to the targets (1 for Fraud, 0 for No Fraud) and the Binary Cross Entropy Loss is computed.  
    


In [None]:
class TabRowEmbeddings(nn.Module):
    """
    Custom embedding class for tabular row data.

    This custom class is designed handle embeddings for tabular data where each row consists of multiple columns.
    Each individual token is mapped to an embedding.
    The sequence is then passed to a transformer encoder to capture relationships between columns.
    A final linear projection transform the embeddings to the desired hidden size.
    """
    def __init__(self,
                 config: CustomBertConfig) -> None:
        """
        Initializes the TabRowEmbeddings class.

        Args:
        - config (CustomBertConfig):
            CustomBertConfig object with attributes:
            - vocab_size: Vocabulary size of the model.
            - field_hidden_size: Hidden size for field embeddings.
            - ncols: The number of columns in the tabular data.
            - hidden_size: Hidden size for the output row embeddings.
            - pad_token_id (optional): Index used for padding.
        """
        super().__init__()
        self.word_embeddings = nn.Embedding(num_embeddings=config.vocab_size,
                                            embedding_dim=config.field_hidden_size,
                                            padding_idx=getattr(config, 'pad_token_id', 0))
        encoder_layer = nn.TransformerEncoderLayer(d_model=config.field_hidden_size,
                                                   nhead=8,
                                                   dim_feedforward=config.field_hidden_size,
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,
                                                         num_layers=1)
        self.linear = nn.Linear(in_features=config.field_hidden_size*config.ncols,
                                out_features=config.hidden_size)
        self._init_model_weights()

    def forward(self,
                input_ids: torch.Tensor) -> torch.Tensor:
        """
        Forward step of TabRowEmbeddings.

        Args:
            - input_ids:
                Tensor of shape [batch_size, seq_len+1, ncols+1] containing the token ids.

        Returns:
            - input_embeds:
                Tensor of shape [batch_size, seq_len+1, hidden_size] containing the output row embeddings.
        """

        inputs_embeds = self.word_embeddings(input_ids) #[batch_size, seq_len+1, ncols+1, field_hidden_size]
        embeds_shape = inputs_embeds.shape
        # reshape the embeddings
        inputs_embeds = inputs_embeds.view(embeds_shape[0]*embeds_shape[1], embeds_shape[2], -1)  #[batch_size*(seq_len+1), ncols+1, field_hidden_size]
        # passing through the transformer encoder
        inputs_embeds = self.transformer_encoder(inputs_embeds)
        # reshape the embeddings to have a single row embedding
        inputs_embeds = inputs_embeds.contiguous().view(embeds_shape[0], embeds_shape[1], -1)  # [batch_size, seq_len+1, (ncols+1)*field_hidden_size]
        # final linear projection to hidden size
        inputs_embeds = self.linear(inputs_embeds) # [batch_size, seq_len+1, hidden_size]
        return inputs_embeds

    def _init_model_weights(self):
        """
        Initializes the weights of the model.
        """
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.Embedding):
                nn.init.uniform_(module.weight, -1.0, 1.0)

In [None]:
class HierarchicalBertLM(PreTrainedModel):
    def __init__(self,
                 config: BertConfig,
                 vocab: Vocab,
                 mode: str='mlm') -> None:
        """
        Initializes the HierarchicalBertLM class. It can be used for masked LM and classification tasks.

        Args:
            - config (CustomBertConfig):
                CustomBertConfig object with attributes:
                - vocab_size: Vocabulary size of the model.
                - field_hidden_size: Hidden size for field embeddings.
                - ncols: The number of columns in the tabular data.
                - hidden_size: Hidden size for the output row embeddings.
                - pad_token_id (optional): Index used for padding.
                - hidden_act (optional): Activation function used in the feedforward layer.
                - layer_norm_eps (optional): Epsilon value for layer normalization.
            - vocab (Vocab):
                Vocab object containing the vocabulary of the model.
            - mode (str, 'mlm' or 'classification'):
                Mode of the model. If 'mlm', the model is trained with masked LM. If 'classification', the model is trained for classification.
        """
        super().__init__(config)
        self.config = config
        self.vocab = vocab
        self.mode = mode
        # tabular embeddings
        self.tabular_row_embeddings = TabRowEmbeddings(self.config)
        # bert model for sequence of rows
        self.bert = BertModel(config)
        # MLM-specific layers
        self.mlm_linear = nn.Linear(in_features=config.field_hidden_size,
                                    out_features=config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias
        if isinstance(config.hidden_act, str):
            self.activation_function = ACT2FN[config.hidden_act]
        else:
            self.activation_function = config.hidden_act
        # CrossEntropyLoss for masked LM
        self.loss_fct_mlm = nn.CrossEntropyLoss()
        # precompute the global ids for each column
        if self.mode=='mlm':
            self.precomputed_global_ids = {key: self.vocab.get_global_ids(key) for key in self.vocab.cols_for_vocab}
        # classification-specific layers
        self.hidden_layer_cls = nn.Linear(config.hidden_size, 256)
        self.dropout = nn.Dropout(0.1)
        self.output_layer_cls = nn.Linear(256, 1)
        self.loss_fct_cls = nn.BCELoss()
        self._init_model_weights()

    def compute_masked_lm_loss(self,
                               sequence_output: torch.Tensor,
                               masked_lm_labels: torch.Tensor,
                               outputs: tuple) -> torch.Tensor:
        """
        Computes the masked LM loss.

        Args:
            - sequence_output (torch.Tensor):
                Tensor of shape [batch_size, seq_len, hidden_size] containing the output of the BERT model.
            - masked_lm_labels (torch.Tensor):
                Tensor of shape [batch_size, seq_len, ncols] containing the masked LM labels.
            - outputs (torch.Tensor):
                Tuple containing the outputs of the BERT model.

        Returns:
            - total_masked_lm_loss (torch.Tensor):
                Tensor containing the total masked LM loss
        """
        # we must reshape the output to reconstruct field embeddings
        output_shape = sequence_output.shape # [batch_size, seq_len, hidden_size]
        expected_shape = [output_shape[0], output_shape[1]*self.config.ncols, -1] # [batch_size, seq_len*ncols, field_hidden_size]
        sequence_output = sequence_output.view(expected_shape)
        masked_lm_labels = masked_lm_labels.view(expected_shape[0], -1) # [batch_size, seq_len*ncols]
        # pass the output of BERT through the feedforward layer, output shape [batch_size, seq_len*ncols, hidden_size]
        hidden_state = self.mlm_linear(sequence_output)
        hidden_state = self.activation_function(hidden_state)
        hidden_state = self.layer_norm(hidden_state)
        prediction_scores = self.decoder(hidden_state) # [batch_size, seq_len*ncols, vocab_size]
        outputs = (prediction_scores, ) + outputs[2:]
        total_masked_lm_loss = 0
        seq_len = prediction_scores.size(1)
        # get the field names
        field_names = self.vocab.cols_for_vocab
        # iterate over the field names
        for index, key in enumerate(field_names):
            # get the global ids for the field
            col_ids = list(range(index, seq_len, len(field_names)+1))
            global_ids_field = self.precomputed_global_ids[key]
            # remember that prediction_scores has shape [batch_size, seq_len*ncols, vocab_size], so we need to select the right columns.
            # we select the prediction scores for the particular field and from them we select the scores only corresponding to the global ids of the field
            prediction_scores_field = prediction_scores[:, col_ids, :][:, :, global_ids_field]  # [batch_size, seq_len, K] where K is the number of unique tokens in the field (the vocab size of the field)
            # selection of the masked LM labels for the field
            masked_lm_labels_field = masked_lm_labels[:, col_ids]
            # map the global ids to local ids
            masked_lm_labels_field_local = self.vocab.map_global_to_local(global_ids=masked_lm_labels_field)
            # compute the masked LM loss for the field
            masked_lm_loss_field = self.loss_fct_mlm(prediction_scores_field.view(-1, len(global_ids_field)),
                                                     masked_lm_labels_field_local.view(-1))
            if not torch.isnan(masked_lm_loss_field):
                total_masked_lm_loss += masked_lm_loss_field
        return total_masked_lm_loss

    def forward(self,
                input_ids=None,
                attention_mask=None,
                labels=None) -> dict:
        """
        Forward step of HierarchicalBertLM. Works for both masked LM and classification.

        Args:
            - input_ids (torch.Tensor):
                Tensor of shape [batch_size, seq_len+1, ncols+1] containing the token ids.
            - attention_mask (torch.Tensor):
                Tensor of shape [batch_size, seq_len+1, ncols+1] containing the attention mask.
            - labels (torch.Tensor):
                Tensor of shape [batch_size, seq_len+1, 1] or [batch_size, 1] containing the masked LM labels or the classification targets.

        Returns:
            - dict:
                Dictionary containing the loss and the predictions (if present).
        """
        # construct the embeddings of the tabular data, output shape [batch_size, seq_len, hidden_size]
        inputs_embeds = self.tabular_row_embeddings(input_ids)
        # pass the time series of rows through BERT
        outputs = self.bert(inputs_embeds=inputs_embeds)
        sequence_output = outputs[0]
        if self.mode=='mlm':
            total_loss = self.compute_masked_lm_loss(sequence_output,
                                                     labels,
                                                     outputs)
            return {'loss': total_loss}
        # classification task
        elif self.mode=='classification':
            # extract the CLS token
            cls_embedding = sequence_output[:, 0, :]
            # pass the output of BERT through the hidden layer
            x = self.hidden_layer_cls(cls_embedding)
            x = self.activation_function(x)
            x = self.dropout(x)
            logits = self.output_layer_cls(x)
            probs = torch.sigmoid(logits)
            # reshape the probs and the labels
            probs_flat = probs.view(-1)
            labels_flat = labels.view(-1)
            total_loss = self.loss_fct_cls(probs_flat, labels_flat)
            return {'loss': total_loss,
                    'predictions': probs_flat,
                    'labels': labels_flat}
        else:
            raise ValueError("Neither masked_lm_labels nor labels are provided for the forward pass.")

    def freeze_model_except_nlayers(self,
                                    n: int=2) -> None:
        """
        Freezes all the layers of the encoder model except the last n layers.

        Args:
            - n (int):
                Number of layers to keep unfrozen. Default to 2.
        """
        # freeze all the layers
        for param in self.tabular_row_embeddings.parameters():
            param.requires_grad = False
        for param in self.bert.parameters():
            param.requires_grad = False
        # unfreeze the last n layers
        for param in self.bert.encoder.layer[-n:].parameters():
            param.requires_grad = True
        # unfreeze the classification layers (if present)
        if self.mode=='classification':
            for param in self.hidden_layer_cls.parameters():
                param.requires_grad = True
            for param in self.output_layer_cls.parameters():
                param.requires_grad = True

    def _init_model_weights(self) -> None:
        """Initializes the weights of the model."""
        # initialize weights for MLM layers
        if self.mode == 'mlm':
            init.xavier_uniform_(self.mlm_linear.weight)
            self.mlm_linear.bias.data.zero_()
            init.xavier_uniform_(self.decoder.weight)
            self.decoder.bias.data.zero_()
        # initialize weights for classification layers
        if self.mode == 'classification':
            init.xavier_uniform_(self.hidden_layer_cls.weight)
            self.hidden_layer_cls.bias.data.zero_()
            init.xavier_uniform_(self.output_layer_cls.weight)
            self.output_layer_cls.bias.data.zero_()

> **_Important_**:
>
> **The following code is not meant to be executed. It is only used to show how to instantiate the model.
> The models will be instantiated in the training manager based on the mode (mlm or classification).**

In [None]:
## NO NEED TO RUN THIS CELL
model_mlm = HierarchicalBertLM(config=config,
                               vocab=vocab,
                               mode='mlm')

model_classification = HierarchicalBertLM(config=config,
                                          vocab=vocab,
                                          mode='classification')

---
# **TRAINING AND EVALUATION**

> ## **Weights & Biases**
>
> In order to log the training process and the metrics, we will use [Weights & Biases](https://wandb.ai/site).
>
> **_Important_:**
>
> **You can create a free account and login from the notebook running the following cell.**
>
> **While running the cell, you will be prompted to enter your API key. You can find your API key [here](https://wandb.ai/authorize).**




In [None]:
## RUN THIS CELL - NO CHANGES NEEDED
import wandb

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

> ## **MLM Training Configuration**
>
> In this cell we define a dictionary containing the training parameters to train the Masked Language Model.
>
> The parameters are passed to the TrainingArguments class from the transformers library, that is used to instantiate the Trainer class. Thus, be sure that the parameters are valid for the TrainingArguments class. A check is performed in the TrainingManager class but it is better to check them before.
>
> **_Important_:**
>
> **Note if you're not planning to train the model, you can skip this cell. Pretrained models will be loaded in the next sections.**


In [None]:
## NO NEED TO RUN THIS CELL UNLESS YOU WANT TO TRAIN THE MODEL
training_mlm_config_dict = {
    'per_device_train_batch_size': 128,
    'per_device_eval_batch_size': 128,
    'num_train_epochs': 50,
    'logging_strategy': 'steps',
    'logging_first_step': True,
    'logging_steps': 1,
    'save_strategy': 'steps',
    'save_steps': 750,
    'evaluation_strategy': 'steps',
    'eval_steps': 250,
    'load_best_model_at_end': True,
    'disable_tqdm': False,
    'seed': 2024,
    'learning_rate': 5e-5,
    'report_to':'wandb',
    'lr_scheduler_type':'constant'}

> ## **Training Manager Class**
>
> The TrainingManager class is responsible for setting up the model, the data collator, the training arguments, and the HuggingFace Trainer.
>
> - **Class Initialization:**
>     - To initialize the class, we need to provide the model configuration dictionary, the training configuration dictionary, the training, validation, and test sets, the root directory, the project name, the model name, the mode (either 'mlm' or 'classification'), and the path to the pretrained model checkpoint (only required for 'classification' mode).
>     - The model configuration dictionary contains the model parameters to be logged.
>     - The training configuration dictionary contains the training parameters for the TrainingArguments class from HuggingFace. Be sure to check the [documentation](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments) for the list of available parameters. A check is performed to ensure that only valid parameters are provided.
>     - The training, validation, and test sets are instances of the PRSADataset class.
>     - The root directory is the directory where the output and logs directories will be created.
>     - The project name is the name of the project for logging on wandb.
>     - The model name is the name of the model for logging on wandb.
>     - The mode is either 'mlm' or 'classification'.
>     - The pretrained model path is the path to the pretrained model checkpoint.
>
> - **Directories and Logging:**
>     - The setup_directories() method sets up the checkpoints and logs directories.
>     - The setup_wandb() method sets up wandb for logging.
>
> - **Tokenizer and Collator:**
>     - The setup_tokenizer() method sets up the tokenizer needed in the data collator.
>     - The setup_collator() method sets up the data collator for training. If the mode is 'mlm', the data collator returns the labels for masked language modeling. If the mode is 'classification', the data collator returns the labels for classification.
>
> - **Model:**
>     - The setup_model() method sets up the model for training. If a pretrained model path is provided, the model is initialized from the pretrained checkpoint. If the mode is 'mlm', the model is trained with masked language modeling. If the mode is 'classification', the model is initialized from the pretrained checkpoint (after MLM training) and trained for classification. The model is frozen except for the last n (default 3) layers.
>
> - **Training:**
>     - The setup_training() method sets up the training arguments and trainer.
>     - The train() method must be called to train the model.
>     - The evaluate() method can be used to evaluate the model on the validation or test set.

In [None]:
class TrainingManager:
    def __init__(self,
                 model_config_dict: dict,
                 training_config_dict: dict,
                 vocab: Vocab,
                 train_set: TransactionDataset,
                 val_set: TransactionDataset,
                 test_set: TransactionDataset,
                 root_dir: str='/content',
                 project_name: str='CreditTabBert',
                 model_name: str='credit-0',
                 mode: str='mlm',
                 pretrained_model_path: str=None,
                 layers_to_unfreeze: Optional[int] = 3) -> None:
        """
        Initializes the TrainingManager class.

        Args:
            - model_config_dict (dict): configuration dictionary containing model and training parameters to be logged.
            - training_config_dict (dict): configuration dictionary containing training parameters for training_args
            - vocab (Vocab): Vocab object containing the vocabulary of the model.
            - data_collator (CustomDataCollator): CustomDataCollator object.
            - train_set (TransactionDataset): TransactionDataset object containing the training data.
            - val_set (TransactionDataset): TransactionDataset object containing the validation data.
            - test_set (TransactionDataset): TransactionDataset object containing the test data.
            - root_dir (str): Root directory of the project. Defaults to '/content'.
            - project_name (str): Name of the project. Defaults to 'CreditTabBert'.
            - model_name (str): Name of the model. Defaults to 'credit-0'.
            - mode (str): Mode of the model. Either 'mlm' or 'classification'.
            - pretrained_model_path (str): Path to the pretrained model checkpoint.
            - layers_to_unfreeze (int): numbre of layers of TabBert to unfreeze.
        """
        self.model_config_dict = model_config_dict
        self.model_config =  CustomBertConfig(**self.model_config_dict)
        self.training_config_dict = training_config_dict
        self.vocab = vocab
        self.train_set = train_set
        self.val_set = val_set
        self.test_set = test_set
        self.root_dir = root_dir
        self.project_name = project_name
        self.model_name = model_name
        self.mode = mode
        self.pretrained_model_path = pretrained_model_path
        self.layers_to_unfreeze = layers_to_unfreeze
        self._validate_attributes()
        self.setup_directories()
        self.setup_model()
        self.setup_wandb()
        self.setup_tokenizer()
        self.setup_collator()
        self.setup_training()

    def _validate_attributes(self) -> None:
        """Helper function to validate the attributes."""
        validations = [
            (self.model_config_dict, dict, '"model_config_dict" must be a dictionary'),
            (self.training_config_dict, dict, '"training_config_dict" must be a dictionary'),
            (self.vocab, Vocab, '"vocab" must be an instance of Vocab'),
            (self.train_set, TransactionDataset, '"train_set" must be an instance of TransactionDataset'),
            (self.val_set, TransactionDataset, '"val_set" must be an instance of TransactionDataset'),
            (self.test_set, TransactionDataset, '"test_set" must be an instance of TransactionDataset'),
            (self.root_dir, (str, type(None)), '"root_dir" must be a string or None'),
            (self.project_name, str, '"project_name" must be a string'),
            (self.model_name, str, '"model_name" must be a string'),
            (self.mode, str, '"mode" must be a string'),
            (self.pretrained_model_path, (str, type(None)), '"pretrained_model_path" must be a string or None'),
            (self.layers_to_unfreeze, int, '"layers_to_unfreeze" must be an integer')
        ]
        for var, var_type, err_msg in validations:
            if not isinstance(var, var_type):
                raise TypeError(f'{err_msg}. Got {type(var)} instead.')
        if self.mode not in ['mlm', 'classification']:
            raise ValueError('"mode" must be either "mlm" or "classification"')
        if self.mode == 'classification' and self.pretrained_model_path is None:
            raise ValueError('"pretrained_model_path" is required for "classification" mode')
        if self.mode == 'classification' and not os.path.exists(self.pretrained_model_path):
            raise ValueError(f'"{self.pretrained_model_path}" does not exist')

        valid_params  = set(inspect.signature(TrainingArguments).parameters.keys())
        input_params = set(self.training_config_dict.keys())
        invalid_params = input_params - valid_params
        if invalid_params:
            raise ValueError(f'Invalid parameters in training_config_dict: {invalid_params}')

    def setup_directories(self) -> None:
        """Sets up the output and logs directories."""
        try:
            base_dir = Path(self.root_dir) / f'output' / self.mode
            base_dir.mkdir(parents=True, exist_ok=True)
            self.CHECKPOINT_DIR = base_dir / 'checkpoints' / self.model_name
            self.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
            self.LOGS_DIR = base_dir / 'logs' / self.model_name
            self.LOGS_DIR.mkdir(parents=True, exist_ok=True)
        except Exception as e:
            logger.error(f"An error occurred while setting up directories: {e}")
            raise

    def setup_collator(self) -> None:
        """Sets up the data collator for training."""
        try:
            self.data_collator = CustomDataCollator(tokenizer=self.tokenizer,
                                                    mlm_probability=self.model_config.mlm_probability,
                                                    mlm=self.mode=='mlm')
        except Exception as e:
            logger.error(f"An error occurred while setting up the data collator: {e}")
            raise

    def setup_tokenizer(self) -> None:
        """Sets up the tokenizer for training."""
        try:
            self.tokenizer = BertTokenizerFast(vocab_file=self.vocab.vocab_file_for_bert,
                                               do_lower_case=False,
                                               **self.vocab.get_special_tokens())
        except Exception as e:
            logger.error(f"An error occurred while setting up the tokenizer: {e}")
            raise

    def setup_model(self) -> None:
        """
        Sets up the model for training.
        If the mode is 'mlm', the model is initialized from scratch and trained with masked language modeling.
        If a pretrained_model_path is provided, the model is initialized from the pretrained checkpoint.
        If the mode is 'classification', the model is initialized from a pretrained checkpoint and trained for classification.
        """
        try:
            if self.pretrained_model_path:
                self.model = HierarchicalBertLM.from_pretrained(self.pretrained_model_path,
                                                                config=self.model_config,
                                                                vocab=self.vocab,
                                                                mode=self.mode,
                                                                ignore_mismatched_sizes=True)
            else:
                self.model = HierarchicalBertLM(config=self.model_config,
                                                vocab=self.vocab,
                                                mode=self.mode)
            if self.mode == 'classification':
                self.model.freeze_model_except_nlayers(n=self.layers_to_unfreeze)
        except Exception as e:
            logger.error(f"An error occurred while setting up the model: {e}")
            raise

    def setup_wandb(self) -> None:
        """Sets up wandb for logging."""
        try:
            wandb.init(config=self.model_config_dict,
                       project=self.project_name,
                       name=self.model_name,
                       group=self.mode,
                       dir=str(self.LOGS_DIR))
            wandb.config.update(self.model_config_dict)
        except Exception as e:
            logger.error(f"An error occurred while setting up wandb: {e}")
            raise

    def setup_training(self) -> None:
        """Sets up the training arguments and trainer."""
        try:
            self.training_args = TrainingArguments(output_dir=str(self.CHECKPOINT_DIR),
                                                   logging_dir=str(self.LOGS_DIR),
                                                   **self.training_config_dict)
            if self.mode == 'classification':
                self.trainer = Trainer(model=self.model,
                                    args=self.training_args,
                                    data_collator=self.data_collator,
                                    train_dataset=self.train_set,
                                    eval_dataset=self.val_set,
                                    compute_metrics = self.compute_metrics)
            else:
                self.trainer = Trainer(model=self.model,
                                args=self.training_args,
                                data_collator=self.data_collator,
                                train_dataset=self.train_set,
                                eval_dataset=self.val_set)

        except Exception as e:
            logger.error(f"An error occurred while setting up training: {e}")
            raise

    def _cleanup(self) -> None:
        """Helper function to terminate wandb process."""
        if wandb.run:
            wandb.finish()
        logger.info('Cleanup completed.')

    def train(self,
              resume_from_checkpoint: Union[bool, str]=None) -> None:
        """
        Trains the model.

        Args:
            - resume_from_checkpoint (Union[bool, str], optional):
                If a str, local path to a saved checkpoint as saved by a previous instance of Trainer.
                If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer.
                If None, training starts from scratch.
        """
        if not wandb.run:
            self.setup_wandb()
        try:
            self.trainer.train(resume_from_checkpoint=resume_from_checkpoint)
        except Exception as e:
            logger.error(f"An error occurred during training: {e}")
        finally:
            self._cleanup()

    def compute_metrics(self, out):
        predictions = out.predictions[0]
        labels = out.predictions[1]
        threshold = 0.5
        binary_predictions = (predictions > threshold).astype(int)
        precision = precision_score(labels, binary_predictions)
        recall = recall_score(labels, binary_predictions)
        f1 = f1_score(labels, binary_predictions)
        accuracy = accuracy_score(labels, binary_predictions)
        return {'precision': precision,
                'recall': recall,
                'f1': f1,
                'accuracy': accuracy}

    def evaluate(self,
                 test=False):
        """
        Evaluates the model on the validation or test set based on the value of test.

        Args:
            - test (bool): If True, evaluate on the test set. Else, evaluate on the validation set.
        """
        if self.mode == 'classification':
            if test:
                if self.test_set is None:
                    raise ValueError('Test set is None. Cannot evaluate.')
                else:
                    predictions = self.trainer.predict(self.test_set)
                    metrics = self.compute_metrics(predictions)
                    binary_predictions = (predictions.predictions[0] > 0.5).astype(int)
                    return metrics, predictions.predictions[0], predictions.predictions[1]
        out = self.trainer.evaluate()
        return out



> ## **MLM Training**
>
> **_Important_:**
>
> **Run this cell if you want to initialize the TrainingManager class for MLM training and run the train() method.**
>
> **A pretrained model checkpoint is provided to evaluate the model on the val/test set. If you prefer to train the model from scratch, set the pretrained_model_path to None.**


In [None]:
## RUN THIS CELL
## SET TO NONE TO TRAIN THE MODEL FROM SCRATCH
mlm_pretrained_model_path =  os.path.join(ROOT_DIR, 'output/mlm/checkpoints/card-model/checkpoint-final')

training_manager_mlm = TrainingManager(model_config_dict=model_config_values,
                                       training_config_dict=training_mlm_config_dict,
                                       vocab=vocab,
                                       train_set=train_dataset,
                                       val_set=val_dataset,
                                       test_set=test_dataset,
                                       root_dir=ROOT_DIR,
                                       model_name='card-model',
                                       mode='mlm',
                                       pretrained_model_path=mlm_pretrained_model_path)

In [None]:
## NO NEED TO RUN THIS CELL
training_manager_mlm.train()

> ## **MLM Evaluation**
>
> The following cell evaluates the model on the validation set.
> It returns the Cross Entropy Loss for MLM.
>
> **_Important_:**
>
> **If you run the .train() cell without finishing the training process, make sure to run again the training manager cell with the provided checkpoint.**

In [None]:
## RUN THIS CELL
training_manager_mlm.trainer.evaluate()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 11.246283531188965,
 'eval_runtime': 771.5818,
 'eval_samples_per_second': 128.978,
 'eval_steps_per_second': 1.008}

> ## **Classification Training Configuration**
>
> In this cell we define a dictionary containing the training parameters to train the model for the classification task.
>
> The parameters are passed to the TrainingArguments class from the transformers library, that is used to instantiate the Trainer class. Thus, be sure that the parameters are valid for the TrainingArguments class. A check is performed in the TrainingManager class but it is better to check them before.
>
> **_Important_:**
>
> **Run this cell as it is.**


In [None]:
training_cls_config_dict = {
    'per_device_train_batch_size': 128,
    'per_device_eval_batch_size': 128,
    'num_train_epochs': 10,
    'logging_strategy': 'steps',
    'logging_first_step': True,
    'logging_steps': 1,
    'save_strategy': 'steps',
    'save_steps': 100,
    'evaluation_strategy': 'steps',
    'eval_steps': 100,
    'load_best_model_at_end': True,
    'disable_tqdm': False,
    'seed': 2024,
    'learning_rate': 5e-5,
    'report_to':'wandb'}


> ## **Classification Training**
>
> **_Important_:**
>
> **Run this cell if you want to initialize the TrainingManager class for classification training and run the train() method.**
>
> **A pretrained model checkpoint is provided to evaluate the model on the val/test set. If you prefer to train the model from scratch (starting from pretrained mlm), set the pretrained_model_path to mlm_pretrained_model_path.**


In [None]:
## RUN THIS CELL TO SET UP THE TRAINING MANAGER FOR CLASSIFICATION - NO CHANGES NEEDED

## SET TO 'mlm_pretrained_model_path' TO TRAIN THE MODEL FROM THE PRETRAINED BERT (from mlm)
cls_pretrained_model_path = os.path.join(ROOT_DIR, 'output/classification/checkpoints/card-model-cls/checkpoint-final')

training_manager_cls = TrainingManager(model_config_dict=model_config_values,
                                       training_config_dict=training_cls_config_dict,
                                       vocab=vocab,
                                       train_set=train_dataset_cls,
                                       val_set=val_dataset,
                                       test_set=test_dataset,
                                       root_dir=ROOT_DIR,
                                       model_name='card-model-cls',
                                       mode='classification',
                                       pretrained_model_path=cls_pretrained_model_path)

In [None]:
## NO NEEED TO RUN THIS CELL
training_manager_cls.train()

> ## **Classification Evaluation**
>
> The following cell evaluates the model on the validation/test set based on the value of test.
> The method evaluate returns the metrics, the predictions and the labels.
>
> The following metrics are computed:
>
> - **F1**: A metric that combines precision and recall, providing a balanced measure of a model's performance by considering both false positives and false negatives.
> - **Precision**: It measures the accuracy of positive predictions, representing the ratio of true positives to the total predicted positives.
> - **Recall**: It measures the ability of the model to capture all relevant instances of the positive class. It is the ratio of true positives to the total actual positives.
> - **Accuracy**: It represents the overall correctness of the model's predictions, calculated as the ratio of correct predictions to the total number of instances.
>
> In our evaluation, the model achieved an F1 score of 0.52, which is notably lower than the 0.76 F1 score reported in the reference paper. This discrepancy in performance can be primarily attributed to the difference in the volume of training data used. Our model was trained on a dataset comprising 5 million samples, a decision driven by constraints in computational resources and time. In contrast, the paper's model was trained on the full dataset, which encompasses 24 million samples.


In [None]:
## RUN THIS CELL
## SET TEST TO FALSE/TRUE TO EVALUATE THE MODEL ON THE VAL/TEST SET
training_manager_cls.evaluate(test=True)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


({'precision': 0.6470588235294118,
  'recall': 0.4408014571948998,
  'f1': 0.5243770314192849,
  'accuracy': 0.9955810114350137},
 array([0.00013916, 0.00012273, 0.00011576, ..., 0.00011405, 0.00010707,
        0.00012873], dtype=float32),
 array([0., 0., 0., ..., 0., 0., 0.], dtype=float32))

---
# **WANDB**

* [Link](https://wandb.ai/neural-network-tab-bert/CreditTabBert) to the wandb **project**

* MLM task training/evaluation plots can be seen at the following [link](
https://api.wandb.ai/links/neural-network-tab-bert/cwliqy1o)

* CLASSIFICATION task training/evaluation plots can be seen at the following [link](
https://api.wandb.ai/links/neural-network-tab-bert/1lff34z1)

---
# **REFERENCES**

- Inkit Padhi, Yair Schiff, Igor Melnyk, Mattia Rigotti, Youssef Mroueh, Pierre Dognin, Jerret Ross, Ravi Nair, and Erik Altman. "Tabular Transformers for Modeling Multivariate Time Series". 2021. arXiv:2011.01843 [cs.LG].

