# Training a model
see README.md for details

Loosely following this tutorial:
https://medium.com/nlplanet/a-full-guide-to-finetuning-t5-for-text2text-and-building-a-demo-with-streamlit-c72009631887

In [1]:
import os

os.environ["WANDB_DISABLED"] = "true"

import pandas as pd
# from glob import glob

import zipfile

import torch
import wandb
from datasets import Dataset
from evaluate import load
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import T5Config
from transformers import T5ForConditionalGeneration
from transformers import ByT5Tokenizer  # a "dummy" tokenizer, tokenizing into bytes
from transformers import DataCollatorForSeq2Seq

from config import data_root, model_root, checkpoint_name
from config import token_len, annot_len

In [None]:
torch.cuda.is_available()

In [8]:
dfs = []
# for fname in glob(f"{data_root}/*.csv"):
#     dfs += [pd.read_csv(fname, names=["inputs", "labels"])]
with zipfile.ZipFile(f"{data_root}/data.zip") as zf:
    for name in zf.namelist():
        dfs += [pd.read_csv(zf.open(name), names=["inputs", "labels"])]
df = pd.concat(dfs, axis=0)

ds = Dataset.from_pandas(df)
ds = ds.train_test_split(test_size=0.1)
ds

DatasetDict({
    train: Dataset({
        features: ['inputs', 'labels', '__index_level_0__'],
        num_rows: 36995
    })
    test: Dataset({
        features: ['inputs', 'labels', '__index_level_0__'],
        num_rows: 4111
    })
})

In [None]:
tokenizer = ByT5Tokenizer()
# tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")


# Create a tokenization function
def preprocess_function(examples):
    inputs = examples["inputs"]
    targets = examples["labels"]
    model_inputs = tokenizer(inputs, max_length=64, truncation=True)

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=64, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# Apply the tokenization function to the dataset
tds = ds.map(preprocess_function, batched=True)
tds = tds.remove_columns(["inputs"])
tds

In [None]:
exact_match_metric = load("exact_match")

args = Seq2SeqTrainingArguments(
    output_dir=f"{model_root}/byT5-ocs-ue",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    #     gradient_checkpointing=True,
    # torch_empty_cache_steps=100,
    disable_tqdm=False,
    report_to=None,  # disable wandb.ai
)

config = T5Config.from_pretrained("t5-base")
# config.task_specific_params = {}
model = T5ForConditionalGeneration(config)
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tds["train"],
    eval_dataset=tds["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=exact_match_metric,
)

In [None]:
trainer.train()

In [None]:
tds