In [8]:
# Load Your Existing Model
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_path = "t5_decorator_model"
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)

In [2]:
# Add New Duplet
from storage.duplets_dictionary import duplets, new_duplets
duplets += new_duplets

In [3]:
# Train/Validation Split
from sklearn.model_selection import train_test_split

train_duplets, val_duplets = train_test_split(duplets, test_size=0.1, random_state=42)

In [4]:
# Dataset Class
from torch.utils.data import Dataset

class DupletDataset(Dataset):
    def __init__(self, p_duplets, p_tokenizer, max_length=32):
        self.duplets = p_duplets
        self.tokenizer = p_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.duplets)

    def __getitem__(self, idx):
        p_input_text = f"normalize: {self.duplets[idx][0]}"
        target_text = self.duplets[idx][1]
        input_ids = self.tokenizer(p_input_text, truncation=True, padding="max_length", max_length=self.max_length).input_ids
        labels = self.tokenizer(target_text, truncation=True, padding="max_length", max_length=self.max_length).input_ids
        return {"input_ids": input_ids, "labels": labels}

In [5]:
# Trainer Setup
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

train_dataset = DupletDataset(train_duplets, tokenizer)
val_dataset = DupletDataset(val_duplets, tokenizer)

training_args = Seq2SeqTrainingArguments(
    output_dir="./t5_decorator_model",
    eval_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=25,
    logging_dir="./logs",
    predict_with_generate=True,
    save_strategy="epoch",
    save_total_limit=2,
    dataloader_pin_memory=False,  # 👈 disables pin_memory
)

from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

from sklearn.metrics import accuracy_score

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Strip whitespace
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    # Compute exact match accuracy
    acc = accuracy_score(decoded_labels, decoded_preds)
    return {"accuracy": acc}

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [6]:
# Train
trainer.train()

Epoch,Training Loss,Validation Loss


TrainOutput(global_step=525, training_loss=0.08811043727965583, metrics={'train_runtime': 766.9177, 'train_samples_per_second': 5.281, 'train_steps_per_second': 0.685, 'total_flos': 34258393497600.0, 'train_loss': 0.08811043727965583, 'epoch': 25.0})

In [10]:
# Test
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_path = "t5_decorator_model"
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()

from storage.duplets_dictionary import test_cases, normalize

for input_text, expected_output in test_cases:
    predicted_output = normalize(input_text, tokenizer, model)
    if predicted_output != expected_output:
        print(f"❌ {input_text} → {predicted_output} (expected: {expected_output})")
    else:
        print(f"✅ {input_text} → {predicted_output}")

✅ copilului → copil
✅ copiii → copii
✅ mamei → mama
✅ mamelor → mame
✅ fratelui → frate
✅ sportivilor → sportivi
✅ echipelor locale → echipe locale
❌ manelei → maneți (expected: manea)
✅ acadelei → acadea
❌ Stelei București → Stel București (expected: Steaua București)
✅ Stelei → Steaua
✅ Unirii → Unirea
✅ Sibiului → Sibiu
✅ jucătoarei → jucătoare
✅ fotbalistelor tinere → fotbaliste tinere
✅ rugbiștilor francezi → rugbiști francezi
✅ cetățeanului → cetățean
✅ comisiilor → comisii
✅ prefecților → prefecți
