In [1]:
from datasets import load_dataset, concatenate_datasets
from toxicity_detection.trainer import Trainer

train_dataset = load_dataset("parquet", data_files="../../data/train_combined.parquet", split="train")
dev_dataset = load_dataset("parquet", data_files="../../data/dev.parquet", split="train")
languages = ["en", "fi", "de"]
n_train_samples = 10
n_dev_samples = 4

def stratified_sample_equal_per_lang(dataset, languages, total_samples, seed=0):
    """Sample equally from each language in `languages`."""
    num_langs = len(languages)
    base_n = total_samples // num_langs
    remainder = total_samples % num_langs
    per_lang_counts = {lang: base_n for lang in languages}
    for lang in languages[:remainder]:
        per_lang_counts[lang] += 1

    per_lang_datasets = []
    for lang in languages:
        lang_subset = dataset.filter(lambda ex, l=lang: ex["lang"] == l)
        lang_subset = lang_subset.shuffle(seed=seed)
        n = per_lang_counts[lang]
        if n > len(lang_subset):
            raise ValueError(f"Requested {n} samples for language '{lang}', "
                            f"but only {len(lang_subset)} available.")
        per_lang_datasets.append(lang_subset.select(range(n)))

    combined = concatenate_datasets(per_lang_datasets).shuffle(seed=seed)
    return combined

train_samples = stratified_sample_equal_per_lang(
    train_dataset, languages, total_samples=n_train_samples, seed=0
)
dev_samples = stratified_sample_equal_per_lang(
    dev_dataset, languages, total_samples=n_dev_samples, seed=1
)

trainer = Trainer(
    model_name="google/gemma-3-270m",
    train_dataset=train_samples,
    val_dataset=dev_samples,
    languages=["en", "fi", "de"],
    batch_size=1,
    learning_rate=1e-4,
    warmup_steps=1,
    accumulation_steps=2,
    eval_steps=1,
    num_epochs=1,
    max_num_checkpoints=1
)
model, tokenizer = trainer.train_model()


  from .autonotebook import tqdm as notebook_tqdm


Training on device: cpu


Loading weights: 100%|██████████| 236/236 [00:00<00:00, 507.20it/s, Materializing param=norm.weight]                                



Epoch 1/1


Evaluating: 100%|██████████| 4/4 [00:00<00:00, 12.04it/s] 1.86it/s, loss=3.02, tox_loss=0.612, lang_loss=2.41, labeled=1, step=0]


Validation - Macro F1: 0.3333


Evaluating: 100%|██████████| 4/4 [00:00<00:00,  8.93it/s] 3.11s/it, loss=2, tox_loss=1.36, lang_loss=0.644, labeled=1, step=1]   


Validation - Macro F1: 0.6667


Evaluating: 100%|██████████| 4/4 [00:00<00:00,  9.92it/s] 3.38s/it, loss=6.29, tox_loss=0.531, lang_loss=5.76, labeled=1, step=2]


Validation - Macro F1: 0.3333


Evaluating: 100%|██████████| 4/4 [00:00<00:00,  8.00it/s] 2.78s/it, loss=13.6, tox_loss=0, lang_loss=13.6, labeled=0, step=3]    


Validation - Macro F1: 0.0000


Evaluating: 100%|██████████| 4/4 [00:00<00:00,  8.33it/s] 2.44s/it, loss=1.83, tox_loss=0, lang_loss=1.83, labeled=0, step=4]


Validation - Macro F1: 0.0000


Epoch 1/1 (λ=1.000): 100%|██████████| 10/10 [00:33<00:00,  3.39s/it, loss=15, tox_loss=0.0737, lang_loss=15, labeled=1, step=5]



Training - Loss: 4.8503, Toxicity Loss: 0.2576, Language Loss: 4.5927, Labeled Samples: 4


Evaluating: 100%|██████████| 4/4 [00:00<00:00,  9.57it/s]



Validation Metrics:
  Loss: 22.3908
  Toxicity Loss: 3.7195
  Language Loss: 18.6714
  Toxicity Macro F1: 0.0000
  Labeled Samples: 4

Training completed!
