In [1]:
import sys

sys.path.append("..")

In [2]:
from transformers import (
    DataCollatorWithPadding,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizerFast,
    TrainingArguments,
)

from adna.pylib import consts
from adna.pylib.adna_dataset import ADnaDataset
from adna.pylib.weighted_trainer import WeightedTrainer

In [None]:
MODEL_PATH = consts.MT_DIR / "train" / "checkpoint-6"

In [3]:
TRAIN_EPOCHS = 50
LEARNING_RATE = 1e-5
TRAIN_BATCH_SIZE = 128
EVAL_BATCH_SIZE = 128

## Get the tokenizer

In [4]:
tokenizer_path = str(consts.MT_DIR)
tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_path)

## Get the datasets

In [5]:
train_dataset = ADnaDataset(
    "train",
    tokenizer,
    rev_comp_rate=consts.REV_COMP_RATE,
    to_n_rate=consts.TO_N_RATE,
    # limit=TRAIN_BATCH_SIZE,
)
eval_dataset = ADnaDataset("val", tokenizer)  # , limit=EVAL_BATCH_SIZE)

In [6]:
train_dataset[0]

{'input_ids': [0, 262, 264, 3269, 3778, 322, 285, 695, 1441, 2096, 839, 2690, 287, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'label': tensor(0)}

## Get the model

In [7]:
model = RobertaForSequenceClassification.from_pretrained(
    MODEL_PATH, local_files_only=True
)

## Build the data collator

In [8]:
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding="max_length",
    max_length=consts.MAX_LENGTH,
)

## Build the trainer

In [9]:
training_args = TrainingArguments(
    output_dir=consts.MT_DIR / "finetune",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    num_train_epochs=TRAIN_EPOCHS,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    save_strategy="epoch",
    logging_strategy="epoch",
    seed=23,
)

In [10]:
trainer = WeightedTrainer(
    train_dataset,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

## Train

In [11]:
trainer.train()

***** Running training *****
  Num examples = 128
  Num Epochs = 20
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 20


Epoch,Training Loss,Validation Loss
1,0.7038,0.704518
2,0.6961,0.718231
3,0.7114,0.717236
4,0.6881,0.709831
5,0.7102,0.700927
6,0.6753,0.698075
7,0.6873,0.699204
8,0.6891,0.70171
9,0.7078,0.704761
10,0.7028,0.706673


***** Running Evaluation *****
  Num examples = 128
  Batch size = 128
Saving model checkpoint to ../data/UF46992/finetune/checkpoint-1
Configuration saved in ../data/UF46992/finetune/checkpoint-1/config.json
Model weights saved in ../data/UF46992/finetune/checkpoint-1/pytorch_model.bin
tokenizer config file saved in ../data/UF46992/finetune/checkpoint-1/tokenizer_config.json
Special tokens file saved in ../data/UF46992/finetune/checkpoint-1/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 128
  Batch size = 128
Saving model checkpoint to ../data/UF46992/finetune/checkpoint-2
Configuration saved in ../data/UF46992/finetune/checkpoint-2/config.json
Model weights saved in ../data/UF46992/finetune/checkpoint-2/pytorch_model.bin
tokenizer config file saved in ../data/UF46992/finetune/checkpoint-2/tokenizer_config.json
Special tokens file saved in ../data/UF46992/finetune/checkpoint-2/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 128
  Batch 

TrainOutput(global_step=20, training_loss=0.6890893965959549, metrics={'train_runtime': 156.0772, 'train_samples_per_second': 16.402, 'train_steps_per_second': 0.128, 'total_flos': 52986959462400.0, 'train_loss': 0.6890893965959549, 'epoch': 20.0})

### It looks like epoch ?? is best

Path = ../data/UF46992/models/checkpoint-??
