## muP on Huggingface demo on Translation model with IWSLT En-Ko corpus & T5 structure

### requirements:
* numpy>=1.18.5
* pandas>=1.1.2
* torch>=1.6.0
* torchvision>=0.7.0
* seaborn>=0.11.2
* transformers>=4.16.2
* pyyaml
* sacrebleu
* sentencepiece
* mup

In [None]:
import os
import random

import numpy as np
import torch
from torch.optim import AdamW

import transformers
from transformers import MT5Tokenizer
from transformers import T5Config, T5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import SchedulerType, get_scheduler

from datasets import load_dataset, load_metric

In [None]:
# set seed for reproducibility
SEED = 890112
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

## Load and prepare IWSLT en-ko corpus with HF datsets.
### you probably want to change DATA_DIR.

In [None]:
DATA_DIR = '/path/to/data'

# load the raw dataset
raw_dataset = load_dataset("iwslt2017", "iwslt2017-en-ko", cache_dir=DATA_DIR)

# prepare the prefix for pre-processing
source_lang = "en"
target_lang = "ko"
prefix = "translate English to Korean: "

# load the pre-trained tokenizer; we use pre-trained mT5 tokenizer, yet we still train the model from scratch.
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-base", model_max_length=128)

# define preprocess function that adds the prefix on decoder-side and tokenizes the dataset
def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# tokenize dataset
tokenized_dataset = raw_dataset.map(preprocess_function, batched=True)

### Set configs with different widths.

In [None]:
# narrow model
narrow_config = T5Config.from_pretrained("t5-small")
narrow_config.d_ff = 1024
narrow_config.d_kv = 32
narrow_config.d_model = 256
narrow_config.num_heads = 8
narrow_config.vocab_size = tokenizer.vocab_size

# wide model
wide_config = T5Config.from_pretrained("t5-small")
wide_config.d_ff = 3072
wide_config.d_kv = 64
wide_config.d_model = 768
wide_config.num_heads = 12
wide_config.vocab_size = tokenizer.vocab_size

### Post processing & metric

In [None]:
metric = load_metric("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

## Build model and Train 
### you probably need to restart this ipynb to train several models

In [None]:
# build model 
model = T5ForConditionalGeneration(config=narrow_config) # narrow model
OUT_DIR = './results/SP-narrow' # narrow model

# model = T5ForConditionalGeneration(config=wide_config) # wide model
# OUT_DIR = './results/SP-wide' # wide model
print("model size: %d" % sum([p.numel() for p in model.parameters()]))

# build optimizer and constant lr scheduler
optimizer = AdamW(params=model.parameters(), lr=1e-3, weight_decay=0.01)
lr_scheduler = get_scheduler(name=SchedulerType.CONSTANT_WITH_WARMUP, optimizer=optimizer, num_warmup_steps=500)

# load data collator; HF trainer needs this.
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# set batch_size infos; default batch_size=64
total_batch_size=64
num_gpus = torch.cuda.device_count()
per_device_batch_size = total_batch_size // num_gpus

# HF Trainer args
training_args = Seq2SeqTrainingArguments(
    output_dir="./results/SP-wide",
    evaluation_strategy="steps",    
    eval_steps=2000,
    per_device_train_batch_size=per_device_batch_size,
    per_device_eval_batch_size=per_device_batch_size,    
    save_total_limit=3,
    num_train_epochs=3,
    fp16=True,
)

# build HF Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    optimizers=[optimizer, lr_scheduler],
    compute_metrics=compute_metrics
)

In [None]:
# train
trainer.train()