# Train a RoBERTa model using a balanced dataset

## Note: This experiment did not work

In [1]:
import sys

sys.path.append("..")

In [2]:
from transformers import (
    DataCollatorWithPadding,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    Trainer,
    TrainerCallback,
    TrainingArguments,
)

from adna.pylib import consts
from adna.pylib.balanced_dataset import BalancedDataset

## Data augmentation parameters

In [3]:
REV_COMP_RATE = 0.5

## Training parameters

In [4]:
TRAIN_EPOCHS = 50
LEARNING_RATE = 3e-5
TRAIN_BATCH_SIZE = 192
EVAL_BATCH_SIZE = 192

NUM_HIDDEN_LAYERS = 6
NUM_ATTENTION_HEADS = 6
HIDDEN_SIZE = NUM_ATTENTION_HEADS * 32

MODEL_DIR = 'train'  # Save models to this sub-directory

## Get the tokenizer

In [5]:
tokenizer_path = str(consts.MT_DIR)
tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_path)

## Get the datasets

In [6]:
train_dataset = BalancedDataset('train', tokenizer)
eval_dataset = BalancedDataset('val', tokenizer)

In [7]:
len(train_dataset.pos_records)

96950

In [8]:
len(eval_dataset.pos_records)

32034

## Adjust weights

In [9]:
WEIGHTS = train_dataset.weights
WEIGHTS

[1.0, 1.0]

## Build the model

In [10]:
config = RobertaConfig(
    vocab_size=consts.VOCAB_SIZE,
    num_hidden_layers=NUM_HIDDEN_LAYERS,
    type_vocab_size=1,
    num_labels=len(WEIGHTS),
    max_position_embeddings=consts.MAX_LENGTH,
    num_attention_heads=NUM_ATTENTION_HEADS,
    hidden_size=HIDDEN_SIZE,
)

In [11]:
model = RobertaForSequenceClassification(config=config)

## Build the start of epoch callback

In [12]:
class DataResetCallback(TrainerCallback):
    def on_epoch_begin(
        self,
        args,
        state,
        control,
        train_dataloader=None,
        eval_dataloader=None,
        **kwargs,
    ):
        if train_dataloader:
            train_dataloader.dataset.sample()

## Build the trainer

In [13]:
training_args = TrainingArguments(
    output_dir=consts.MT_DIR / MODEL_DIR,
    overwrite_output_dir=True,
    num_train_epochs=TRAIN_EPOCHS,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    seed=23,
)

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    callbacks=[DataResetCallback],
)

## Train

In [15]:
trainer.train()

***** Running training *****
  Num examples = 193900
  Num Epochs = 50
  Instantaneous batch size per device = 192
  Total train batch size (w. parallel, distributed & accumulation) = 192
  Gradient Accumulation steps = 1
  Total optimization steps = 50500


Epoch,Training Loss,Validation Loss
1,0.4017,0.294142
2,0.2911,0.273105
3,0.2717,0.257891
4,0.2627,0.247982
5,0.2514,0.249661
6,0.2398,0.230601
7,0.2294,0.225182
8,0.2239,0.219988
9,0.2158,0.212062
10,0.2071,0.210061


***** Running Evaluation *****
  Num examples = 64068
  Batch size = 192
Saving model checkpoint to ../data/UF46992/train/checkpoint-1010
Configuration saved in ../data/UF46992/train/checkpoint-1010/config.json
Model weights saved in ../data/UF46992/train/checkpoint-1010/pytorch_model.bin
tokenizer config file saved in ../data/UF46992/train/checkpoint-1010/tokenizer_config.json
Special tokens file saved in ../data/UF46992/train/checkpoint-1010/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 64068
  Batch size = 192
Saving model checkpoint to ../data/UF46992/train/checkpoint-2020
Configuration saved in ../data/UF46992/train/checkpoint-2020/config.json
Model weights saved in ../data/UF46992/train/checkpoint-2020/pytorch_model.bin
tokenizer config file saved in ../data/UF46992/train/checkpoint-2020/tokenizer_config.json
Special tokens file saved in ../data/UF46992/train/checkpoint-2020/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 64068
  

TrainOutput(global_step=50500, training_loss=0.1709773406982422, metrics={'train_runtime': 24076.6204, 'train_samples_per_second': 402.673, 'train_steps_per_second': 2.097, 'total_flos': 4.670614725e+16, 'train_loss': 0.1709773406982422, 'epoch': 50.0})

### It looks like epoch 42 is best

Path = `../data/UF46992/train/checkpoint-42420`