# Import

In [None]:
# set hyperparameters
TRAIN_FOLD = 0
EXP_NAME = 'tasr_mt5_large_f0'
MODEL_CHECKPOINT = 'google/mt5-large'

BATCH_SIZE = 64
GRAG_ACC_STEP = 1
MAX_INPUT_LEN = 192
MAX_TARGET_LEN = 32

In [None]:
# import packages
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import re
import json
import jiwer
import logging
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.model_selection import KFold

import torch.distributed as dist

from datasets import Dataset
from datasets import load_dataset, load_metric
from transformers import AutoModel, AutoTokenizer, MT5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

# Load Data

In [None]:
# read training data
with open('train_all.json') as file:
    data = json.load(file)

In [None]:
# cross-validation
kf = KFold(n_splits=10, random_state=1998, shuffle=True)
for i_fold, (train_index, valid_index) in enumerate(kf.split(data)):
    
    if i_fold != TRAIN_FOLD:
        continue
        
    train_data = [data[idx] for idx in train_index]
    valid_data = [data[idx] for idx in valid_index]

In [None]:
len(train_data), len(valid_data)

In [None]:
# remove duplicate string and convert to huggingface dataset
def list_drop_dup(data_list):
    return list(dict.fromkeys(data_list))

def to_hf_dataset(data, drop_duplicate=False):    
    data_dict = {}
    if drop_duplicate:
        dd_sents = [list_drop_dup([s.replace(' ', '') for s in d['sentence_list']]) for d in data]
        data_dict['asr_sentences'] = ['</s>'.join(ss) for ss in dd_sents]
    else:
        data_dict['asr_sentences'] = ['</s>'.join(d['sentence_list']).replace(' ', '') for d in data]
    data_dict['ground_truth'] = [d['ground_truth_sentence'] for d in data]
    
    
    dataset = Dataset.from_dict(data_dict)
    
    return dataset

In [None]:
train_dataset = to_hf_dataset(train_data, drop_duplicate=True)
valid_dataset = to_hf_dataset(valid_data, drop_duplicate=True)

In [None]:
len(train_dataset), len(valid_dataset)

In [None]:
valid_dataset[0]

# preprocess

In [None]:
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)

In [None]:
# tokenize data
def preprocess_function(examples):
    inputs = [doc for doc in examples["asr_sentences"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LEN, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["ground_truth"], max_length=MAX_TARGET_LEN, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    model_inputs["length"] = [len(input_ids) for input_ids in model_inputs["input_ids"]]
    
    return model_inputs

In [None]:
tokenized_train_datasets = train_dataset.map(preprocess_function, batched=True)
tokenized_valid_datasets = valid_dataset.map(preprocess_function, batched=True)

In [None]:
print(tokenized_train_datasets[0])

# Model

In [None]:
# load model
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_CHECKPOINT,
    max_length=MAX_TARGET_LEN,
    use_cache=False,
)

In [None]:
# set training hyperparameters
args = Seq2SeqTrainingArguments(
    EXP_NAME,
    
    evaluation_strategy="steps",
    eval_steps=250,
    logging_strategy="steps",
    logging_steps=25,
    save_steps=250,
    
    seed=87,
    data_seed=87,
    group_by_length=True,
    
    load_best_model_at_end=True,
    metric_for_best_model='mcer',
    greater_is_better=False,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE*2,
    gradient_accumulation_steps=GRAG_ACC_STEP,
    gradient_checkpointing=True,
    
    optim="adafactor",
    num_train_epochs=5,
    learning_rate=3e-4,
    weight_decay=0.00,
    warmup_ratio=0.06,
    lr_scheduler_type='cosine',
    predict_with_generate=True,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
# metric functions
cer = load_metric("cer")

def mcer(predictions, references):
    return np.mean([jiwer.cer(ref, pred) for pred, ref in zip(predictions, references)])

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    if not dist.is_initialized() or dist.get_rank() == 0:
        print(list(zip(decoded_labels[:100], decoded_preds[:100])))
        for idx, (gt, pred) in enumerate(zip(decoded_labels[:500], decoded_preds[:500])):
            if gt != pred:
                print(idx, gt, pred)
    result = {}
    result['cer'] = cer.compute(predictions=decoded_preds, references=decoded_labels)
    result['mcer'] = mcer(predictions=decoded_preds, references=decoded_labels)
    
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
# trainer
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train_datasets,
    eval_dataset=tokenized_valid_datasets,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Train

In [None]:
# training
trainer.train()

In [None]:
# save model
model.save_pretrained(f"{EXP_NAME}/best")

# Eval

In [None]:
# evaluation
eval_result = trainer.evaluate(tokenized_valid_datasets)
print(eval_result)

# To ONNX

In [None]:
# convert model to onnx format for faster inference
!cd onnxruntime/onnxruntime/python/tools/transformers/models/t5 && python convert_to_onnx.py \
-m $f"{EXP_NAME}/best" \
--output api/onnx_models/f0_best \
--use_gpu