# Digital Forensics and Biometrics Project
Scalco Riccardo (id: 2155352)\
Ferrari Luca (id: 2166294)

In [None]:
from pathlib import Path
import pandas as pd
import torchaudio
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

import torchaudio.transforms as T
import torchaudio.functional as AF
import random

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import IPython.display as ipd

from tqdm import tqdm
import numpy as np
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_curve, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_fscore_support, classification_report, confusion_matrix
import seaborn as sns
import torch.optim.lr_scheduler as lr_scheduler
from torch.amp import GradScaler, autocast
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
RANDOM_STATE=42
N_SAMPLE_TRAIN=5000
N_SAMPLE_VAL=4000
SAMPLE_RATE = 16_000
EPOCH = 50
PATIENCE = 5
N_MELS = 80
N_LFCC = 60
PROBABILITY_AUGMENT = 0.3

In [None]:
DATA_ROOT = Path("/kaggle/input/asvpoof-2019-dataset")

AUDIO_DIR_LA = DATA_ROOT / "LA" / "LA" / "ASVspoof2019_LA_train" / "flac"
PROTO_FILE_LA = DATA_ROOT / "LA" / "LA" / "ASVspoof2019_LA_cm_protocols" / "ASVspoof2019.LA.cm.train.trn.txt"

AUDIO_DIR_VAL_LA = DATA_ROOT / "LA" / "LA" / "ASVspoof2019_LA_dev" / "flac"
PROTO_FILE_VAL_LA = DATA_ROOT / "LA" / "LA" / "ASVspoof2019_LA_cm_protocols" / "ASVspoof2019.LA.cm.dev.trl.txt"

AUDIO_DIR_TEST_LA = DATA_ROOT / "LA" / "LA" / "ASVspoof2019_LA_eval" / "flac"
PROTO_FILE_TEST_LA = DATA_ROOT / "LA" / "LA" / "ASVspoof2019_LA_cm_protocols" / "ASVspoof2019.LA.cm.eval.trl.txt"

AUDIO_DIR_PA = DATA_ROOT / "PA" / "PA" / "ASVspoof2019_PA_train" / "flac"
PROTO_FILE_PA = DATA_ROOT / "PA" / "PA" / "ASVspoof2019_PA_cm_protocols" / "ASVspoof2019.PA.cm.train.trn.txt"

AUDIO_DIR_VAL_PA = DATA_ROOT / "PA" / "PA" / "ASVspoof2019_PA_dev" / "flac"
PROTO_FILE_VAL_PA = DATA_ROOT / "PA" / "PA" / "ASVspoof2019_PA_cm_protocols" / "ASVspoof2019.PA.cm.dev.trl.txt"

AUDIO_DIR_TEST_PA = DATA_ROOT / "PA" / "PA" / "ASVspoof2019_PA_eval" / "flac"
PROTO_FILE_TEST_PA = DATA_ROOT / "PA" / "PA" / "ASVspoof2019_PA_cm_protocols" / "ASVspoof2019.PA.cm.eval.trl.txt"

