In [None]:
import torch
from transformers import BertTokenizer, EncoderDecoderModel, Trainer, TrainingArguments
from datasets import load_dataset

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

In [None]:

dataset = load_dataset("iwslt2017", "iwslt2017-en-de")
train_data = dataset['train'].shuffle(seed=42).select(range(10000))  
val_data = dataset['validation'].select(range(1000))



In [None]:
MAX_LEN = 64

def preprocess(example):
    inputs = tokenizer(example['translation']['en'], padding='max_length', truncation=True, max_length=MAX_LEN)
    targets = tokenizer(example['translation']['de'], padding='max_length', truncation=True, max_length=MAX_LEN)

    inputs['labels'] = targets['input_ids']
    return inputs

train_dataset = train_data.map(preprocess, batched=True)
val_dataset = val_data.map(preprocess, batched=True)

In [None]:
training_args = TrainingArguments(
    output_dir="./transformer-en-de",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    save_steps=500,
    logging_steps=100,
    evaluation_strategy="epoch",
    save_total_limit=2,
    remove_unused_columns=False,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


In [None]:
trainer.train()

In [None]:
def translate(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LEN)
    output_ids = model.generate(**inputs)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


In [None]:




print("EN:", "how are you?")
print("DE:", translate("how are you?"))
