In [None]:
import torch
import numpy as np
import random
seed=1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# https://github.com/huggingface/notebooks/blob/master/examples/translation.ipynb

In [None]:
import sys

In [None]:
from transformers import AutoTokenizer
import sacrebleu
import sys
import os
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, MBartForConditionalGeneration, MBart50TokenizerFast

In [None]:
from mbart_covariate import MBartSeq2SeqTrainer

In [None]:
tgt_lang_to_code = {
    "hi" : "hi_IN",
    "de" : "de_DE",
    "es" : "es_XX",
    "it" : "it_IT",
    "ru" : "ru_RU",
    "ja" : "ja_XX"
}

In [None]:
from datasets import load_dataset, load_metric
MAX_LENGTH=96
metric = load_metric("sacrebleu")

In [None]:
def read_file(fname):
    data = []
    with open(fname) as f:
        for line in f:
            data.append(line.strip())
    return data

def get_data(data_dir, tgt_lang, domain, split):
    source = read_file(f"{data_dir}/en-{tgt_lang}/{split}.{domain}.en")
    formal_translations = read_file(f"{data_dir}/en-{tgt_lang}/{split}.{domain}.formal.{tgt_lang}")
    informal_translations = read_file(f"{data_dir}/en-{tgt_lang}/{split}.{domain}.informal.{tgt_lang}")
    return source, formal_translations, informal_translations

In [None]:
import torch
from torch.utils.data import Dataset

class FormalityData(Dataset):
    
    def __init__(self, data_dir, domain, split, src_lang, tgt_lang, direction):
        self.source, self.formal_translations, self.informal_translations=get_data(data_dir, tgt_lang, domain, split)
        tokenizer.src_lang = "en-XX"
        tokenizer.tgt_lang  = tgt_lang_to_code[tgt_lang]
        self.direction = direction
        self.max_target_length=MAX_LENGTH
        self.max_input_length=MAX_LENGTH
        self.model_inputs = self.encode_split()
        self.tgt_lang = tgt_lang
        
        
    def __len__(self):
        return len(self.model_inputs["input_ids"])
    
    def encode_split(self):
        model_inputs = tokenizer(self.source, max_length=self.max_input_length, truncation=True)
        with tokenizer.as_target_tokenizer():
            if self.direction == "formal":
                labels = tokenizer(self.formal_translations, max_length=self.max_target_length, truncation=True)
            else:
                labels = tokenizer(self.informal_translations, max_length=self.max_target_length, truncation=True)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.model_inputs.items()}
        item["labels"] = self.model_inputs["labels"][idx]
        item["forced_bos_token_id"] = tokenizer.lang_code_to_id[tgt_lang_to_code[self.tgt_lang]]
        return item

In [None]:
def get_training_data(data_dir, direction):
    train_datasets = []
    for domain in ["telephony", "topical-chat"]:
        for tgt_lang in ["hi", "ja", "de", "es"]:
            train_datasets.append(FormalityData(data_dir, domain, "train", "en", tgt_lang, direction))
    train_dataset = torch.utils.data.ConcatDataset(train_datasets)

    dev_datasets = []
    for domain in ["telephony", "topical-chat"]:
        for tgt_lang in ["hi", "ja", "de", "es"]:
            dev_datasets.append(FormalityData(data_dir, domain, "dev", "en", tgt_lang, direction))
    dev_dataset = torch.utils.data.ConcatDataset(dev_datasets)
    
    return train_dataset, dev_dataset

In [None]:
import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
def get_trainer(data_dir, direction, exp_name):
    
    args = Seq2SeqTrainingArguments(
        output_dir=f"../models/{model_name}-finetuned-en-to-xx-{direction}-{exp_name}",
        evaluation_strategy = "epoch",
        learning_rate=3e-5,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=10,
        save_steps=100,
        eval_steps=100,
        predict_with_generate=True,
        fp16=True,
        push_to_hub=False,
    )
    
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="/fs/clip-scratch/sweagraw/CACHE")
#     print(model)
    model.requires_grad_(True)
#     model.shared.embedding.requires_grad_(False)
#     model.get_decoder().layers[-2:].requires_grad_(True)
#     model.lm_head.requires_grad_(True)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    
    train_dataset, dev_dataset = get_training_data(data_dir, direction)
    
    trainer = MBartSeq2SeqTrainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    
    return trainer, model

In [None]:
train_batch_size = 8
eval_batch_size=4
exp_name="test"
data_dir = "../internal_split"
src_lang="en"
direction="formal"
model_name="facebook/mbart-large-50-one-to-many-mmt"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/fs/clip-scratch/sweagraw/CACHE")

In [None]:
trainer, model = get_trainer(data_dir, direction, exp_name)

In [None]:
model.embedding

In [None]:
trainer = get_trainer(data_dir, direction, exp_name)
torch.cuda.empty_cache()
trainer.train()

In [None]:
# https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/finetune_trainer.py
# https://github.com/huggingface/transformers/blob/master/examples/pytorch/translation/run_translation.py