In [None]:
import logging
from pathlib import Path

import lightning as L
import numpy as np
import torch as th
import librosa
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torchmetrics.classification.f_beta import F1Score
from torchmetrics import MetricCollection

from mamkit.configs.audio import BiLSTMMFCCsConfig
from mamkit.configs.base import ConfigKey
from mamkit.data.collators import UnimodalCollator, AudioCollatorOutput
from mamkit.data.datasets import MMUSEDFallacy, InputMode
from mamkit.data.processing import UnimodalProcessor, MFCCExtractor
from mamkit.models.audio import BiLSTM
from mamkit.utility.callbacks import PycharmProgressBar
from mamkit.utility.model import MAMKitLightingModel


In [None]:
logging.basicConfig(level=logging.INFO)
save_path = Path('results/mmused/afc/audio_only_lstm_mfcc')
save_path.mkdir(parents=True, exist_ok=True)
base_data_path = Path('data')


In [None]:
loader = MMUSEDFallacy(task_name='afc',
                       input_mode=InputMode.AUDIO_ONLY,
                       base_data_path=base_data_path)

config = BiLSTMMFCCsConfig.from_config(key=ConfigKey(dataset='mmused-fallacy',
                                                     input_mode=InputMode.AUDIO_ONLY,
                                                     task_name='afc',
                                                     tags='anonymous'))

config.optimizer = th.optim.Adam
config.optimizer_args = {'lr': 1e-3}


In [None]:
trainer_args = {
    'accelerator': 'auto',
    'devices': 1,
    'batch_size': 8,
    'max_epochs': 5,
}


In [None]:
metrics = {}
for seed in config.seeds:
    seed_everything(seed=seed)
    for split_info in loader.get_splits(key='mm-argfallacy-2025'):
        processor = UnimodalProcessor(features_processor=MFCCExtractor(
            sampling_rate=16000,
            normalize=True,
            remove_energy=True,
            pooling_sizes=config.pooling_sizes,
            mfccs=13,
            padding=True,
            max_length=5 * 16000  # Standardize to 5 seconds
        ))

        processor.fit(split_info.train)
        split_info.train = processor(split_info.train)
        split_info.val = processor(split_info.val)
        split_info.test = processor(split_info.test)
        processor.clear()

        unimodal_collator = UnimodalCollator(
            features_collator=AudioCollatorOutput(),
            label_collator=lambda labels: th.tensor(labels)
        )

        train_dataloader = DataLoader(split_info.train,
                                      batch_size=config.batch_size,
                                      shuffle=True,
                                      collate_fn=unimodal_collator)
        val_dataloader = DataLoader(split_info.val,
                                    batch_size=config.batch_size,
                                    shuffle=False,
                                    collate_fn=unimodal_collator)
        test_dataloader = DataLoader(split_info.test,
                                     batch_size=config.batch_size,
                                     shuffle=False,
                                     collate_fn=unimodal_collator)

        model = BiLSTM(embedding_dim=config.embedding_dim,
                       lstm_weights=config.lstm_weights,
                       head=config.head)
        model = MAMKitLightingModel(model=model,
                                    loss_function=config.loss_function,
                                    num_classes=config.num_classes,
                                    optimizer_class=config.optimizer,
                                    val_metrics=MetricCollection({'f1': F1Score(task='multiclass', num_classes=6)}),
                                    test_metrics=MetricCollection({'f1': F1Score(task='multiclass', num_classes=6)}),
                                    **config.optimizer_args)

        trainer = L.Trainer(**trainer_args,
                            callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'),
                                       PycharmProgressBar()])
        trainer.fit(model,
                    train_dataloaders=train_dataloader,
                    val_dataloaders=val_dataloader)

        val_metrics = trainer.test(ckpt_path='best', dataloaders=val_dataloader)[0]
        test_metrics = trainer.test(ckpt_path='best', dataloaders=test_dataloader)[0]
        logging.info(f'Validation metrics: {val_metrics}')
        logging.info(f'Test metrics: {test_metrics}')

        for metric_name, metric_value in val_metrics.items():
            metrics.setdefault('validation', {}).setdefault(metric_name, []).append(metric_value)
        for metric_name, metric_value in test_metrics.items():
            metrics.setdefault('test', {}).setdefault(metric_name, []).append(metric_value)

        processor.reset()


In [None]:
# Averaging
metric_names = list(metrics['validation'].keys())
for split_name in ['validation', 'test']:
    for metric_name in metric_names:
        metric_values = np.array(metrics[split_name][metric_name]).reshape(len(config.seeds), -1)
        per_seed_avg = metric_values.mean(axis=-1)
        per_seed_std = metric_values.std(axis=-1)
        avg = per_seed_avg.mean(axis=-1)
        std = per_seed_avg.std(axis=-1)
        metrics[split_name][f'per_seed_avg_{metric_name}'] = (per_seed_avg, per_seed_std)
        metrics[split_name][f'avg_{metric_name}'] = (avg, std)

logging.info(metrics)
np.save(save_path.joinpath('metrics.npy').as_posix(), metrics)
