# phase 3: wav2vec2 fine-tuning for pd classification

fine-tune wav2vec2-base-960h on parkinson's disease voice detection using
leave-one-subject-out (loso) cross-validation for rigorous evaluation.

**methodology:**
- loso cv: same protocol as clinical baseline (88.3% accuracy) for fair comparison
- freeze cnn feature extractor + first 4 transformer layers (small dataset)
- gradient checkpointing for memory efficiency
- early stopping to prevent overfitting

**expected results:**
- target accuracy: 80-90% (competitive with clinical baseline)
- comparison with 17-feature clinical model establishes deep learning value

**hardware support:**
- nvidia gpu (cuda) - recommended
- apple silicon (mps) - supported but slower
- cpu - not recommended (10-20+ hours)

## 1. setup and configuration

In [1]:
import sys
from pathlib import Path
import os

# set project root
project_root = Path('/Volumes/usb drive/pd-interpretability')
assert project_root.exists(), f"project root not found: {project_root}"

os.chdir(project_root)
sys.path.insert(0, str(project_root))

print(f"project root: {project_root}")
print(f"working directory: {os.getcwd()}")

project root: /Volumes/usb drive/pd-interpretability
working directory: /Volumes/usb drive/pd-interpretability


In [2]:
import torch
import numpy as np
import pandas as pd
import json
import warnings
from datetime import datetime
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

warnings.filterwarnings('ignore')

def detect_device():
    """detect best available compute device."""
    if torch.cuda.is_available():
        device = 'cuda'
        device_name = torch.cuda.get_device_name(0)
        memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"detected: nvidia gpu ({device_name})")
        print(f"vram: {memory_gb:.1f} gb")
        return device, True
    elif torch.backends.mps.is_available():
        device = 'mps'
        print("detected: apple silicon (mps)")
        return device, True
    else:
        device = 'cpu'
        print("warning: no gpu detected, using cpu (very slow)")
        return device, False

device, has_accelerator = detect_device()
print(f"pytorch version: {torch.__version__}")
print(f"device: {device}")

detected: apple silicon (mps)
pytorch version: 2.2.0
device: mps


In [3]:
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW

from transformers import (
    Wav2Vec2ForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    get_linear_schedule_with_warmup
)

from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)

from src.data.datasets import ItalianPVSDataset
from src.models.classifier import DataCollatorWithPadding

print("imports complete")

imports complete


In [4]:
# experiment configuration
config = {
    # model
    'model_name': 'facebook/wav2vec2-base-960h',
    'num_labels': 2,
    'freeze_feature_extractor': True,
    'freeze_encoder_layers': 4,  # freeze first 4 transformer layers
    'dropout': 0.15,
    'gradient_checkpointing': True,
    
    # audio
    'max_duration': 10.0,
    'target_sr': 16000,
    
    # training
    'num_epochs': 15,
    'learning_rate': 5e-5,  # lower for small dataset stability
    'warmup_ratio': 0.1,
    'weight_decay': 0.01,
    'early_stopping_patience': 3,
    
    # loso cv
    'max_folds': 3,  # set to 3 for quick test, None for full CV
    
    # random seed
    'random_seed': 42
}

# device-specific settings
if device == 'cuda':
    config['batch_size'] = 8
    config['gradient_accumulation_steps'] = 4
    config['fp16'] = True
elif device == 'mps':
    config['batch_size'] = 4
    config['gradient_accumulation_steps'] = 8
    config['fp16'] = False  # mps fp16 unstable
else:
    config['batch_size'] = 2
    config['gradient_accumulation_steps'] = 16
    config['fp16'] = False

effective_batch = config['batch_size'] * config['gradient_accumulation_steps']
print(f"batch size: {config['batch_size']} (effective: {effective_batch})")
print(f"learning rate: {config['learning_rate']}")
print(f"epochs: {config['num_epochs']}")
print(f"frozen layers: cnn + first {config['freeze_encoder_layers']} transformer")
print(f"max folds: {config['max_folds']} (set to None for full LOSO CV)")

batch size: 4 (effective: 32)
learning rate: 5e-05
epochs: 15
frozen layers: cnn + first 4 transformer
max folds: 3 (set to None for full LOSO CV)


## 2. load dataset

In [5]:
# load dataset
data_root = project_root / 'data' / 'raw'

