In [15]:
# Dependencies
import os
import shutil
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random
import numpy as np

In [16]:
# Constants and Paths
DATA_ROOT = "./data"
ASV5_DIR = os.path.join(DATA_ROOT, "ASVspoof5")
RAW_AUDIO_DIR = os.path.join(ASV5_DIR, "raw_audio")
PROTOCOL_EVAL = os.path.join(ASV5_DIR, "protocols", "ASVspoof5.eval.track_2.trial.tsv")
TARGET_EVAL_DIR = os.path.join(ASV5_DIR, "eval")
TARGET_TRAIN_DIR = os.path.join(ASV5_DIR, "train")
MAX_PER_CLASS = 500

In [17]:
# Organize audio based on protocol

def organize_asvspoof5_data(protocol_file, output_dir):
    os.makedirs(os.path.join(output_dir, "bonafide"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "spoof"), exist_ok=True)

    bonafide_count, spoof_count = 0, 0
    
    if not os.path.exists(protocol_file):
        print(f"Protocol file not found: {protocol_file}")
        return
        
    with open(protocol_file, 'r') as f:
        for line in f:
            # Split by space instead of tab
            parts = line.strip().split(' ')
            
            # From the debug output, we can see format is:
            # Speaker_ID File_ID Gender Type Label
            if len(parts) < 3:
                continue
                
            file_id = parts[1]  # Second element is the audio file ID
            
            # Check if bonafide or spoof - looking at labels in positions 3/4
            is_bonafide = 'bonafide' in line
            
            for ext in [".flac", ".wav"]:
                src_path = os.path.join(RAW_AUDIO_DIR, file_id + ext)
                if os.path.exists(src_path):
                    if is_bonafide and bonafide_count < MAX_PER_CLASS:
                        dst_path = os.path.join(output_dir, "bonafide", file_id + ext)
                        shutil.copyfile(src_path, dst_path)
                        bonafide_count += 1
                    elif not is_bonafide and spoof_count < MAX_PER_CLASS:
                        dst_path = os.path.join(output_dir, "spoof", file_id + ext)
                        shutil.copyfile(src_path, dst_path)
                        spoof_count += 1
                    break
                    
            if bonafide_count >= MAX_PER_CLASS and spoof_count >= MAX_PER_CLASS:
                break
                
    print(f"Organized {bonafide_count} bonafide and {spoof_count} spoof files in {output_dir}.")
    
    if bonafide_count == 0 and spoof_count == 0:
        print(f"No files were found. Please check:")
        print(f"1. Raw audio directory exists: {os.path.exists(RAW_AUDIO_DIR)}")
        print(f"2. Audio file names match protocol format")
        if os.path.exists(RAW_AUDIO_DIR):
            sample_files = os.listdir(RAW_AUDIO_DIR)[:5] if os.listdir(RAW_AUDIO_DIR) else []
            print(f"   Sample files in RAW_AUDIO_DIR: {sample_files}")

In [18]:
# Import RawNet2 model
from model.rawnet2 import RawNet2
from model.model_config import ModelConfig

In [19]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, max_samples=MAX_PER_CLASS, target_length=64000):
        self.files = []
        self.labels = []
        self.target_length = target_length
        for label, folder in enumerate(["bonafide", "spoof"]):
            full_path = os.path.join(root_dir, folder)
            if not os.path.exists(full_path): continue
            files = [os.path.join(full_path, f) for f in os.listdir(full_path) if f.endswith(".wav") or f.endswith(".flac")]
            self.files.extend(files[:max_samples])
            self.labels.extend([label] * len(files[:max_samples]))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        waveform, sr = torchaudio.load(self.files[idx])
        waveform = waveform.mean(dim=0, keepdim=True)  # mono
        waveform = Resample(sr, 16000)(waveform)  # resample to 16kHz

        # Standardize length
        current_length = waveform.shape[1]
        if current_length < self.target_length:
            padding = torch.zeros(1, self.target_length - current_length)
            waveform = torch.cat([waveform, padding], dim=1)
        elif current_length > self.target_length:
            waveform = waveform[:, :self.target_length]

        return waveform, self.labels[idx]


In [20]:
# Training Function

def train_model(model, dataloader, optimizer, criterion, epochs=5, device="cpu"):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            x, y = batch
            x, y = x.to(device), torch.tensor(y).to(device)
            optimizer.zero_grad()
            y_hat = model(x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

In [21]:
# Evaluation Function

def evaluate_model(model, dataloader, device="cpu"):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y_hat = model(x)
            preds.extend(torch.argmax(y_hat, dim=1).cpu().numpy())
            trues.extend(y)
    acc = accuracy_score(trues, preds)
    print("\nEvaluation Metrics")
    print("------------------")
    print(f"Accuracy: {acc:.4f}")
    print("\nClassification Report:")
    print(classification_report(trues, preds, target_names=["bonafide", "spoof"]))
    print("Confusion Matrix:")
    print(confusion_matrix(trues, preds))

In [22]:
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    config = ModelConfig()
    model = RawNet2(config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("\nOrganizing evaluation data from E subset only...")
    organize_asvspoof5_data(PROTOCOL_EVAL, TARGET_EVAL_DIR)

    print("\nDuplicating eval set as train set for light fine-tuning...")
    shutil.rmtree(TARGET_TRAIN_DIR, ignore_errors=True)
    shutil.copytree(TARGET_EVAL_DIR, TARGET_TRAIN_DIR)

    print("\nLoading training and evaluation datasets...")
    train_dataset = AudioDataset(TARGET_TRAIN_DIR)
    eval_dataset = AudioDataset(TARGET_EVAL_DIR)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=4, shuffle=False)

    print("\nStarting model training...")
    train_model(model, train_loader, optimizer, criterion, epochs=10, device=device)

    print("\nRunning evaluation...")
    evaluate_model(model, eval_loader, device=device)


Organizing evaluation data from E subset only...
Organized 500 bonafide and 500 spoof files in ./data\ASVspoof5\eval.

Duplicating eval set as train set for light fine-tuning...

Loading training and evaluation datasets...

Starting model training...


  x, y = x.to(device), torch.tensor(y).to(device)


Epoch 1, Loss: 179.2521
Epoch 2, Loss: 158.4937
Epoch 3, Loss: 194.0273
Epoch 4, Loss: 159.2617
Epoch 5, Loss: 158.1174
Epoch 6, Loss: 171.1250
Epoch 7, Loss: 164.2781
Epoch 8, Loss: 154.5140
Epoch 9, Loss: 152.4674
Epoch 10, Loss: 151.9663

Running evaluation...

Evaluation Metrics
------------------
Accuracy: 0.5760

Classification Report:
              precision    recall  f1-score   support

    bonafide       0.00      0.00      0.00       368
       spoof       0.58      1.00      0.73       500

    accuracy                           0.58       868
   macro avg       0.29      0.50      0.37       868
weighted avg       0.33      0.58      0.42       868

Confusion Matrix:
[[  0 368]
 [  0 500]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
