In [None]:

import logging
from pathlib import Path

import lightning as L
import numpy as np
import torch as th
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torchmetrics.classification.f_beta import F1Score
from torchmetrics import MetricCollection

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

from mamkit.configs.base import ConfigKey
from mamkit.configs.text_audio import MMTransformerConfig
from mamkit.data.collators import MultimodalCollator, TextTransformerCollator, AudioCollatorOutput
from mamkit.data.datasets import MMUSEDFallacy, InputMode
from mamkit.data.processing import MultimodalProcessor, AudioTransformer
from mamkit.utility.callbacks import PycharmProgressBar

logging.basicConfig(level=logging.INFO)


In [None]:

save_path = Path(__file__).parent.parent.resolve().joinpath('results', 'mmused-fallacy', 'afc', 'text_audio_late_fusion_logreg')
save_path.mkdir(parents=True, exist_ok=True)
base_data_path = Path(__file__).parent.parent.resolve().joinpath('data')

loader = MMUSEDFallacy(task_name='afc',
                       input_mode=InputMode.TEXT_AUDIO,
                       base_data_path=base_data_path)


In [None]:

config = MMTransformerConfig.from_config(key=ConfigKey(dataset='mmused-fallacy',
                                                       input_mode=InputMode.TEXT_AUDIO,
                                                       task_name='afc',
                                                       tags={'anonymous', 'roberta', 'wav2vec2'}))

config.audio_model_card = "facebook/wav2vec2-base-960h"
config.processor_args['sampling_rate'] = 16000
config.audio_model_args['sampling_rate'] = 16000

trainer_args = {
    'accelerator': 'auto',
    'devices': 1,
    'batch_size': 16,
    'max_epochs': 5,
}

In [None]:

def extract_features_and_labels(dataloader, model):
    model.eval()
    features = []
    labels = []
    for batch in dataloader:
        with th.no_grad():
            text_features = model.model.encode_text(batch['text_input_ids'], batch['text_attention_mask'])
            audio_features = model.model.encode_audio(batch['audio_features'])
            fused_features = th.cat([text_features, audio_features], dim=1)
            features.append(fused_features.cpu().numpy())
            labels.append(batch['labels'].cpu().numpy())
    return np.vstack(features), np.concatenate(labels)


In [None]:

metrics = {}
for seed in config.seeds:
    seed_everything(seed=seed)
    for split_info in loader.get_splits(key='mm-argfallacy-2025'):
        processor = MultimodalProcessor(audio_processor=AudioTransformer(
            model_card=config.audio_model_card,
            processor_args=config.processor_args,
            model_args=config.audio_model_args,
            aggregate=config.aggregate,
            downsampling_factor=config.downsampling_factor,
            sampling_rate=config.sampling_rate,
            max_duration=5
        ))
        processor.fit(train_data=split_info.train)
        split_info.train = processor(split_info.train)
        split_info.val = processor(split_info.val)
        split_info.test = processor(split_info.test)
        processor.clear()

        collator = MultimodalCollator(
            text_collator=TextTransformerCollator(model_card=config.text_model_card,
                                                  tokenizer_args=config.tokenizer_args),
            audio_collator=AudioCollatorOutput(),
            label_collator=lambda labels: th.tensor(labels)
        )

        train_dataloader = DataLoader(split_info.train, batch_size=config.batch_size, shuffle=False, collate_fn=collator)
        val_dataloader = DataLoader(split_info.val, batch_size=config.batch_size, shuffle=False, collate_fn=collator)
        test_dataloader = DataLoader(split_info.test, batch_size=config.batch_size, shuffle=False, collate_fn=collator)

        model = MMTransformer(
            model_card=config.text_model_card,
            head=config.head,
            is_transformer_trainable=config.is_transformer_trainable,
            lstm_weights=config.lstm_weights,
            audio_embedding_dim=config.audio_embedding_dim
        )
        model.eval()

        X_train, y_train = extract_features_and_labels(train_dataloader, model)
        X_val, y_val = extract_features_and_labels(val_dataloader, model)
        X_test, y_test = extract_features_and_labels(test_dataloader, model)

        clf = LogisticRegression(max_iter=1000, multi_class='multinomial')
        clf.fit(X_train, y_train)

        y_val_pred = clf.predict(X_val)
        y_test_pred = clf.predict(X_test)

        val_f1 = f1_score(y_val, y_val_pred, average='macro')
        test_f1 = f1_score(y_test, y_test_pred, average='macro')
        logging.info(f'Seed {seed} | Validation F1: {val_f1:.4f} | Test F1: {test_f1:.4f}')

        metrics.setdefault('validation', []).append(val_f1)
        metrics.setdefault('test', []).append(test_f1)

        processor.reset()


In [None]:

val_avg = np.mean(metrics['validation'])
val_std = np.std(metrics['validation'])
test_avg = np.mean(metrics['test'])
test_std = np.std(metrics['test'])

summary = {
    "validation_avg": val_avg,
    "validation_std": val_std,
    "test_avg": test_avg,
    "test_std": test_std
}

logging.info(summary)
np.save(save_path.joinpath('metrics.npy').as_posix(), summary)