dataset = ItalianPVSDataset(
    root_dir=str(data_root / 'italian_pvs'),
    task=None,  # all tasks
    target_sr=config['target_sr'],
    max_duration=config['max_duration']
)

print(f"samples: {len(dataset)}")

# extract labels and subject ids for loso cv
labels = np.array([s['label'] for s in dataset.samples])
subject_ids = np.array([s['subject_id'] for s in dataset.samples])

# unique subjects
unique_subjects = np.unique(subject_ids)
n_subjects = len(unique_subjects)

# class distribution
n_pd = np.sum(labels)
n_hc = len(labels) - n_pd
print(f"class distribution: {n_hc} hc, {n_pd} pd")
print(f"subjects: {n_subjects}")
print(f"loso cv folds: {n_subjects}")

samples: 831
class distribution: 394 hc, 437 pd
subjects: 61
loso cv folds: 61


In [6]:
# create output directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
experiment_name = f"wav2vec2_loso_{timestamp}"
output_dir = project_root / 'results' / 'checkpoints' / experiment_name
output_dir.mkdir(parents=True, exist_ok=True)

# save config
config_path = output_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

print(f"experiment: {experiment_name}")
print(f"output: {output_dir}")

experiment: wav2vec2_loso_20260102_125342
output: /Volumes/usb drive/pd-interpretability/results/checkpoints/wav2vec2_loso_20260102_125342


## 3. training utilities

In [7]:
def create_model(config: dict, device: str):
    """create fresh wav2vec2 model with specified freezing strategy."""
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        config['model_name'],
        num_labels=config['num_labels'],
        classifier_proj_size=256,
        hidden_dropout=config['dropout'],
        attention_dropout=config['dropout'],
        final_dropout=config['dropout']
    )
    
    # enable gradient checkpointing
    if config['gradient_checkpointing']:
        model.gradient_checkpointing_enable()
    
    # freeze cnn feature extractor
    if config['freeze_feature_extractor']:
        model.freeze_feature_encoder()
    
    # freeze first n transformer layers
    if config['freeze_encoder_layers'] > 0:
        for i, layer in enumerate(model.wav2vec2.encoder.layers):
            if i < config['freeze_encoder_layers']:
                for param in layer.parameters():
                    param.requires_grad = False
    
    return model.to(device)


def count_parameters(model):
    """count trainable and frozen parameters."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, total - trainable


def create_collate_fn(feature_extractor, max_length: int):
    """create collate function for wav2vec2 training."""
    def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
        # extract raw audio waveforms - dataset returns 1d tensors
        input_values = [item['input_values'] for item in batch]
        labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
        
        # pad sequences to max length in batch
        max_len = min(max(len(x) for x in input_values), max_length)
        
        padded_input = torch.zeros(len(input_values), max_len)
        attention_mask = torch.zeros(len(input_values), max_len)
        
        for i, wav in enumerate(input_values):
            length = min(len(wav), max_len)
            padded_input[i, :length] = wav[:length]
            attention_mask[i, :length] = 1.0
        
        return {
            'input_values': padded_input,
            'attention_mask': attention_mask,
            'labels': labels
        }
    
    return collate_fn

In [None]:
import gc

def train_epoch(model, loader, optimizer, scheduler, scaler, device, accumulation_steps):
    """train for one epoch with gradient accumulation and mps memory management."""
    model.train()
    total_loss = 0
    n_batches = 0
    optimizer.zero_grad()
    
    for step, batch in enumerate(loader):
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(input_values, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss / accumulation_steps
            scaler.scale(loss).backward()
        else:
            outputs = model(input_values, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss / accumulation_steps
            loss.backward()
        
        total_loss += loss.item() * accumulation_steps
        n_batches += 1
        
        if (step + 1) % accumulation_steps == 0:
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # mps memory management - synchronize periodically to prevent fragmentation
        if device == 'mps' and (step + 1) % 50 == 0:
            torch.mps.synchronize()
        
        # cleanup batch tensors
        del input_values, attention_mask, labels, outputs, loss
    
    # end of epoch cleanup for mps
    if device == 'mps':
        torch.mps.synchronize()
        torch.mps.empty_cache()
        gc.collect()
    
    return total_loss / n_batches if n_batches > 0 else 0.0

In [None]:
@torch.no_grad()
def evaluate(model, loader, device):
    """evaluate model on dataset with memory management."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    for batch in loader:
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_values, attention_mask=attention_mask, labels=labels)
        total_loss += outputs.loss.item()
        
        probs = torch.softmax(outputs.logits, dim=-1)
        preds = outputs.logits.argmax(dim=-1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())
        
        # cleanup batch tensors
        del input_values, attention_mask, labels, outputs, probs, preds
    
    # mps memory cleanup
    if device == 'mps':
        torch.mps.synchronize()
        torch.mps.empty_cache()
        gc.collect()
    
    n_batches = len(loader)
    
    return {
        'loss': total_loss / n_batches if n_batches > 0 else 0,
        'accuracy': accuracy_score(all_labels, all_preds),
        'predictions': np.array(all_preds),
        'labels': np.array(all_labels),
        'probabilities': np.array(all_probs)
    }

