# wav2vec2 training for pd classification

fine-tune wav2vec2-base on parkinson's disease voice detection.
this notebook requires gpu runtime.

## 1. setup and installation

In [None]:
import sys

!nvidia-smi

in_colab = 'google.colab' in sys.modules
print(f"running in colab: {in_colab}")

In [None]:
if in_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    
    project_path = '/content/drive/MyDrive/pd-interpretability'
    
    !pip install -q -r {project_path}/requirements-colab.txt
else:
    project_path = '.'

In [None]:
import os
os.chdir(project_path)
sys.path.insert(0, project_path)

print(f"working directory: {os.getcwd()}")

In [None]:
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

from src.data import ItalianPVSDataset, MDVRKCLDataset
from src.models import (
    Wav2Vec2PDClassifier,
    PDClassifierTrainer,
    DataCollatorWithPadding,
    create_training_args,
    evaluate_model_on_dataset
)

print(f"pytorch version: {torch.__version__}")
print(f"cuda available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"cuda device: {torch.cuda.get_device_name(0)}")

## 2. configure experiment

In [None]:
config = {
    'model_name': 'facebook/wav2vec2-base-960h',
    'dataset': 'italian_pvs',
    'task': 'vowel_a',
    'max_duration': 10.0,
    'target_sr': 16000,
    
    'freeze_feature_extractor': True,
    'freeze_encoder_layers': None,
    'dropout': 0.1,
    
    'num_epochs': 20,
    'batch_size': 8,
    'learning_rate': 1e-4,
    'warmup_ratio': 0.1,
    'gradient_accumulation_steps': 4,
    
    'test_size': 0.2,
    'val_size': 0.1,
    'random_seed': 42
}

experiment_name = f"wav2vec2_{config['dataset']}_{config['task']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
output_dir = Path('results/checkpoints') / experiment_name

print(f"experiment: {experiment_name}")
print(f"output dir: {output_dir}")

## 3. load and split dataset

In [None]:
data_root = Path('data/raw')

if config['dataset'] == 'italian_pvs':
    dataset = ItalianPVSDataset(
        root_dir=data_root / 'italian_pvs',
        task=config['task'],
        target_sr=config['target_sr'],
        max_duration=config['max_duration'],
        normalize_audio=True
    )
elif config['dataset'] == 'mdvr_kcl':
    dataset = MDVRKCLDataset(
        root_dir=data_root / 'mdvr_kcl',
        target_sr=config['target_sr'],
        max_duration=config['max_duration'],
        normalize_audio=True
    )
else:
    raise ValueError(f"unknown dataset: {config['dataset']}")

print(f"total samples: {len(dataset)}")
print(f"subjects: {dataset.get_subject_count()}")
print(f"label distribution: {dataset.get_label_distribution()}")

In [None]:
train_dataset, val_dataset, test_dataset = dataset.get_subject_split(
    test_size=config['test_size'],
    val_size=config['val_size'],
    random_state=config['random_seed'],
    stratify=True
)

print(f"train: {len(train_dataset)} samples")
print(f"val: {len(val_dataset)} samples")
print(f"test: {len(test_dataset)} samples")

train_labels = [dataset[i]['label'] for i in train_dataset.indices]
print(f"\ntrain pd ratio: {sum(train_labels)/len(train_labels):.2%}")

## 4. initialize model

In [None]:
model = Wav2Vec2PDClassifier(
    model_name=config['model_name'],
    num_labels=2,
    freeze_feature_extractor=config['freeze_feature_extractor'],
    freeze_encoder_layers=config['freeze_encoder_layers'],
    dropout=config['dropout'],
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

param_counts = model.count_parameters()
print(f"total parameters: {param_counts['total']:,}")
print(f"trainable parameters: {param_counts['trainable']:,}")
print(f"frozen parameters: {param_counts['frozen']:,}")
print(f"trainable: {param_counts['trainable_percent']:.1f}%")

## 5. setup training

In [None]:
training_args = create_training_args(
    output_dir=output_dir,
    num_epochs=config['num_epochs'],
    batch_size=config['batch_size'],
    learning_rate=config['learning_rate'],
    warmup_ratio=config['warmup_ratio'],
    gradient_accumulation_steps=config['gradient_accumulation_steps'],
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    fp16=torch.cuda.is_available(),
    seed=config['random_seed']
)

data_collator = DataCollatorWithPadding(
    feature_extractor=model.feature_extractor
)

trainer = PDClassifierTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    training_args=training_args,
    data_collator=data_collator
)

print("trainer initialized")

## 6. train model

In [None]:
import json

config_path = output_dir / 'config.json'
config_path.parent.mkdir(parents=True, exist_ok=True)
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

print("starting training...")
train_metrics = trainer.train()

print("\ntraining complete!")
print(f"train loss: {train_metrics['train_loss']:.4f}")

## 7. evaluate model

In [None]:
val_metrics = trainer.evaluate(val_dataset)

print("validation metrics:")
for key, value in val_metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

In [None]:
test_metrics = evaluate_model_on_dataset(
    model,
    test_dataset,
    batch_size=16
)

print("test metrics:")
for key, value in test_metrics.items():
    if isinstance(value, (float, int)):
        print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

In [None]:
results_summary = {
    'config': config,
    'train_metrics': train_metrics,
    'val_metrics': val_metrics,
    'test_metrics': {k: v for k, v in test_metrics.items() if k != 'confusion_matrix'}
}

results_path = output_dir / 'results.json'
with open(results_path, 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"results saved to {results_path}")

## 8. visualize results

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

cm = np.array(test_metrics['confusion_matrix'])
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1)
ax1.set_xlabel('predicted')
ax1.set_ylabel('actual')
ax1.set_title('confusion matrix (test set)')
ax1.set_xticklabels(['healthy', 'parkinson'])
ax1.set_yticklabels(['healthy', 'parkinson'])

metrics_names = ['accuracy', 'precision', 'recall', 'f1', 'auc']
metrics_values = [test_metrics[m] for m in metrics_names]

ax2.barh(metrics_names, metrics_values)
ax2.set_xlabel('score')
ax2.set_title('test set metrics')
ax2.set_xlim(0, 1)

for i, v in enumerate(metrics_values):
    ax2.text(v + 0.01, i, f'{v:.3f}', va='center')

plt.tight_layout()
plt.savefig(output_dir / 'evaluation_results.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. save final model

In [None]:
final_model_path = output_dir / 'final_model'
model.save(final_model_path)

print(f"model saved to {final_model_path}")
print(f"\nto load this model later:")
print(f"  from src.models import Wav2Vec2PDClassifier")
print(f"  model = Wav2Vec2PDClassifier.load('{final_model_path}')")

## 10. training complete

next steps:
- extract activations from all layers for interpretability analysis
- run probing experiments to identify clinical feature encoding  
- perform activation patching to find causal circuits

In [None]:
print("training complete!")
print(f"\ntest accuracy: {test_metrics['accuracy']:.1%}")
print(f"test f1: {test_metrics['f1']:.3f}")
print(f"test auc: {test_metrics['auc']:.3f}")
print(f"\nmodel checkpoint: {final_model_path}")