In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [2]:
mt5_tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
mt5_model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")

In [None]:
data_files = {"train": "/home/ubuntu/teach-trans/train_hinglish_english.csv", 'valid': "/home/ubuntu/teach-trans/valid_hinglish_english.csv"}
dataset = load_dataset("csv", data_files=data_files)

In [None]:
source_lang = "en"
target_lang = "hi"
prefix = "translate English to Hinglish: "


def preprocess_function(examples):
    # print(examples)
    inputs = [prefix + example for example in examples['English']]
    targets = examples['Hinglish']
    model_inputs = mt5_tokenizer(inputs, truncation=True)

    with mt5_tokenizer.as_target_tokenizer():
        labels = mt5_tokenizer(targets, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
tokenized_dataset['train'][0]

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=mt5_tokenizer, model=mt5_model)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="/home/ubuntu/teach-trans/results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=mt5_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["valid"],
    tokenizer=mt5_tokenizer,
    data_collator=data_collator,
)

In [None]:
trainer.train()