In [30]:
!pip install evaluate



In [31]:
!pip install --upgrade transformers



In [32]:
from datasets import load_dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq
import evaluate
import torch

In [33]:
dataset = load_dataset("bentrevett/multi30k", split={"train": "train", "validation": "validation", "test": "test"})

In [34]:
from transformers import MarianTokenizer, MarianConfig, MarianMTModel

model_name = "transformer-wmt-new"
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
tokenizer.save_pretrained(model_name)



('transformer-wmt-new/tokenizer_config.json',
 'transformer-wmt-new/special_tokens_map.json',
 'transformer-wmt-new/vocab.json',
 'transformer-wmt-new/source.spm',
 'transformer-wmt-new/target.spm',
 'transformer-wmt-new/added_tokens.json')

In [35]:
config = MarianConfig(
    vocab_size=tokenizer.vocab_size,
    encoder_layers=6,
    decoder_layers=6,
    encoder_attention_heads=8,
    decoder_attention_heads=8,
    d_model=512,
    decoder_ffn_dim=2048,
    encoder_ffn_dim=2048,
    activation_function="relu",
    dropout=0.1
)
model = MarianMTModel(config)

source_lang = "en"
target_lang = "de"

In [36]:
def preprocess(batch):
    model_inputs = tokenizer(batch["en"], padding="max_length", truncation=True, max_length=128)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch["de"], padding="max_length", truncation=True, max_length=128)

    labels["input_ids"] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in label_seq]
        for label_seq in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [37]:
tokenized_datasets = dataset.map(preprocess, batched=True)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

bleu = evaluate.load("bleu")

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]



Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [38]:
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    filtered_preds, filtered_labels = [], []
    for pred, label in zip(decoded_preds, decoded_labels):
        if label.strip():
            filtered_preds.append(pred)
            filtered_labels.append([label])

    if len(filtered_labels) == 0:
        return {"bleu": 0.0}

    return bleu.compute(predictions=filtered_preds, references=filtered_labels)

In [42]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./mt-checkpoint",
    eval_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    logging_dir="./logs",
    report_to="none"
)

In [61]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss,Bleu,Precisions,Brevity Penalty,Length Ratio,Translation Length,Reference Length
1,5.7317,11.061228,0.0,"[0.022682445759368838, 0.0, 0.0, 0.0]",1.0,1.504569,19266,12805
2,5.6787,10.565073,0.0,"[0.022682445759368838, 0.0, 0.0, 0.0]",1.0,1.504569,19266,12805
3,5.6361,11.116978,0.0,"[0.022682445759368838, 0.0, 0.0, 0.0]",1.0,1.504569,19266,12805
4,5.6159,11.226173,0.0,"[0.022682445759368838, 0.0, 0.0, 0.0]",1.0,1.504569,19266,12805
5,5.5799,11.129913,0.0,"[0.022682445759368838, 0.0, 0.0, 0.0]",1.0,1.504569,19266,12805


TrainOutput(global_step=9065, training_loss=5.636854202262091, metrics={'train_runtime': 868.602, 'train_samples_per_second': 166.935, 'train_steps_per_second': 10.436, 'total_flos': 4915262914560000.0, 'train_loss': 5.636854202262091, 'epoch': 5.0})

In [57]:
print(tokenizer.pad_token_id)
print(tokenizer.eos_token)
print(model.config.eos_token_id)

58100
</s>
0


In [60]:
for i in range(3):
    input_ids = tokenized_datasets["validation"]["input_ids"][i]
    labels = tokenized_datasets["validation"]["labels"][i]
    print("Input:", tokenizer.decode(input_ids, skip_special_tokens=True))
    print("Label:", tokenizer.decode(labels, skip_special_tokens=True))

Input: A group of men are loading cotton onto a truck
Label: Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen
Input: A man sleeping in a green room on a couch.
Label: Ein Mann schläft in einem grünen Raum auf einem Sofa.
Input: A boy wearing headphones sits on a woman's shoulders.
Label: Ein Junge mit Kopfhörern sitzt auf den Schultern einer Frau.
