# Train a RoBERTa model

In [1]:
import sys

sys.path.append("..")

In [2]:
import sqlite3

import torch
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
    DataCollatorWithPadding,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    Trainer,
    TrainingArguments,
)

from adna.pylib import consts

In [3]:
TRAIN_EPOCHS = 20
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
TRAIN_BATCH_SIZE = 128
VALID_BATCH_SIZE = 128
SUMMARY_LEN = 7

## Build the tokenizer

In [4]:
tokenizer_path = str(consts.SUB_DIR)
tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_path)

## Build the training datasets

In [5]:
class ADnaDataset(Dataset):
    def __init__(self, split, tokenizer):
        self.split = split
        self.tokenizer = tokenizer
        self.cxn = sqlite3.connect(":memory:")
        self.length = 0
        self.db_to_memory()

    def db_to_memory(self):
        create = """
            create table seqs as
            select seq, label, rev
            from aux.seqs where split = ?
            order by random()
            limit 1000000
            """
        self.cxn.execute(f"attach database '{consts.SQL}' as aux")
        self.cxn.execute(create, (self.split,))
        self.cxn.execute("detach database aux")
        count = self.cxn.execute("select count(*) from seqs")
        self.length = int(count.fetchone()[0])

    def __len__(self):
        return self.length

    def __getitem__(self, row_id):
        row_id += 1
        sql = "select * from seqs where rowid = ?"
        row = self.cxn.execute(sql, (row_id,)).fetchone()
        encoded = tokenizer.encode_plus(
            row[0],
            padding="max_length",
            max_length=consts.MAX_LENGTH,
        )
        encoded["label"] = torch.tensor(row[1])
        return encoded

In [6]:
train_dataset = ADnaDataset("train", tokenizer)
eval_dataset = ADnaDataset("val", tokenizer)

In [7]:
train_dataset[1]

{'input_ids': [0, 1817, 836, 515, 770, 3830, 847, 3120, 323, 326, 1821, 2696, 1065, 280, 810, 268, 481, 273, 545, 4067, 275, 293, 330, 1972, 1471, 651, 307, 277, 309, 266, 296, 285, 531, 278, 732, 2496, 640, 275, 592, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': tensor(0)}

In [8]:
len(train_dataset)

1000000

## Build the model

In [9]:
config = RobertaConfig(
    vocab_size=consts.VOCAB_SIZE,
    num_hidden_layers=6,
    type_vocab_size=1,
)

In [10]:
model = RobertaForSequenceClassification(config=config)

In [11]:
model.num_parameters()

46660610

In [12]:
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding="max_length",
    max_length=consts.MAX_LENGTH,
)

In [13]:
training_args = TrainingArguments(
    output_dir=consts.SUB_DIR,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    num_train_epochs=TRAIN_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=VALID_BATCH_SIZE,
    save_steps=8192,
    # save_total_limit=1,
)

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
trainer.train()

***** Running training *****
  Num examples = 1000000
  Num Epochs = 20
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 156260


Epoch,Training Loss,Validation Loss