## 4. loso cross-validation training

In [None]:
def train_fold(
    dataset,
    train_indices: np.ndarray,
    test_indices: np.ndarray,
    config: dict,
    device: str,
    fold_idx: int,
    output_dir: Path
) -> Dict:
    """train model on single loso fold with robust memory management."""
    
    print(f"    [fold {fold_idx + 1}] creating data subsets...")
    # create data subsets
    train_subset = Subset(dataset, train_indices.tolist())
    test_subset = Subset(dataset, test_indices.tolist())
    
    # feature extractor and custom collate function
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['model_name'])
    max_length = int(config['max_duration'] * config['target_sr'])
    collate_fn = create_collate_fn(feature_extractor, max_length)
    
    print(f"    [fold {fold_idx + 1}] creating dataloaders (batch_size={config['batch_size']})...")
    # dataloaders
    train_loader = DataLoader(
        train_subset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=(device == 'cuda')
    )
    
    test_loader = DataLoader(
        test_subset,
        batch_size=config['batch_size'] * 2,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=(device == 'cuda')
    )
    
    print(f"    [fold {fold_idx + 1}] initializing model ({config['model_name']})...")
    # create fresh model
    model = create_model(config, device)
    
    # optimizer
    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # scheduler
    steps_per_epoch = max(1, len(train_loader) // config['gradient_accumulation_steps'])
    total_steps = steps_per_epoch * config['num_epochs']
    warmup_steps = int(total_steps * config['warmup_ratio'])
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # gradient scaler for fp16
    scaler = torch.cuda.amp.GradScaler() if config['fp16'] and device == 'cuda' else None
    
    print(f"    [fold {fold_idx + 1}] starting training ({config['num_epochs']} epochs, {len(train_loader)} batches/epoch)...")
    # training loop with early stopping
    best_loss = float('inf')
    patience_counter = 0
    best_metrics = None
    
    for epoch in range(config['num_epochs']):
        epoch_start = datetime.now()
        
        train_loss = train_epoch(
            model, train_loader, optimizer, scheduler, scaler,
            device, config['gradient_accumulation_steps']
        )
        
        test_metrics = evaluate(model, test_loader, device)
        
        epoch_time = (datetime.now() - epoch_start).total_seconds()
        
        print(f"      epoch {epoch + 1}/{config['num_epochs']}: "
              f"train_loss={train_loss:.4f}, test_loss={test_metrics['loss']:.4f}, "
              f"test_acc={test_metrics['accuracy']:.1%}, time={epoch_time:.1f}s")
        
        # save best metrics
        if test_metrics['loss'] < best_loss:
            best_loss = test_metrics['loss']
            patience_counter = 0
            best_metrics = test_metrics.copy()
        else:
            patience_counter += 1
            if patience_counter >= config['early_stopping_patience']:
                print(f"      early stopping triggered at epoch {epoch + 1}")
                break
        
        # mps: aggressive cleanup after each epoch to prevent memory fragmentation
        if device == 'mps':
            torch.mps.synchronize()
            torch.mps.empty_cache()
            gc.collect()
    
    print(f"    [fold {fold_idx + 1}] final evaluation...")
    # final evaluation
    final_metrics = evaluate(model, test_loader, device) if best_metrics is None else best_metrics
    
    print(f"    [fold {fold_idx + 1}] cleaning up...")
    # aggressive cleanup
    del model, optimizer, scheduler, train_loader, test_loader
    del train_subset, test_subset, feature_extractor, collate_fn
    
    if device == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    elif device == 'mps':
        torch.mps.synchronize()
        torch.mps.empty_cache()
    
    gc.collect()
    
    return {
        'fold': fold_idx,
        'train_samples': len(train_indices),
        'test_samples': len(test_indices),
        'accuracy': final_metrics['accuracy'],
        'predictions': final_metrics['predictions'],
        'labels': final_metrics['labels'],
        'probabilities': final_metrics['probabilities']
    }

In [None]:
import time

# run loso cross-validation
logo = LeaveOneGroupOut()
n_folds = logo.get_n_splits(groups=subject_ids)

if config['max_folds']:
    n_folds = min(n_folds, config['max_folds'])

print("=" * 80)
print(f"STARTING LOSO CROSS-VALIDATION")
print("=" * 80)
print(f"total folds to run: {n_folds}")
print(f"device: {device}")
print(f"model: {config['model_name']}")
print(f"batch size: {config['batch_size']}")
print(f"learning rate: {config['learning_rate']}")
print(f"max epochs per fold: {config['num_epochs']}")
print(f"early stopping patience: {config['early_stopping_patience']}")
print("=" * 80)

fold_results = []
all_predictions = []
all_labels = []
all_probabilities = []
all_subject_ids = []

start_time = datetime.now()

for fold_idx, (train_idx, test_idx) in enumerate(
    logo.split(np.arange(len(dataset)), labels, subject_ids)
):
    if config['max_folds'] and fold_idx >= config['max_folds']:
        break
    
    fold_start = datetime.now()
    
    test_subject = subject_ids[test_idx[0]]
    test_label = labels[test_idx[0]]
    label_str = "pd" if test_label == 1 else "hc"
    
    print(f"\n{'=' * 80}")
    print(f"FOLD {fold_idx + 1}/{n_folds}")
    print(f"{'=' * 80}")
    print(f"  test subject: {test_subject} ({label_str})")
    print(f"  train samples: {len(train_idx)}, test samples: {len(test_idx)}")
    print(f"  starting training...")
    
    result = train_fold(
        dataset=dataset,
        train_indices=train_idx,
        test_indices=test_idx,
        config=config,
        device=device,
        fold_idx=fold_idx,
        output_dir=output_dir
    )
    
    fold_time = (datetime.now() - fold_start).total_seconds()
    elapsed_total = datetime.now() - start_time
    avg_time_per_fold = elapsed_total.total_seconds() / (fold_idx + 1)
    remaining_folds = n_folds - (fold_idx + 1)
    eta = remaining_folds * avg_time_per_fold
    
    fold_results.append(result)
    all_predictions.extend(result['predictions'])
    all_labels.extend(result['labels'])
    all_probabilities.extend(result['probabilities'])
    all_subject_ids.extend([test_subject] * len(result['predictions']))
    
    # calculate running accuracy
    running_acc = accuracy_score(all_labels, all_predictions)
    
    print(f"\n  FOLD {fold_idx + 1} COMPLETE:")
    print(f"    fold accuracy: {result['accuracy']:.1%}")
    print(f"    fold time: {fold_time:.1f}s ({fold_time/60:.1f}m)")
    print(f"    running overall accuracy: {running_acc:.1%}")
    print(f"    time elapsed: {elapsed_total}")
    print(f"    estimated time remaining: {eta/60:.1f}m ({eta/3600:.1f}h)")
    print(f"{'=' * 80}")
    
    # mps: add brief sleep between folds to allow system memory stabilization
    if device == 'mps' and remaining_folds > 0:
        print(f"  [mps] stabilizing memory before next fold...")
        torch.mps.synchronize()
        torch.mps.empty_cache()
        gc.collect()
        time.sleep(2)

elapsed = datetime.now() - start_time
print(f"\n{'=' * 80}")
print(f"LOSO CV COMPLETE")
print(f"{'=' * 80}")
print(f"total time: {elapsed} ({elapsed.total_seconds()/60:.1f}m)")
print(f"{'=' * 80}")

STARTING LOSO CROSS-VALIDATION
total folds to run: 3
device: mps
model: facebook/wav2vec2-base-960h
batch size: 4
learning rate: 5e-05
max epochs per fold: 15
early stopping patience: 3

FOLD 1/3
  test subject: HC_elderly_AGNESE_P (hc)
  train samples: 815, test samples: 16
  starting training...
    [fold 1] creating data subsets...
    [fold 1] creating dataloaders (batch_size=4)...
    [fold 1] initializing model (facebook/wav2vec2-base-960h)...


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['projector.bias', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'classifier.bias', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'projector.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


    [fold 1] starting training (15 epochs, 204 batches/epoch)...
      epoch 1/15: train_loss=0.6868, test_loss=0.7901, test_acc=0.0%, time=518.2s
      epoch 2/15: train_loss=0.5461, test_loss=0.4938, test_acc=81.2%, time=613.3s
      epoch 3/15: train_loss=0.3458, test_loss=0.7751, test_acc=68.8%, time=570.6s
      epoch 4/15: train_loss=0.2105, test_loss=1.4799, test_acc=62.5%, time=546.6s
      epoch 5/15: train_loss=0.2099, test_loss=0.0795, test_acc=93.8%, time=505.2s
      epoch 6/15: train_loss=0.1794, test_loss=0.0501, test_acc=100.0%, time=585.6s
      epoch 7/15: train_loss=0.1374, test_loss=0.1589, test_acc=93.8%, time=526.1s
      epoch 8/15: train_loss=0.0930, test_loss=0.1946, test_acc=93.8%, time=483.0s
      epoch 9/15: train_loss=0.0757, test_loss=0.0134, test_acc=100.0%, time=496.0s
      epoch 10/15: train_loss=0.0550, test_loss=0.3135, test_acc=87.5%, time=492.6s
      epoch 11/15: train_loss=0.0628, test_loss=0.0714, test_acc=93.8%, time=480.2s
      epoch 12/15: 

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['projector.bias', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'classifier.bias', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'projector.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


    [fold 2] starting training (15 epochs, 204 batches/epoch)...
      epoch 1/15: train_loss=0.6891, test_loss=0.7750, test_acc=0.0%, time=460.8s
      epoch 2/15: train_loss=0.6318, test_loss=0.6903, test_acc=87.5%, time=489.1s
      epoch 3/15: train_loss=0.4510, test_loss=0.8636, test_acc=68.8%, time=471.2s
      epoch 4/15: train_loss=0.3735, test_loss=0.3949, test_acc=87.5%, time=477.8s
      epoch 5/15: train_loss=0.2417, test_loss=0.4809, test_acc=87.5%, time=474.5s
      epoch 6/15: train_loss=0.1805, test_loss=0.4367, test_acc=87.5%, time=482.2s
      epoch 7/15: train_loss=0.1502, test_loss=0.4388, test_acc=87.5%, time=632.2s
      early stopping triggered at epoch 7
    [fold 2] final evaluation...
    [fold 2] cleaning up...

  FOLD 2 COMPLETE:
    fold accuracy: 87.5%
    fold time: 3495.4s (58.3m)
    running overall accuracy: 93.8%
    time elapsed: 3:08:15.408072
    estimated time remaining: 94.1m (1.6h)

FOLD 3/3
  test subject: HC_elderly_ANGELA_G (hc)
  train sampl

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['projector.bias', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'classifier.bias', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'projector.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


    [fold 3] starting training (15 epochs, 204 batches/epoch)...
      epoch 1/15: train_loss=0.6902, test_loss=0.7941, test_acc=0.0%, time=580.4s
      epoch 2/15: train_loss=0.6244, test_loss=0.3095, test_acc=100.0%, time=595.0s
      epoch 3/15: train_loss=0.4144, test_loss=0.1840, test_acc=93.8%, time=664.6s
      epoch 4/15: train_loss=0.3523, test_loss=0.0727, test_acc=100.0%, time=577.2s
      epoch 5/15: train_loss=0.2704, test_loss=0.0541, test_acc=100.0%, time=644.4s
      epoch 6/15: train_loss=0.1939, test_loss=0.0914, test_acc=93.8%, time=14641.8s


KeyboardInterrupt: 

## 5. aggregate results

In [None]:
# convert to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)

# overall metrics
overall_accuracy = accuracy_score(all_labels, all_predictions)
overall_precision = precision_score(all_labels, all_predictions, zero_division=0)
overall_recall = recall_score(all_labels, all_predictions, zero_division=0)
overall_f1 = f1_score(all_labels, all_predictions, zero_division=0)

try:
    overall_auc = roc_auc_score(all_labels, all_probabilities)
except:
    overall_auc = 0.5

# per-fold accuracy
fold_accuracies = [r['accuracy'] for r in fold_results]
mean_accuracy = np.mean(fold_accuracies)
std_accuracy = np.std(fold_accuracies)

# confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

print("=" * 60)
print("LOSO CROSS-VALIDATION RESULTS")
print("=" * 60)
print(f"\noverall metrics (aggregated across all folds):")
print(f"  accuracy: {overall_accuracy:.1%}")
print(f"  precision: {overall_precision:.3f}")
print(f"  recall: {overall_recall:.3f}")
print(f"  f1 score: {overall_f1:.3f}")
print(f"  auc-roc: {overall_auc:.3f}")

print(f"\nper-fold statistics:")
print(f"  mean accuracy: {mean_accuracy:.1%} ± {std_accuracy:.1%}")
print(f"  min: {min(fold_accuracies):.1%}, max: {max(fold_accuracies):.1%}")

print(f"\nconfusion matrix:")
print(f"           predicted")
print(f"            hc    pd")
print(f"actual hc  {cm[0,0]:4d}  {cm[0,1]:4d}")
print(f"       pd  {cm[1,0]:4d}  {cm[1,1]:4d}")

# per-subject accuracy analysis
subject_results = {}
for subj, pred, label in zip(all_subject_ids, all_predictions, all_labels):
    if subj not in subject_results:
        subject_results[subj] = {'correct': 0, 'total': 0, 'label': label}
    subject_results[subj]['total'] += 1
    if pred == label:
        subject_results[subj]['correct'] += 1

subject_accuracies = []
subject_data = []

for subj, data in subject_results.items():
    acc = data['correct'] / data['total']
    subject_accuracies.append(acc)
    diagnosis = 'pd' if data['label'] == 1 else 'hc'
    subject_data.append({
        'subject_id': subj,
        'diagnosis': diagnosis,
        'accuracy': acc,
        'correct': data['correct'],
        'total': data['total']
    })

subject_df = pd.DataFrame(subject_data)
subject_df = subject_df.sort_values('accuracy', ascending=False)

print(f"\nper-subject accuracy:")
print(f"  mean: {np.mean(subject_accuracies):.1%}")
print(f"  median: {np.median(subject_accuracies):.1%}")
print(f"  subjects with 100% accuracy: {sum(1 for a in subject_accuracies if a == 1.0)}/{len(subject_accuracies)}")
print(f"  subjects with <50% accuracy: {sum(1 for a in subject_accuracies if a < 0.5)}/{len(subject_accuracies)}")

In [None]:
## 6. comparison with clinical baseline

In [None]:
# load clinical baseline results for comparison
baseline_path = project_root / 'results' / 'clinical_baseline_results.json'

if baseline_path.exists():
    with open(baseline_path) as f:
        baseline_results = json.load(f)
    
    clinical_acc = baseline_results['svm_results']['mean_accuracy']
    clinical_std = baseline_results['svm_results']['std_accuracy']
    
    print("=" * 60)
    print("COMPARISON WITH CLINICAL BASELINE")
    print("=" * 60)
    print(f"\nclinical baseline (svm, 17 features):")
    print(f"  accuracy: {clinical_acc:.1%} ± {clinical_std:.1%}")
    
    print(f"\nwav2vec2 (loso cv):")
    print(f"  accuracy: {overall_accuracy:.1%}")
    print(f"  per-fold mean: {mean_accuracy:.1%} ± {std_accuracy:.1%}")
    
    diff = overall_accuracy - clinical_acc
    print(f"\ndifference: {diff:+.1%}")
    
    if diff > 0:
        print("  wav2vec2 outperforms clinical baseline")
    elif diff < -0.05:
        print("  clinical baseline outperforms wav2vec2")
    else:
        print("  comparable performance")
else:
    print("clinical baseline results not found")

## 7. save results

In [None]:
# save comprehensive results
results = {
    'experiment': experiment_name,
    'config': config,
    'device': device,
    'n_folds': len(fold_results),
    'n_subjects': n_subjects,
    'n_samples': len(dataset),
    
    'overall_metrics': {
        'accuracy': float(overall_accuracy),
        'precision': float(overall_precision),
        'recall': float(overall_recall),
        'f1': float(overall_f1),
        'auc': float(overall_auc)
    },
    
    'per_fold_stats': {
        'mean_accuracy': float(mean_accuracy),
        'std_accuracy': float(std_accuracy),
        'min_accuracy': float(min(fold_accuracies)),
        'max_accuracy': float(max(fold_accuracies))
    },
    
    'confusion_matrix': cm.tolist(),
    
    'fold_results': [
        {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in r.items()}
        for r in fold_results
    ],
    
    'timestamp': datetime.now().isoformat(),
    'elapsed_time': str(elapsed)
}

# save to json
results_path = output_dir / 'wav2vec2_loso_results.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

# save subject-level results
subject_path = output_dir / 'wav2vec2_subject_accuracy.csv'
subject_df.to_csv(subject_path, index=False)

# also save to main results folder for easy access
main_results_path = project_root / 'results' / 'wav2vec2_loso_results.json'
with open(main_results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"results saved to:")
print(f"  {results_path}")
print(f"  {subject_path}")
print(f"  {main_results_path}")

## 8. visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os

# add latex to path
os.environ['PATH'] = '/Library/TeX/texbin:' + os.environ.get('PATH', '')

# publication-quality style with latex and times new roman
plt.rcParams.update({
    'text.usetex': True,
    'text.latex.preamble': r'\usepackage{amsmath}\usepackage{amssymb}',
    'font.family': 'serif',
    'font.serif': ['Times', 'Times New Roman', 'DejaVu Serif'],
    'font.size': 10,
    'axes.titlesize': 11,
    'axes.labelsize': 10,
    'axes.linewidth': 0.8,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'xtick.direction': 'out',
    'ytick.direction': 'out',
    'legend.fontsize': 9,
    'legend.frameon': True,
    'legend.fancybox': False,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
})

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# fold accuracy distribution
ax1 = axes[0]
ax1.bar(range(1, len(fold_accuracies) + 1), fold_accuracies, color='#3498db', alpha=0.8, edgecolor='white')
ax1.axhline(y=mean_accuracy, color='#e74c3c', linestyle='--', linewidth=2, 
            label=rf'Mean: {mean_accuracy:.1%}')
ax1.set_xlabel(r'Fold')
ax1.set_ylabel(r'Accuracy')
ax1.set_title(r'Per-Fold Accuracy (LOSO CV)')
ax1.legend(loc='lower right')
ax1.set_ylim(0, 1.05)

# confusion matrix
ax2 = axes[1]
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax2,
            xticklabels=[r'HC', r'PD'], yticklabels=[r'HC', r'PD'],
            cbar_kws={'label': 'Count'})
ax2.set_xlabel(r'Predicted')
ax2.set_ylabel(r'Actual')
ax2.set_title(rf'Confusion Matrix (Accuracy: {overall_accuracy:.1%})')

# subject accuracy histogram
ax3 = axes[2]
ax3.hist(subject_accuracies, bins=10, color='#3498db', alpha=0.8, edgecolor='white')
ax3.axvline(x=np.mean(subject_accuracies), color='#e74c3c', linestyle='--', linewidth=2,
            label=rf'Mean: {np.mean(subject_accuracies):.1%}')
ax3.set_xlabel(r'Accuracy')
ax3.set_ylabel(r'Count')
ax3.set_title(r'Per-Subject Accuracy Distribution')
ax3.legend()

plt.tight_layout()

# save to both experiment dir and main figures dir
plt.savefig(output_dir / 'wav2vec2_results_summary.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'wav2vec2_results_summary.pdf', dpi=300, bbox_inches='tight')

# also save to main figures folder
main_fig_dir = project_root / 'results' / 'figures'
plt.savefig(main_fig_dir / 'fig_p3_01_wav2vec2_loso_summary.png', dpi=300, bbox_inches='tight')
plt.savefig(main_fig_dir / 'fig_p3_01_wav2vec2_loso_summary.pdf', dpi=300, bbox_inches='tight')

plt.show()

print(f"figures saved to {output_dir} and {main_fig_dir}")

In [None]:
# comparison bar chart with clinical baseline
if baseline_path.exists():
    fig, ax = plt.subplots(figsize=(8, 5))
    
    models = [r'Clinical Baseline' + '\n' + r'(SVM, 17 features)', 
              r'Wav2Vec2' + '\n' + r'(Fine-tuned)']
    accuracies = [clinical_acc, overall_accuracy]
    stds = [clinical_std, std_accuracy]
    colors = ['#2ecc71', '#3498db']
    
    bars = ax.bar(models, accuracies, yerr=stds, capsize=8, color=colors, alpha=0.8,
                  edgecolor='white', linewidth=1.5)
    
    for bar, acc, std in zip(bars, accuracies, stds):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.02,
                rf'{acc:.1%}', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_ylabel(r'Accuracy')
    ax.set_title(r'Model Comparison: LOSO Cross-Validation')
    ax.set_ylim(0, 1.15)
    ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, label=r'Chance Level')
    ax.legend(loc='upper right')
    
    plt.tight_layout()
    
    # save figures
    plt.savefig(output_dir / 'model_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / 'model_comparison.pdf', dpi=300, bbox_inches='tight')
    
    main_fig_dir = project_root / 'results' / 'figures'
    plt.savefig(main_fig_dir / 'fig_p3_02_model_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig(main_fig_dir / 'fig_p3_02_model_comparison.pdf', dpi=300, bbox_inches='tight')
    
    plt.show()
    
    print(f"comparison figures saved")
else:
    print("clinical baseline results not found - skipping comparison")

## 9. summary and next steps

In [None]:
print("=" * 60)
print("PHASE 3 COMPLETE: WAV2VEC2 FINE-TUNING")
print("=" * 60)

print(f"\nmodel: {config['model_name']}")
print(f"device: {device}")
print(f"loso cv folds: {len(fold_results)}")
print(f"training time: {elapsed}")

print(f"\nresults:")
print(f"  accuracy: {overall_accuracy:.1%}")
print(f"  precision: {overall_precision:.3f}")
print(f"  recall: {overall_recall:.3f}")
print(f"  f1 score: {overall_f1:.3f}")
print(f"  auc-roc: {overall_auc:.3f}")

print(f"\nper-fold accuracy: {mean_accuracy:.1%} ± {std_accuracy:.1%}")

print(f"\nnext steps:")
print(f"  1. phase 4: activation extraction (notebook 04)")
print(f"  2. phase 5: probing experiments (notebook 05)")
print(f"  3. phase 6: activation patching (notebook 06)")

print(f"\noutputs saved to: {output_dir}")

## 10. train final model for activation extraction

train a single model on all data for use in probing and patching experiments.
this model will be used to extract activations in phase 4.

In [None]:
# train final model on 80% of data (hold out 20% for testing)
train_subset, _, test_subset = dataset.get_subject_split(
    test_size=0.2,
    val_size=0.0,
    random_state=config['random_seed']
)

print(f"training final model...")
print(f"  train samples: {len(train_subset)}")
print(f"  test samples: {len(test_subset)}")

# create model
final_model = create_model(config, device)

# feature extractor and data collator
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['model_name'])
data_collator = DataCollatorWithPadding(
    feature_extractor=feature_extractor,
    max_length=int(config['max_duration'] * config['target_sr'])
)

# dataloaders
train_loader = DataLoader(
    train_subset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=data_collator,
    num_workers=0,
    pin_memory=(device == 'cuda')
)

test_loader = DataLoader(
    test_subset,
    batch_size=config['batch_size'] * 2,
    shuffle=False,
    collate_fn=data_collator,
    num_workers=0,
    pin_memory=(device == 'cuda')
)

# optimizer and scheduler
optimizer = AdamW(
    [p for p in final_model.parameters() if p.requires_grad],
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

steps_per_epoch = max(1, len(train_loader) // config['gradient_accumulation_steps'])
total_steps = steps_per_epoch * config['num_epochs']
warmup_steps = int(total_steps * config['warmup_ratio'])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

scaler = torch.cuda.amp.GradScaler() if config['fp16'] and device == 'cuda' else None

# training loop
best_acc = 0
for epoch in range(config['num_epochs']):
    train_loss = train_epoch(
        final_model, train_loader, optimizer, scheduler, scaler,
        device, config['gradient_accumulation_steps']
    )
    
    test_metrics = evaluate(final_model, test_loader, device)
    
    if test_metrics['accuracy'] > best_acc:
        best_acc = test_metrics['accuracy']
        # save checkpoint
        checkpoint_path = output_dir / 'final_model'
        final_model.save_pretrained(checkpoint_path)
        feature_extractor.save_pretrained(checkpoint_path)
    
    print(f"epoch {epoch+1}/{config['num_epochs']}: loss={train_loss:.4f}, acc={test_metrics['accuracy']:.1%}")

print(f"\nfinal model saved to: {checkpoint_path}")
print(f"test accuracy: {best_acc:.1%}")