In [3]:
from datasets import load_dataset, concatenate_datasets
from toxicity_detection.trainer import Trainer

train_dataset = load_dataset("parquet", data_files="../../data/train_combined.parquet", split="train")
dev_dataset = load_dataset("parquet", data_files="../../data/dev.parquet", split="train")
languages = ["en", "fi", "de"]
n_train_samples = 10
n_dev_samples = 4

def stratified_sample_equal_per_lang(dataset, languages, total_samples, seed=0):
    """Sample equally from each language in `languages`."""
    num_langs = len(languages)
    base_n = total_samples // num_langs
    remainder = total_samples % num_langs
    per_lang_counts = {lang: base_n for lang in languages}
    for lang in languages[:remainder]:
        per_lang_counts[lang] += 1

    per_lang_datasets = []
    for lang in languages:
        lang_subset = dataset.filter(lambda ex, l=lang: ex["lang"] == l)
        lang_subset = lang_subset.shuffle(seed=seed)
        n = per_lang_counts[lang]
        if n > len(lang_subset):
            raise ValueError(f"Requested {n} samples for language '{lang}', "
                            f"but only {len(lang_subset)} available.")
        per_lang_datasets.append(lang_subset.select(range(n)))

    combined = concatenate_datasets(per_lang_datasets).shuffle(seed=seed)
    return combined

train_samples = stratified_sample_equal_per_lang(
    dev_dataset, languages, total_samples=n_train_samples, seed=0
)
# dev_samples = stratified_sample_equal_per_lang(
#     dev_dataset, languages, total_samples=n_dev_samples, seed=1
# )

trainer = Trainer(
    model_name="google/gemma-3-270m",
    train_dataset=train_samples,
    val_dataset=train_samples,
    languages=["en", "fi", "de"],
    batch_size=1,
    learning_rate=1e-4,
    warmup_steps=1,
    accumulation_steps=2,
    eval_steps=1,
    num_epochs=5,
    max_num_checkpoints=1
)
model, tokenizer = trainer.train_model()


Training on device: cpu


Loading weights: 100%|██████████| 236/236 [00:00<00:00, 828.04it/s, Materializing param=norm.weight]                                



Epoch 1/5


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.16it/s].93it/s, loss=2.55, tox_loss=0.653, lang_loss=1.89, labeled=1, step=0]


Validation - Macro F1: 0.4889


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 12.02it/s].83s/it, loss=2.03, tox_loss=0.934, lang_loss=1.1, labeled=1, step=1] 


Validation - Macro F1: 0.5333


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.67it/s].24s/it, loss=1.22, tox_loss=0.768, lang_loss=0.45, labeled=1, step=2]


Validation - Macro F1: 0.2667


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.07it/s].67s/it, loss=2.28, tox_loss=1.76, lang_loss=0.52, labeled=1, step=3]  


Validation - Macro F1: 0.2222


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  8.82it/s].39s/it, loss=6.06, tox_loss=0.000709, lang_loss=6.06, labeled=1, step=4]


Validation - Macro F1: 0.2667


Epoch 1/5 (λ=1.000): 100%|██████████| 10/10 [00:32<00:00,  3.24s/it, loss=2.63, tox_loss=1.85, lang_loss=0.779, labeled=1, step=5]  



Training - Loss: 2.7001, Toxicity Loss: 0.9016, Language Loss: 1.7984, Labeled Samples: 10

Epoch 2/5


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.11it/s].21it/s, loss=9.53, tox_loss=0.0511, lang_loss=9.48, labeled=1, step=5]


Validation - Macro F1: 0.2667


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.35it/s].69s/it, loss=15.9, tox_loss=4.05e-6, lang_loss=15.9, labeled=1, step=6] 


Validation - Macro F1: 0.3333


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.53it/s].04s/it, loss=13.3, tox_loss=2.3, lang_loss=11, labeled=1, step=7]      


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.52it/s].73s/it, loss=29.3, tox_loss=0.0256, lang_loss=29.2, labeled=1, step=8]   


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.01it/s].56s/it, loss=2.68, tox_loss=0.00276, lang_loss=2.68, labeled=1, step=9]