In [None]:
def read_proto_file(proto_file, audio_dir, audio_ext, tag):
    rows = []
    with open(proto_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            speaker_id = parts[0]
            audio_file_name = parts[1]
            environment_id = parts[2]
            attack_id = parts[3]
            label = parts[-1]  # "spoof" o "bonafide"
            
            filepath = audio_dir / f"{audio_file_name}.{audio_ext}"
            rows.append({
                "speaker_id": speaker_id,
                "audio_file_name": audio_file_name,
                #"environment_id": environment_id,
                #"attack_id": attack_id,
                "label": 1 if label == "spoof" else 0,
                "path": filepath,
                "tag": tag
            })
    return pd.DataFrame(rows)

In [None]:
la_train = read_proto_file(PROTO_FILE_LA, AUDIO_DIR_LA, "flac", "LA")
la_t_1 = la_train[la_train["label"] == 1].sample(n=N_SAMPLE_TRAIN//2, random_state=RANDOM_STATE)
la_t_0 = la_train[la_train["label"] == 0].sample(n=N_SAMPLE_TRAIN//2, random_state=RANDOM_STATE)
la_train_reduced = pd.concat([la_t_1, la_t_0], ignore_index=True)
pa_train = read_proto_file(PROTO_FILE_PA, AUDIO_DIR_PA, "flac", "PA")
pa_t_1 = pa_train[pa_train["label"] == 1].sample(n=N_SAMPLE_TRAIN//2, random_state=RANDOM_STATE)
pa_t_0 = pa_train[pa_train["label"] == 0].sample(n=N_SAMPLE_TRAIN//2, random_state=RANDOM_STATE)
pa_train_reduced = pd.concat([pa_t_1, pa_t_0], ignore_index=True)
train_df = pd.concat([la_train_reduced, pa_train_reduced], ignore_index=True)

la_validation = read_proto_file(PROTO_FILE_VAL_LA, AUDIO_DIR_VAL_LA, "flac", "LA")
la_v_1 = la_validation[la_validation["label"] == 1].sample(n=N_SAMPLE_VAL//2, random_state=RANDOM_STATE)
la_v_0 = la_validation[la_validation["label"] == 0].sample(n=N_SAMPLE_VAL//2, random_state=RANDOM_STATE)
la_val_reduced = pd.concat([la_v_1, la_v_0], ignore_index=True)
pa_validation = read_proto_file(PROTO_FILE_VAL_PA, AUDIO_DIR_VAL_PA, "flac", "PA")
pa_v_1 = pa_validation[pa_validation["label"] == 1].sample(n=N_SAMPLE_VAL//2, random_state=RANDOM_STATE)
pa_v_0 = pa_validation[pa_validation["label"] == 0].sample(n=N_SAMPLE_VAL//2, random_state=RANDOM_STATE)
pa_val_reduced = pd.concat([pa_v_1, pa_v_0], ignore_index=True)
validation_df = pd.concat([la_val_reduced, pa_val_reduced], ignore_index=True)

In [None]:
pd.set_option('display.max_colwidth', None)
train_df.head()

In [None]:
pd.set_option('display.max_colwidth', None)
train_df.tail()

In [None]:
train_df['label'].value_counts()

In [None]:
validation_df['label'].value_counts()

## Check for the Augmentation possibility

The following datasets contain a set of file that will be added to our sample to add some noise or some reverbation 

In [None]:
def plot_waveform_and_spectrogram(waveform, sr, title="Audio Analysis"):
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))

    # --- Waveform ---
    axes[0].plot(waveform[0].numpy())
    axes[0].set_title("Waveform")
    axes[0].set_xlabel("Samples")
    axes[0].set_ylabel("Amplitude")

    # --- Spectrogram ---
    Pxx, freqs, bins, im = axes[1].specgram(waveform[0].numpy(), Fs=sr)
    axes[1].set_title("Spectrogram")
    axes[1].set_xlabel("Time (s)")
    axes[1].set_ylabel("Frequency (Hz)")
    
    fig.colorbar(im, ax=axes[1], label="dB")

    fig.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
sample_path = train_df.loc[0, 'path']
tag = train_df.loc[0, 'tag']
waveform, sr = torchaudio.load(sample_path)
print("Sampling rate:", sr)
print("Waveform shape:", waveform.shape)
plot_waveform_and_spectrogram(waveform, sr, "Original Audio")

### Noise

In [None]:
wav = waveform
noise = torch.randn_like(wav)
snr = random.uniform(15, 30)
snr = 10 ** (snr / 20)
noise * wav.std() / (snr * noise.std() + 1e-8)
wav = wav + noise

plot_waveform_and_spectrogram(waveform, sr, "Original Audio")
plot_waveform_and_spectrogram(wav, sr, "Noisy Audio")

print("Original:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))

print("Noisy:")
ipd.display(ipd.Audio(wav.numpy(), rate=sr))

### Time-Stretch

In [None]:
wav = waveform
rate = random.uniform(0.9, 1.1)
length = int(wav.shape[1] / rate)
wav = torch.nn.functional.interpolate(
    wav.unsqueeze(0), size=length, mode="linear", align_corners=False
).squeeze(0)

plot_waveform_and_spectrogram(waveform, sr, "Original Audio")
plot_waveform_and_spectrogram(wav, sr, "Time-Stretch Audio")

print("Original:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))

print("Time-Stretch:")
ipd.display(ipd.Audio(wav.numpy(), rate=sr))

### Pitch Shift

In [None]:
wav = waveform
n_steps = random.choice([-1, 1])
fft = torch.fft.rfft(wav)
shift = int(n_steps * fft.shape[-1] / 12)  
fft = torch.roll(fft, shifts=shift, dims=-1)
wav = torch.fft.irfft(fft, n=wav.shape[-1])

plot_waveform_and_spectrogram(waveform, sr, "Original Audio")
plot_waveform_and_spectrogram(wav, sr, "Pitch Shift Audio")

print("Original:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))

print("Pitch Shift:")
ipd.display(ipd.Audio(wav.numpy(), rate=sr))

### Time Mask

In [None]:
time_mask_param = int(0.1 * waveform.shape[1]) 
t0 = random.randint(0, waveform.shape[1] - time_mask_param)
t_len = random.randint(50, time_mask_param)

wav = waveform.clone()
wav[:, t0:t0+t_len] = 0.0  

plot_waveform_and_spectrogram(waveform, sr, "Original Audio")
plot_waveform_and_spectrogram(wav, sr, "Audio with Random Time-Mask")

print("Original:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))

print("Random Time-Mask:")
ipd.display(ipd.Audio(wav.numpy(), rate=sr))

### Frequency Mask

In [None]:
wav = waveform[0].unsqueeze(0) 
n_mels = 64
n_fft = 1024
hop_length = 256

mel_spectrogram = T.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=n_fft,
    hop_length=hop_length,
    n_mels=n_mels
)
mel_spec = mel_spectrogram(wav)

freq_mask_param = int(n_mels * 0.2)
f_start = random.randint(0, n_mels - freq_mask_param)
mel_spec[:, f_start:f_start+freq_mask_param, :] = 0
inverse_mel = T.InverseMelScale(n_stft=n_fft//2 + 1, n_mels=n_mels)
linear_spec = inverse_mel(mel_spec)
griffin_lim = T.GriffinLim(n_fft=n_fft, hop_length=hop_length)
wav = griffin_lim(linear_spec)

plot_waveform_and_spectrogram(waveform, sr, "Original Audio")
plot_waveform_and_spectrogram(wav, sr, "Audio with Random Frequency-Mask")

print("Original:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))

print("Random Frequency-Mask:")
ipd.display(ipd.Audio(wav.numpy(), rate=sr))

In this last case is convenient to plot also the spectrogram before the inverse trasformation, to see clearly where we delete some frequency.

In [None]:
plt.figure(figsize=(10, 4))
plt.imshow(mel_spec[0].log2().numpy(), aspect='auto', origin='lower', cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title("MelSpectrogram con Frequency Mask")
plt.xlabel("Time Frames")
plt.ylabel("Mel Bands")
plt.tight_layout()
plt.show()

## Define the Augment class

In [None]:
class MelFeatureExtractor:
    def __init__(self,
                 target_sr=16000,
                 n_fft=512,
                 hop_length=160,
                 win_length=400,
                 n_mels=80):

        self.target_sr = target_sr
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.n_mels = n_mels

        self.mel_transform = T.MelSpectrogram(
            sample_rate=self.target_sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            n_mels=self.n_mels
        )
        self.db_transform = T.AmplitudeToDB()

    def __call__(self, wav: torch.Tensor):
        mel_spec = self.mel_transform(wav)
        mel_spec_db = self.db_transform(mel_spec)
        return mel_spec_db

In [None]:
class Augment(torch.nn.Module):
    def __init__(self, enable=True, probability=0.2):
        super().__init__()
        self.enable = enable
        self.probability = probability
        self.mel_representation = MelFeatureExtractor() 

    def forward(self, wav: torch.Tensor, tag: str = "LA"):
        if not self.enable:
            mel_rep = self.mel_representation(wav)
            return mel_rep

        if tag == "LA":
            if random.random() < self.probability:
                noise = torch.randn_like(wav)
                snr = random.uniform(15, 30)
                wav = wav + self._scale_to_snr(wav, noise, snr)


        if random.random() < self.probability:
            rate = random.uniform(0.9, 1.1)
            length = int(wav.shape[1] / rate)
            wav = torch.nn.functional.interpolate(
                wav.unsqueeze(0), size=length, mode="linear", align_corners=False
            ).squeeze(0)

        if random.random() < self.probability:
            n_steps = random.choice([-1, 1])
            fft = torch.fft.rfft(wav)
            shift = int(n_steps * fft.shape[-1] / 12) 
            fft = torch.roll(fft, shifts=shift, dims=-1)
            wav = torch.fft.irfft(fft, n=wav.shape[-1])

        if random.random() < self.probability:
            time_mask_param = int(0.1 * waveform.shape[1]) 
            t0 = random.randint(0, waveform.shape[1] - time_mask_param)
            t_len = random.randint(50, time_mask_param)
            wav = waveform.clone()
            wav[:, t0:t0+t_len] = 0.0

        mel_rep = self.mel_representation(wav)
        
        if random.random() < self.probability:
            freq_mask_param = int(n_mels * 0.2)
            f_start = random.randint(0, n_mels - freq_mask_param)
            mel_rep[:, f_start:f_start+freq_mask_param, :] = 0

        return mel_rep

    @staticmethod
    def _scale_to_snr(clean, noise, snr_db):
        snr = 10 ** (snr_db / 20)
        return noise * clean.std() / (snr * noise.std() + 1e-8)


## Dataset & Dataloader

In [None]:
class SpoofDataset(Dataset):
    def __init__(self, df, mel_dir):
        self.df = df
        self.mel_dir = mel_dir

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = Path(self.mel_dir)
        path = path / f"{row['audio_file_name']}.pt"
        mel_db = torch.load(path)
        label = row["label"]
        return mel_db, torch.tensor(label, dtype=torch.long)

In [None]:
def collate_fn(batch):
    mels, labels = zip(*batch)
    max_len = max([mel.shape[-1] for mel in mels])
    padded_mels = []
    for mel in mels:
        pad_len = max_len - mel.shape[-1]
        padded = torch.nn.functional.pad(mel, (0, pad_len))
        padded_mels.append(padded)
    return torch.stack(padded_mels), torch.tensor(labels)

In [None]:
def save_mel(df, output_path, augmentation=False, probability_of_augment=0.2):
    output_dir = Path(output_path)
    output_dir.mkdir(exist_ok=True, parents=True)

    aug = Augment(enable=augmentation, probability=probability_of_augment)

    for i, row in df.iterrows():
        path = row['path']
        waveform, sr = torchaudio.load(path)
        tag = row['tag']
        mel_db = aug(waveform, tag)

        save_path = output_dir / f"{row['audio_file_name']}.pt"
        torch.save(mel_db, save_path)
        
        if (i+1) % 5000 == 0 or (i+1) == len(df):
            print(f"{i+1}/{len(df)} Mel salvati")

In [None]:
save_mel(
    df=train_df, 
    output_path="mel_train", 
    augmentation=True,
    probability_of_augment=PROBABILITY_AUGMENT
)

In [None]:
save_mel(
    df=validation_df, 
    output_path="mel_val", 
    augmentation=False,
    probability_of_augment=PROBABILITY_AUGMENT
)

In [None]:
train_dataset = SpoofDataset(train_df, "mel_train")
val_dataset = SpoofDataset(validation_df, "mel_val")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)

## Model

In [None]:
class SEBlock(nn.Module):    
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.global_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [None]:
class ResidualSEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, reduction=16):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.se = SEBlock(out_channels, reduction)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        self.dropout = nn.Dropout2d(0.1)
    
    def forward(self, x):
        residual = x
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        
        out = self.se(out)
        out += self.shortcut(residual)
        out = F.relu(out)
        
        return out

In [None]:
class AttentionPooling(nn.Module):    
    def __init__(self, in_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_dim, in_dim // 4, 1),
            nn.ReLU(),
            nn.Conv2d(in_dim // 4, 1, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        att_weights = self.attention(x)  
        
        weighted_features = x * att_weights
        pooled = F.adaptive_avg_pool2d(weighted_features, 1)  
        
        return pooled.squeeze(-1).squeeze(-1) 

In [None]:
class SpoofNet(nn.Module):  
    def __init__(self, input_features=227, num_classes=2, dropout=0.3):
        super().__init__()
        
        self.input_conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.layer1 = self._make_layer(64, 128, 2, stride=1)
        self.layer2 = self._make_layer(128, 256, 2, stride=2)
        self.layer3 = self._make_layer(256, 512, 2, stride=2)
        self.layer4 = self._make_layer(512, 512, 2, stride=2)
        
        self.attention_pool = AttentionPooling(512)
        
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
        
        self._initialize_weights()
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualSEBlock(in_channels, out_channels, stride))
        
        for _ in range(1, num_blocks):
            layers.append(ResidualSEBlock(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1) 
        
        x = self.input_conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.attention_pool(x)
        x = self.classifier(x)
        
        return x

## Training Loop

In [None]:
def train_model(model, train_loader, val_loader, device, 
                epochs=20, patience=5, lr=1e-3):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float("inf")
    patience_counter = 0
    
    history = {"train_loss": [], "val_loss": [], 
               "val_acc": [], "val_precision": [], "val_recall": [], "val_f1": []}

    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        for mel_batch, labels in train_loader:
            mel_batch, labels = mel_batch.to(device), labels.to(device)

            outputs = model(mel_batch)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * mel_batch.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []
        all_probs = []

        with torch.no_grad():
            for mel_batch, labels in val_loader:
                mel_batch, labels = mel_batch.to(device), labels.to(device)
        
                outputs = model(mel_batch)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * mel_batch.size(0)
        
                probs = torch.softmax(outputs, dim=1)[:,1].cpu().numpy()  # proba classe "spoof"
                preds = torch.argmax(outputs, dim=1).cpu().numpy()
        
                all_probs.extend(probs)
                all_preds.extend(preds)
                all_labels.extend(labels.cpu().numpy())
        
        plot_confusion_matrix(all_labels, all_preds)
        plot_calibration_curve(all_labels, all_probs)
        
        val_loss /= len(val_loader.dataset)
        acc = accuracy_score(all_labels, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, all_preds, average="binary", zero_division=0
        )

        # Logging
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(acc)
        history["val_precision"].append(precision)
        history["val_recall"].append(recall)
        history["val_f1"].append(f1)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"Acc: {acc:.4f} | "
              f"Prec: {precision:.4f} | "
              f"Rec: {recall:.4f} | "
              f"F1: {f1:.4f}")

        # Early Stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered!")
                break

    print("Training finished. Best model saved as best_model.pth")
    return history

In [None]:
def plot_confusion_matrix(y_true, y_pred, labels=[0,1], class_names=["bonafide","spoof"]):
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
def plot_calibration_curve(y_true, y_probs, n_bins=10):
    prob_true, prob_pred = calibration_curve(y_true, y_probs, n_bins=n_bins)

    plt.figure(figsize=(6,6))
    plt.plot(prob_pred, prob_true, marker='o', label="Model calibration")
    plt.plot([0,1], [0,1], linestyle="--", color="gray", label="Perfectly calibrated")
    plt.xlabel("Mean predicted probability")
    plt.ylabel("Fraction of positives")
    plt.title("Calibration Curve (Reliability Diagram)")
    plt.legend()
    plt.show()

In [None]:
N_MFCC = 39
model = SpoofNet(input_features=N_MELS + N_MFCC + 60).to(device)
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=EPOCH,
    patience=PATIENCE,
    lr=1e-3
)

# Test

In [None]:
print("Loading best model...")
best_model.load_state_dict(torch.load("best_model.pth"))
print("Best model loaded successfully!")

In [None]:
N_SAMPLE_TEST=4000

In [None]:
la_test = read_proto_file(PROTO_FILE_TEST_LA, AUDIO_DIR_TEST_LA, "flac", "LA")
la_t_1 = la_test[la_test["label"] == 1].sample(n=N_SAMPLE_TEST//2, random_state=RANDOM_STATE)
la_t_0 = la_test[la_test["label"] == 0].sample(n=N_SAMPLE_TEST//2, random_state=RANDOM_STATE)
la_test_reduced = pd.concat([la_t_1, la_t_0], ignore_index=True)
pa_test = read_proto_file(PROTO_FILE_TEST_PA, AUDIO_DIR_TEST_PA, "flac", "PA")
pa_t_1 = pa_test[pa_test["label"] == 1].sample(n=N_SAMPLE_TEST//2, random_state=RANDOM_STATE)
pa_t_0 = pa_test[pa_test["label"] == 0].sample(n=N_SAMPLE_TEST//2, random_state=RANDOM_STATE)
pa_test_reduced = pd.concat([pa_t_1, pa_t_0], ignore_index=True)
test_df = pd.concat([la_test_reduced, pa_test_reduced], ignore_index=True)

In [None]:
save_mel(
    df=test_df, 
    output_path="mel_test", 
    augmentation=False,
    probability_of_augment=0.0
)

In [None]:
test_dataset = SpoofDataset(test_df, "mel_test")
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, 
                        num_workers=4, pin_memory=True, collate_fn=collate_fn)

In [None]:
def calculate_eer(y_true, y_scores):
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))]
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]

    return eer * 100, eer_threshold

In [None]:
def plot_score_distribution(y_true, y_scores, title="Score Distribution"):
    """Plot score distribution for bonafide vs spoof"""
    bonafide_scores = y_scores[y_true == 0]
    spoof_scores = y_scores[y_true == 1]

    plt.figure(figsize=(10, 6))
    plt.hist(bonafide_scores, bins=50, alpha=0.7, label='Bonafide', color='blue', density=True)
    plt.hist(spoof_scores, bins=50, alpha=0.7, label='Spoof', color='red', density=True)
    plt.xlabel('Prediction Score')
    plt.ylabel('Density')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
def comprehensive_evaluation(model, test_loader, device, dataset_name="Test"):
    """Comprehensive evaluation pipeline"""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()
    
    print(f"Evaluating on {dataset_name} set...")
    
    with torch.no_grad():
        for mel_batch, labels in tqdm(test_loader, desc="Evaluating"):
            mel_batch, labels = mel_batch.to(device), labels.to(device)
            
            outputs = model(mel_batch)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * mel_batch.size(0)
            
            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    avg_loss = total_loss / len(test_loader.dataset)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')
    auc = roc_auc_score(all_labels, all_probs)
    eer, eer_threshold = calculate_eer(all_labels, all_probs)
    
    print(f"\n{dataset_name} Set Evaluation Results:")
    print("="*50)
    print(f"Loss: {avg_loss:.4f}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"AUC-ROC: {auc:.4f}")
    print(f"EER: {eer:.2f}%")
    print(f"EER Threshold: {eer_threshold:.4f}")
    print("="*50)
    
    plot_score_distribution(all_labels, all_probs, f"{dataset_name} Set - Score Distribution")

    plt.figure(figsize=(8, 6))
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Bonafide', 'Spoof'], 
                yticklabels=['Bonafide', 'Spoof'])
    plt.title(f'{dataset_name} Set - Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()
    
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'{dataset_name} Set - ROC Curve')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    prob_true, prob_pred = calibration_curve(all_labels, all_probs, n_bins=10)
    plt.figure(figsize=(8, 6))
    plt.plot(prob_pred, prob_true, marker='o', linewidth=2, label='Model calibration')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly calibrated')
    plt.xlabel('Mean Predicted Probability')
    plt.ylabel('Fraction of Positives')
    plt.title(f'{dataset_name} Set - Calibration Curve')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    from sklearn.metrics import precision_recall_curve, average_precision_score
    precision_curve, recall_curve, _ = precision_recall_curve(all_labels, all_probs)
    ap_score = average_precision_score(all_labels, all_probs)
    
    plt.figure(figsize=(8, 6))
    plt.plot(recall_curve, precision_curve, color='blue', lw=2, 
             label=f'PR curve (AP = {ap_score:.4f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'{dataset_name} Set - Precision-Recall Curve')
    plt.legend(loc="lower left")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'eer': eer,
        'eer_threshold': eer_threshold,
        'predictions': all_preds,
        'probabilities': all_probs,
        'true_labels': all_labels
    }

In [None]:
test_results = comprehensive_evaluation(best_model, test_loader, device, "Test")