# 05. Modeling - Nubank AI Core Transaction Dataset Interview Project

In this section we will train our model with different hyperparameters and compare the results.

In [1]:
import os
import argparse
import logging

from transformers import (
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
    set_seed,
)
from datasets import Dataset
from sklearn.model_selection import train_test_split
from nubert.datasets import NuDataset
from nubert.config import NubertPreTrainConfig, TrainerConfig

In [2]:
def split_dataset(dataset, test_size=0.1, val_size=0.1, seed=42):
    train_val, test = train_test_split(dataset, test_size=test_size, random_state=seed)    
    train, val = train_test_split(train_val, test_size=val_size / (1 - test_size), random_state=seed)
    
    return train, val, test

def create_hf_dataset(data):
    return Dataset.from_dict({"input_ids": data})

def resize_model_embeddings(model, tokenizer):
    """Resize the model's embeddings to match the tokenizer's vocabulary size."""
    model.resize_token_embeddings(len(tokenizer))
    return model

In [3]:
import gc
import torch
import wandb

def train_model(
    dataset,
    config: NubertPreTrainConfig,
    ):
    model = AutoModelForMaskedLM.from_pretrained(config.model_name)
    tokenizer = dataset.tokenizer.base_tokenizer

    tokenizer.save_pretrained(config.trainer.output_dir)
    model = resize_model_embeddings(model, tokenizer)

    train_data, val_data, test_data = split_dataset(dataset.data)

    train_dataset = create_hf_dataset(train_data)
    val_dataset = create_hf_dataset(val_data)
    test_dataset = create_hf_dataset(test_data)

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True)
    
    training_args = TrainingArguments(
        **config.trainer.model_dump()
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )

    trainer.train()

    trainer.save_model()
    tokenizer.save_pretrained(config.trainer.output_dir)
    wandb.finish()
    del model
    gc.collect()
    torch.cuda.empty_cache()


In [4]:
os.environ["WANDB_PROJECT"] = "nubert"
os.environ["WANDB_LOG_MODEL"] = "end"


num_transactions_to_test = [10]
stride_to_test = [1]
num_bins_to_test = [15]
randomized_to_test = [False]

for num_transactions in num_transactions_to_test:
    for stride in stride_to_test:
        for num_bins in num_bins_to_test:
            for randomize_column_order in randomized_to_test:
                trainer_config = TrainerConfig(
                    per_device_train_batch_size = 64,
                    per_device_eval_batch_size = 64,
                )
                config = NubertPreTrainConfig(
                    run_name="",
                    dataset_path = "/notebooks/nubank/nubert/analyses/nubank-2013-2014",
                    file_name = "nubank_raw",
                    num_transactions = num_transactions,
                    stride = stride,
                    num_bins = num_bins,
                    trainer=trainer_config,
                    randomize_column_order = randomize_column_order,
                )
                full_dataset = NuDataset.from_config(config)
                train_model(dataset=full_dataset, config=config)

  df = pd.read_csv(path.join(root, f"{fname}.csv"))
  df['Transaction Date'] = pd.to_datetime(df['Transaction Date'])


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

 95%|█████████▍| 104/110 [24:54<01:26, 14.37s/it]  


KeyboardInterrupt: 

Cool, we have trained our models. The evaluation of their performances will be done outside of this notebook because the results are logged on Weights and Biases.

### Fine-tuning

For the continuation of this notebook (fine-tuning for amount prediction), head on over to the amount directory on analyses.