Validation - Macro F1: 0.6667


Epoch 2/5 (λ=1.000): 100%|██████████| 10/10 [00:32<00:00,  3.26s/it, loss=35, tox_loss=3.81, lang_loss=31.2, labeled=1, step=10]   



Training - Loss: 14.1738, Toxicity Loss: 0.7575, Language Loss: 13.4163, Labeled Samples: 10

Epoch 3/5


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  8.37it/s].62it/s, loss=31.8, tox_loss=0.121, lang_loss=31.7, labeled=1, step=10]


Validation - Macro F1: 0.3333


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.12it/s].86s/it, loss=0.0833, tox_loss=0.0322, lang_loss=0.0511, labeled=1, step=11]


Validation - Macro F1: 0.6667


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.19it/s].35s/it, loss=41, tox_loss=2.68e-5, lang_loss=41, labeled=1, step=12]       


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.96it/s].19s/it, loss=37.1, tox_loss=0.132, lang_loss=37, labeled=1, step=13] 


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.92it/s].12s/it, loss=8.75, tox_loss=0.00343, lang_loss=8.74, labeled=1, step=14] 


Validation - Macro F1: 0.6000


Epoch 3/5 (λ=1.000): 100%|██████████| 10/10 [00:26<00:00,  2.67s/it, loss=41.4, tox_loss=0.00944, lang_loss=41.4, labeled=1, step=15]



Training - Loss: 29.9610, Toxicity Loss: 1.0344, Language Loss: 28.9266, Labeled Samples: 10

Epoch 4/5


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.76it/s].71it/s, loss=35.7, tox_loss=4.94e-5, lang_loss=35.7, labeled=1, step=15]


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.08it/s].78s/it, loss=18.3, tox_loss=0.0149, lang_loss=18.3, labeled=1, step=16]  


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.79it/s].87s/it, loss=45.4, tox_loss=0.00352, lang_loss=45.3, labeled=1, step=17]


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.40it/s].95s/it, loss=48, tox_loss=0.0285, lang_loss=48, labeled=1, step=18]     


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 10.30it/s].06s/it, loss=5.69, tox_loss=0.0293, lang_loss=5.66, labeled=1, step=19] 


Validation - Macro F1: 0.6667


Epoch 4/5 (λ=1.000): 100%|██████████| 10/10 [00:24<00:00,  2.46s/it, loss=52.2, tox_loss=6.06, lang_loss=46.2, labeled=1, step=20] 



Training - Loss: 34.7347, Toxicity Loss: 0.6588, Language Loss: 34.0759, Labeled Samples: 10

Epoch 5/5


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 12.87it/s].15it/s, loss=54, tox_loss=0.131, lang_loss=53.9, labeled=1, step=20]


Validation - Macro F1: 0.6667


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.10it/s].64s/it, loss=23.8, tox_loss=6.68e-6, lang_loss=23.8, labeled=1, step=21]


Validation - Macro F1: 0.6667


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.95it/s].93s/it, loss=37.4, tox_loss=2.26e-5, lang_loss=37.4, labeled=1, step=22]


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 11.31it/s].97s/it, loss=51, tox_loss=0.000153, lang_loss=51, labeled=1, step=23]   


Validation - Macro F1: 0.6000


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.29it/s].04s/it, loss=48.9, tox_loss=0.0012, lang_loss=48.9, labeled=1, step=24] 


Validation - Macro F1: 0.6000


Epoch 5/5 (λ=1.000): 100%|██████████| 10/10 [00:24<00:00,  2.45s/it, loss=55.7, tox_loss=0.432, lang_loss=55.2, labeled=1, step=25]



Training - Loss: 39.0322, Toxicity Loss: 0.2150, Language Loss: 38.8173, Labeled Samples: 10


Evaluating: 100%|██████████| 10/10 [00:01<00:00,  9.58it/s]



Validation Metrics:
  Loss: 44.0141
  Toxicity Loss: 0.2113
  Language Loss: 43.8028
  Toxicity Macro F1: 0.6000
  Labeled Samples: 10

Training completed!
