# Training a model
see README.md for details

Loosely following this tutorial:
https://medium.com/nlplanet/a-full-guide-to-finetuning-t5-for-text2text-and-building-a-demo-with-streamlit-c72009631887

!pip install -r requirements.txt

In [None]:
import os

os.environ["WANDB_DISABLED"] = "true"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import pandas as pd

# from glob import glob

import zipfile

import torch
import wandb
from datasets import Dataset
from evaluate import load
from accelerate import Accelerator, DataLoaderConfiguration
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import T5Config
from transformers import TFT5ForConditionalGeneration
from transformers import ByT5Tokenizer  # a "dummy" tokenizer, tokenizing into bytes
from transformers import DataCollatorForSeq2Seq
from transformers import EvalPrediction

from config import data_root, model_root, checkpoint_name
from config import token_len, annot_len

In [None]:
dataloader_config = DataLoaderConfiguration(
    use_seedable_sampler=False,
)
accelerator = Accelerator(
    dataloader_config=dataloader_config,
    project_dir=model_root,
)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
# device = accelerator.device
torch.set_default_device(device)
# torch.cuda.is_available()
# accelerator.device
device

In [None]:
dfs = []
# for fname in glob(f"{data_root}/*.csv"):
#     dfs += [pd.read_csv(fname, names=["inputs", "labels"])]
with zipfile.ZipFile(f"{data_root}/data-ue.zip") as zf:
    for name in zf.namelist():
        dfs += [pd.read_csv(zf.open(name), names=["input", "label"])]
df = pd.concat(dfs, axis=0)

ds = Dataset.from_pandas(df)
ds = ds.train_test_split(test_size=0.1, shuffle=True)
ds

In [None]:
tokenizer = ByT5Tokenizer()
# tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")


# Create a tokenization function
def preprocess_function(examples):
    model_inputs = tokenizer(examples["input"], max_length=64, truncation=True)

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["label"], max_length=64, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# Apply the tokenization function to the dataset
tds = ds.map(preprocess_function, batched=True)
tds = tds.remove_columns(["input", "label"])
tds

In [None]:
exact_match_metric = load("exact_match")


def compute_exact_match(pred: EvalPrediction):
    # Convert predictions to text
    predictions = pred.predictions
    references = pred.label_ids

    # Decode if needed
    decoded_preds = [
        pred.decode(pred, skip_special_tokens=True) for pred in predictions
    ]
    decoded_labels = [
        label.decode(label, skip_special_tokens=True) for label in references
    ]

    # Compute exact match
    result = exact_match_metric.compute(
        predictions=decoded_preds, references=decoded_labels
    )
    return {"exact_match": result["exact_match"]}


args = Seq2SeqTrainingArguments(
    output_dir=f"{model_root}/byT5-ocs-ue",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    #     gradient_checkpointing=True,
    # torch_empty_cache_steps=100,
    disable_tqdm=False,
    report_to=None,  # disable wandb.ai
    load_best_model_at_end=True,
    save_total_limit=1,
    eval_strategy="steps",
)


def init_model():
    model = TFT5ForConditionalGeneration(config)
    # model = model.cuda()
    model.to(device)
    return model


config = T5Config.from_pretrained("t5-base")
# config.task_specific_params = {}
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)

In [None]:
trainer = Seq2SeqTrainer(
    model_init=init_model,
    args=args,
    train_dataset=tds["train"],
    eval_dataset=tds["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_exact_match,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(output_dir=f"{model_root}/byT5-ocs-ue-final")