In [1]:
# Dependencies
import os
import shutil
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import random
import numpy as np

In [2]:
import os

def count_files(base_dir):
    bonafide_dir = os.path.join(base_dir, "bonafide")
    spoof_dir = os.path.join(base_dir, "spoof")

    bonafide_count = len([f for f in os.listdir(bonafide_dir) if f.endswith(".flac")]) if os.path.exists(bonafide_dir) else 0
    spoof_count = len([f for f in os.listdir(spoof_dir) if f.endswith(".flac")]) if os.path.exists(spoof_dir) else 0

    return bonafide_count, spoof_count

train_counts = count_files("./data/ASVspoof5/train")
eval_counts = count_files("./data/ASVspoof5/eval")

print("Train Set:")
print(f"  Bonafide: {train_counts[0]}")
print(f"  Spoof:    {train_counts[1]}")

print("\nEval Set:")
print(f"  Bonafide: {eval_counts[0]}")
print(f"  Spoof:    {eval_counts[1]}")


Train Set:
  Bonafide: 300
  Spoof:    300

Eval Set:
  Bonafide: 300
  Spoof:    300


In [3]:
# Constants and Paths
DATA_ROOT = "./data"
ASV5_DIR = os.path.join(DATA_ROOT, "ASVspoof5")
RAW_AUDIO_DIR = os.path.join(ASV5_DIR, "raw_audio")
PROTOCOL_EVAL = os.path.join(ASV5_DIR, "protocols", "ASVspoof5.eval.track_2.trial.tsv")
TARGET_TRAIN_DIR = os.path.join(ASV5_DIR, "train")  # Still used to load some training samples
TARGET_EVAL_DIR = os.path.join(ASV5_DIR, "eval")
MAX_PER_CLASS = 500

In [4]:
# Organize audio based on protocol

def organize_asvspoof5_eval_data(protocol_file, output_dir):
    os.makedirs(os.path.join(output_dir, "bonafide"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "spoof"), exist_ok=True)

    bonafide_count, spoof_count = 0, 0
    with open(protocol_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < 4:
                continue
            file_id, label = parts[1], parts[-1].strip()
            src_path = os.path.join(RAW_AUDIO_DIR, file_id + ".flac")
            if not os.path.exists(src_path):
                continue

            if label == "bonafide" and bonafide_count < MAX_PER_CLASS:
                dst_path = os.path.join(output_dir, "bonafide", file_id + ".flac")
                shutil.copyfile(src_path, dst_path)
                bonafide_count += 1
            elif label == "spoof" and spoof_count < MAX_PER_CLASS:
                dst_path = os.path.join(output_dir, "spoof", file_id + ".flac")
                shutil.copyfile(src_path, dst_path)
                spoof_count += 1
            if bonafide_count >= MAX_PER_CLASS and spoof_count >= MAX_PER_CLASS:
                break
    print(f"Organized {bonafide_count} bonafide and {spoof_count} spoof files in {output_dir}.")

In [5]:
# Import RawNet2 model
from model.rawnet2 import RawNet2
from model.model_config import ModelConfig

In [6]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, max_samples=100, target_length=64000):
        self.files = []
        self.labels = []
        self.target_length = target_length  # 4 seconds at 16kHz
        
        for label, folder in enumerate(["bonafide", "spoof"]):
            full_path = os.path.join(root_dir, folder)
            if not os.path.exists(full_path): continue
            files = [os.path.join(full_path, f) for f in os.listdir(full_path) if f.endswith(".wav") or f.endswith(".flac")]
            self.files.extend(files[:max_samples])
            self.labels.extend([label] * len(files[:max_samples]))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        waveform, sr = torchaudio.load(self.files[idx])
        waveform = waveform.mean(dim=0, keepdim=True)  # mono
        waveform = Resample(sr, 16000)(waveform)  # resample to 16kHz
        
        # Standardize length
        current_length = waveform.shape[1]
        
        if current_length < self.target_length:
            # Pad with zeros if audio is too short
            padding = torch.zeros(1, self.target_length - current_length)
            waveform = torch.cat([waveform, padding], dim=1)
        elif current_length > self.target_length:
            # Truncate if audio is too long
            waveform = waveform[:, :self.target_length]
            
        return waveform, self.labels[idx]

In [7]:
# Training Function

def train_model(model, dataloader, optimizer, criterion, epochs=5, device="cpu"):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            x, y = batch
            x, y = x.to(device), torch.tensor(y).to(device)
            optimizer.zero_grad()
            y_hat = model(x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

In [8]:
# Evaluation Function

def evaluate_model(model, dataloader, device="cpu"):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y_hat = model(x)
            preds.extend(torch.argmax(y_hat, dim=1).cpu().numpy())
            trues.extend(y)
    acc = accuracy_score(trues, preds)
    print("\nEvaluation Metrics")
    print("------------------")
    print(f"Accuracy: {acc:.4f}")
    print("\nClassification Report:")
    print(classification_report(trues, preds, target_names=["bonafide", "spoof"]))
    print("Confusion Matrix:")
    print(confusion_matrix(trues, preds))

In [9]:
# Main Script
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    config = ModelConfig()
    model = RawNet2(config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("\nOrganizing evaluation data from E subset only...")
    organize_asvspoof5_eval_data(PROTOCOL_EVAL, TARGET_EVAL_DIR)

    print("\nDuplicating eval set as train set for light fine-tuning...")
    if not os.path.exists(TARGET_TRAIN_DIR):
        shutil.copytree(TARGET_EVAL_DIR, TARGET_TRAIN_DIR)

    print("\nLoading training and evaluation datasets...")
    train_dataset = AudioDataset(TARGET_TRAIN_DIR)
    eval_dataset = AudioDataset(TARGET_EVAL_DIR)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=4, shuffle=False)

    print("\nStarting model training...")
    train_model(model, train_loader, optimizer, criterion, epochs=3, device=device)

    print("\nRunning evaluation...")
    evaluate_model(model, eval_loader, device=device)

Using device: cpu

Organizing evaluation data from E subset only...
Organized 0 bonafide and 0 spoof files in ./data\ASVspoof5\eval.

Duplicating eval set as train set for light fine-tuning...

Loading training and evaluation datasets...

Starting model training...


  x, y = x.to(device), torch.tensor(y).to(device)


KeyboardInterrupt: 