In [40]:
!pip install evaluate



In [41]:
!pip install --upgrade transformers



In [42]:
from datasets import load_dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq
import evaluate
import torch

In [43]:
dataset = load_dataset("bentrevett/multi30k", split={"train": "train", "validation": "validation", "test": "test"})

In [44]:
from transformers import MarianTokenizer, MarianConfig, MarianMTModel

tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")



In [45]:
config = MarianConfig(
    vocab_size=tokenizer.vocab_size,
    encoder_layers=6,
    decoder_layers=6,
    encoder_attention_heads=8,
    decoder_attention_heads=8,
    d_model=512,
    decoder_ffn_dim=2048,
    encoder_ffn_dim=2048,
    activation_function="relu",
    dropout=0.1
)
model = MarianMTModel(config)

source_lang = "en"
target_lang = "de"

In [46]:
def preprocess(batch):
    model_inputs = tokenizer(batch["en"], padding="max_length", truncation=True, max_length=128)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch["de"], padding="max_length", truncation=True, max_length=128)

    labels["input_ids"] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in label_seq]
        for label_seq in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [47]:
tokenized_datasets = dataset.map(preprocess, batched=True)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

bleu = evaluate.load("bleu")

In [48]:
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    filtered_preds, filtered_labels = [], []
    for pred, label in zip(decoded_preds, decoded_labels):
        if label.strip():
            filtered_preds.append(pred)
            filtered_labels.append([label])

    if len(filtered_labels) == 0:
        return {"bleu": 0.0}

    return bleu.compute(predictions=filtered_preds, references=filtered_labels)

In [49]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./mt-checkpoint",
    eval_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,
    predict_with_generate=True,
    logging_dir="./logs",
    report_to="none"
)

In [50]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss,Bleu,Precisions,Brevity Penalty,Length Ratio,Translation Length,Reference Length
1,5.0709,4.754417,0.021373,"[0.32896825396825397, 0.061989852195014336, 0.011177347242921014, 0.0026996305768684286]",0.763122,0.787193,10080,12805
2,4.2648,4.211071,0.047791,"[0.3178559791463017, 0.08035872846741254, 0.026053864168618268, 0.009313406974225688]",0.957823,0.958688,12276,12805
3,3.9448,3.894757,0.064402,"[0.34257522629046727, 0.10027558005156013, 0.039081582804103565, 0.015291183168853703]",0.956765,0.957673,12263,12805
4,3.5753,3.565189,0.081366,"[0.359508547008547, 0.1205955334987593, 0.05010834236186349, 0.020174915523752734]",1.0,1.02335,13104,12805
5,3.222,3.298485,0.105895,"[0.4059772764696196, 0.15630614444843693, 0.07372998616327338, 0.033391915641476276]",0.947189,0.948536,12146,12805
6,2.9791,3.054304,0.13049,"[0.4211469534050179, 0.17859560067681896, 0.08902461595409958, 0.04330065359477124]",1.0,1.002265,12834,12805
7,2.7243,2.908994,0.152053,"[0.4480133238163217, 0.20172488141440276, 0.10849636140251394, 0.058011915961116337]",0.984576,0.984693,12609,12805
8,2.508,2.733187,0.174331,"[0.48776035183802174, 0.23339675636495424, 0.13469021251122418, 0.07736707736707736]",0.93935,0.941117,12051,12805
9,2.3413,2.61696,0.18392,"[0.49630651838623263, 0.24449358690844758, 0.14051112622680012, 0.07858143796485934]",0.961317,0.962046,12319,12805
10,2.184,2.521735,0.201474,"[0.5165319617927994, 0.26550956831330663, 0.15771450934350847, 0.09134354295644619]",0.955623,0.956579,12249,12805


TrainOutput(global_step=36260, training_loss=2.4686950254361406, metrics={'train_runtime': 3507.058, 'train_samples_per_second': 165.381, 'train_steps_per_second': 10.339, 'total_flos': 1.966105165824e+16, 'train_loss': 2.4686950254361406, 'epoch': 20.0})

In [51]:
print(tokenizer.pad_token_id)
print(tokenizer.eos_token)
print(model.config.eos_token_id)

58100
</s>
0


In [52]:
for i in range(3):
    input_ids = tokenized_datasets["validation"]["input_ids"][i]
    labels = tokenized_datasets["validation"]["labels"][i]
    print("Input:", tokenizer.decode(input_ids, skip_special_tokens=True))
    print("Label:", tokenizer.decode(labels, skip_special_tokens=True))

Input: A group of men are loading cotton onto a truck
Label: Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen
Input: A man sleeping in a green room on a couch.
Label: Ein Mann schläft in einem grünen Raum auf einem Sofa.
Input: A boy wearing headphones sits on a woman's shoulders.
Label: Ein Junge mit Kopfhörern sitzt auf den Schultern einer Frau.


In [53]:
for i in range(5):
    input_ids = tokenized_datasets["test"]["input_ids"][i]
    label_ids = tokenized_datasets["test"]["labels"][i]

    input_text = tokenizer.decode(input_ids, skip_special_tokens=True)
    label_text = tokenizer.decode(label_ids, skip_special_tokens=True)
    output_ids = model.generate(torch.tensor([input_ids]).to(model.device), max_length=128)[0]
    output_text = tokenizer.decode(output_ids, skip_special_tokens=True)

    print(f"[EN] {input_text}")
    print(f"[GT] {label_text}")
    print(f"[PR] {output_text}\n")

[EN] A man in an orange hat starring at something.
[GT] Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
[PR] Ein Mann mit einem orangefarbenen Hut starrt etwas.

[EN] A Boston Terrier is running on lush green grass in front of a white fence.
[GT] Ein Boston Terrier läuft über saftig-grünes Gras vor einem weißen Zaun.
[PR] Ein Fischer läuft vor einem weißen Zaun in der Nähe eines weißen Zauns.

[EN] A girl in karate uniform breaking a stick with a front kick.
[GT] Ein Mädchen in einem Karateanzug bricht ein Brett mit einem Tritt.
[PR] Ein Mädchen in Karateanzug macht einen Stock vor einem Waschbecken.

[EN] Five people wearing winter jackets and helmets stand in the snow, with snowmobiles in the background.
[GT] Fünf Leute in Winterjacken und mit Helmen stehen im Schnee mit Schneemobilen im Hintergrund.
[PR] Fünf Personen in orangefarbenen Westen und mit Helmen stehen im Schnee, im Hintergrund ist ein Schneeer.

[EN] People are fixing the roof of a house.
[GT] Leute Repariere