In [None]:
import logging
from pathlib import Path

import lightning as L
import numpy as np
import torch as th
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader
from torchmetrics.classification.f_beta import F1Score
from torchmetrics import MetricCollection

from mamkit.configs.base import ConfigKey
from mamkit.configs.text import TransformerConfig
from mamkit.data.collators import UnimodalCollator, TextTransformerCollator
from mamkit.data.datasets import MMUSEDFallacy, InputMode
from mamkit.data.processing import UnimodalProcessor
from mamkit.models.text import Transformer
from mamkit.utility.callbacks import PycharmProgressBar
from mamkit.utility.model import MAMKitLightingModel

In [None]:
# Setup
logging.basicConfig(level=logging.INFO)
save_path = Path('results/mmused-fallacy/afc/text_only_bert')
save_path.mkdir(parents=True, exist_ok=True)

base_data_path = Path('data')

# Load dataset
loader = MMUSEDFallacy(task_name='afc', input_mode=InputMode.TEXT_ONLY, base_data_path=base_data_path)

# Load config
config = TransformerConfig.from_config(
    key=ConfigKey(dataset='mmused-fallacy', task_name='afc', input_mode=InputMode.TEXT_ONLY, tags={'anonymous', 'bert'})
)

# Training args
trainer_args = {
    'accelerator': 'auto',
    'devices': 1,
    'accumulate_grad_batches': 8,
    'max_epochs': 3,
}

metrics = {}

In [None]:
# Training and Evaluation Loop
for seed in config.seeds:
    seed_everything(seed=seed)

    for split_info in loader.get_splits(key='mm-argfallacy-2025'):
        processor = UnimodalProcessor()
        processor.fit(split_info.train)
        split_info.train = processor(split_info.train)
        split_info.val = processor(split_info.val)
        split_info.test = processor(split_info.test)
        processor.clear()

        collator = UnimodalCollator(
            features_collator=TextTransformerCollator(model_card=config.model_card, tokenizer_args=config.tokenizer_args),
            label_collator=lambda labels: th.tensor(labels)
        )

        train_loader = DataLoader(split_info.train, batch_size=config.batch_size, shuffle=True, collate_fn=collator)
        val_loader = DataLoader(split_info.val, batch_size=config.batch_size, shuffle=False, collate_fn=collator)
        test_loader = DataLoader(split_info.test, batch_size=config.batch_size, shuffle=False, collate_fn=collator)

        base_model = Transformer(
            model_card=config.model_card,
            is_transformer_trainable=config.is_transformer_trainable,
            dropout_rate=config.dropout_rate,
            head=config.head
        )

        model = MAMKitLightingModel(
            model=base_model,
            loss_function=config.loss_function,
            num_classes=config.num_classes,
            optimizer_class=config.optimizer,
            val_metrics=MetricCollection({'f1': F1Score(task='multiclass', num_classes=6)}),
            test_metrics=MetricCollection({'f1': F1Score(task='multiclass', num_classes=6)}),
            **config.optimizer_args
        )

        trainer = L.Trainer(
            **trainer_args,
            callbacks=[
                EarlyStopping(monitor='val_loss', mode='min', patience=5),
                ModelCheckpoint(monitor='val_loss', mode='min'),
                PycharmProgressBar()
            ]
        )

        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        val_metrics = trainer.test(ckpt_path='best', dataloaders=val_loader)[0]
        test_metrics = trainer.test(ckpt_path='best', dataloaders=test_loader)[0]

        logging.info(f'Validation metrics: {val_metrics}')
        logging.info(f'Test metrics: {test_metrics}')

        for name, val in val_metrics.items():
            metrics.setdefault('validation', {}).setdefault(name, []).append(val)
        for name, val in test_metrics.items():
            metrics.setdefault('test', {}).setdefault(name, []).append(val)

        processor.reset()

In [None]:
# Averaging Metrics
for split in ['validation', 'test']:
    for metric_name, values in metrics[split].items():
        values_np = np.array(values).reshape(len(config.seeds), -1)
        per_seed_avg = values_np.mean(axis=-1)
        per_seed_std = values_np.std(axis=-1)
        avg = per_seed_avg.mean()
        std = per_seed_avg.std()

        metrics[split][f'per_seed_avg_{metric_name}'] = (per_seed_avg, per_seed_std)
        metrics[split][f'avg_{metric_name}'] = (avg, std)

logging.info(metrics)
np.save(save_path / 'metrics.npy', metrics)