In [1]:
import torch
import warnings
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

warnings.filterwarnings("ignore")

# model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50", cache_dir='models/mbart-large-50')
# tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", cache_dir='models/mbart-large-50', src_lang="en_XX", tgt_lang="ro_RO")

# src_text = " UN Chief Says There Is No Military Solution in Syria"
# tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"

# model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")

# model(**model_inputs)  # forward pass

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("models/mbart-large-50")

model = AutoModelForSeq2SeqLM.from_pretrained("models/mbart-large-50")

In [None]:
tokenizer.src_lang = 'en_XX'

text = "I love going <mask> the beach"
encoded_ar = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["vi_VN"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

model = MBartForConditionalGeneration.from_pretrained("models/mbart-large-50", cache_dir='models/mbart-large-50')
tokenizer = MBart50TokenizerFast.from_pretrained("models/mbart-large-50", cache_dir='models/mbart-large-50', src_lang="en_XX", tgt_lang="vi_VN")

In [None]:
tokenizer.src_lang = 'en_XX'

text = "I love<mask> to the beach"
encoded_ar = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, decoder_start_token_id=tokenizer.lang_code_to_id["vi_VN"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

### Many to En

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt", cache_dir='cache')
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt", cache_dir='cache')

# translate Vietnamese to English
source = "Tôi không thích ăn cá."
tokenizer.src_lang = "hi_IN"
encoded_src = tokenizer(source, return_tensors="pt")
generated_tokens = model.generate(**encoded_src)
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)x
# => "The head of the UN says there is no military solution in Syria."


In [None]:
model.save_pretrained('models/mbart-large-50-many-to-one-mmt')
tokenizer.save_pretrained('models/mbart-large-50-many-to-one-mmt')

### En to many

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

model = MBartForConditionalGeneration.from_pretrained("models/mbart-large-50-one-to-many-mmt", cache_dir='cache')
tokenizer = MBart50TokenizerFast.from_pretrained("models/mbart-large-50-one-to-many-mmt", src_lang="en_XX", cache_dir='cache')

# model.save_pretrained('models/mbart-large-50-one-to-many-mmt')
# tokenizer.save_pretrained('models/mbart-large-50-one-to-many-mmt')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
article_en = "The Vietnamese's government has announced about the COVID-19 situation in the country."
# model_inputs = tokenizer(article_en, return_tensors="pt", max_length=256, truncation=True, padding="max_length")
model_inputs = tokenizer(article_en, return_tensors="pt")
print(model_inputs)

model_inputs = {k: v.to(device) for k, v in model_inputs.items()}

generated_tokens = model.generate(
    **model_inputs,
    forced_bos_token_id=tokenizer.lang_code_to_id["vi_VN"]
)
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [None]:
sources = [
    "I love swimming in the sea.",
    "Studying natural language processing is fun.",
]

model_inputs = tokenizer(sources, return_tensors="pt", max_length=64, truncation=True, padding="max_length")
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}

with torch.no_grad():
    generated_tokens = model.generate(
        **model_inputs,
        forced_bos_token_id=tokenizer.lang_code_to_id["vi_VN"],
    )
# generated_tokens = model.generate(
#     **model_inputs,
#     forced_bos_token_id=tokenizer.lang_code_to_id["vi_VN"],
# )

tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)