In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import sacrebleu
import sys
import os
sys.path.append("/fs/clip-scratch/sweagraw/contrastive-controlled-mt/IWSLT2022")

In [None]:
def read_file(fname):
    data = []
    with open(fname) as f:
        for line in f:
            data.append(line.strip())
    return data

In [None]:
def get_data(tgt_lang, domain, split):
    source = read_file(f"../internal_split/en-{tgt_lang}/{split}.{domain}.en")
    formal_translations = read_file(f"../internal_split/en-{tgt_lang}/{split}.{domain}.formal.{tgt_lang}")
    informal_translations = read_file(f"../internal_split/en-{tgt_lang}/{split}.{domain}.informal.{tgt_lang}")
    return source, formal_translations, informal_translations

In [None]:
tgt_lang_to_code = {
    "hi" : "hi_IN",
    "de" : "de_DE",
    "es" : "es_XX",
    "it" : "it_IT",
    "ru" : "ru_RU",
    "ja" : "ja_XX"
}

In [None]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", cache_dir="/fs/clip-scratch/sweagraw/CACHE")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", cache_dir="/fs/clip-scratch/sweagraw/CACHE")

In [None]:
def translate_text(text, tgt_lang):
    model_inputs = tokenizer(text, return_tensors="pt", padding=True)

    # translate from English to Hindi
    generated_tokens = model.generate(
        **model_inputs,
        forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_to_code[tgt_lang]]
    )
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
src_lang = "en"
tgt_lang = "hi"
domain="combined"
split="dev"
source, formal_translations, informal_translations = get_data(tgt_lang, domain, split)

In [None]:
output_dir=f"../experiments/{src_lang}-{tgt_lang}/mBART_informal/{domain}/"
os.makedirs(output_dir, exist_ok=True)

In [None]:
outputs = translate_text(source, tgt_lang)

In [None]:
 with open(output_dir+"/out."+split, "w") as f:
        for out in outputs:
            f.write(out + "\n")

In [None]:
# Other decodng strategies

In [None]:
model_inputs = tokenizer(source[0], return_tensors="pt", padding=True)

In [None]:
# translate from English to Hindi
generated_tokens = model.generate(
    **model_inputs,
    forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_to_code[tgt_lang]],
    max_length=50, 
    num_beams=5, 
    num_return_sequences=5, 
    early_stopping=True
)

In [None]:
generated_tokens

In [None]:
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

# Evaluation Covariates

In [None]:
from transformers import AutoTokenizer
import sacrebleu
import sys
import os
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MBartForConditionalGeneration, MBart50TokenizerFast

In [None]:
from mbart_covariate import CMBartForConditionalGeneration

In [None]:
model = CMBartForConditionalGeneration.from_pretrained("../models/facebook/mbart-large-50-one-to-many-mmt-finetuned-covariate-en-to-xx", cache_dir="/fs/clip-scratch/sweagraw/CACHE")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", cache_dir="/fs/clip-scratch/sweagraw/CACHE")

In [None]:
import torch
def translate_text(text, tgt_lang, covariate_index):
    model_inputs = tokenizer(text, return_tensors="pt", padding=True)
    kwargs = {}
    kwargs["covariate_ids"] = torch.tensor([covariate_index]*len(text))

    # translate from English to Hindi
    generated_tokens = model.generate(
        **model_inputs,
        forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_to_code[tgt_lang]],
        **kwargs
    )
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
outputs = translate_text(source, tgt_lang, covariate_index=1)

In [None]:
outputs

In [None]:
outputs