# phase 3: wav2vec2 fine-tuning for pd classification

fine-tune wav2vec2-base-960h on parkinson's disease voice detection using
leave-one-subject-out (loso) cross-validation for rigorous evaluation.

**methodology:**
- loso cv: same protocol as clinical baseline (88.3% accuracy) for fair comparison
- freeze cnn feature extractor + first 4 transformer layers (small dataset)
- gradient checkpointing for memory efficiency
- early stopping to prevent overfitting

**expected results:**
- target accuracy: 80-90% (competitive with clinical baseline)
- comparison with 17-feature clinical model establishes deep learning value

**hardware support:**
- nvidia gpu (cuda) - recommended
- apple silicon (mps) - supported but slower
- cpu - not recommended (10-20+ hours)

## 1. setup and configuration

In [None]:
import sys
from pathlib import Path
import os

# set project root
project_root = Path('/Volumes/usb drive/pd-interpretability')
assert project_root.exists(), f"project root not found: {project_root}"

os.chdir(project_root)
sys.path.insert(0, str(project_root))

print(f"project root: {project_root}")
print(f"working directory: {os.getcwd()}")

In [None]:
import torch
import numpy as np
import pandas as pd
import json
import warnings
from datetime import datetime
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

warnings.filterwarnings('ignore')

def detect_device():
    """detect best available compute device."""
    if torch.cuda.is_available():
        device = 'cuda'
        device_name = torch.cuda.get_device_name(0)
        memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"detected: nvidia gpu ({device_name})")
        print(f"vram: {memory_gb:.1f} gb")
        return device, True
    elif torch.backends.mps.is_available():
        device = 'mps'
        print("detected: apple silicon (mps)")
        return device, True
    else:
        device = 'cpu'
        print("warning: no gpu detected, using cpu (very slow)")
        return device, False

device, has_accelerator = detect_device()
print(f"pytorch version: {torch.__version__}")
print(f"device: {device}")

In [None]:
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW

from transformers import (
    Wav2Vec2ForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    get_linear_schedule_with_warmup
)

from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)

from src.data.datasets import ItalianPVSDataset
from src.models.classifier import DataCollatorWithPadding

print("imports complete")

In [None]:
# experiment configuration
config = {
    # model
    'model_name': 'facebook/wav2vec2-base-960h',
    'num_labels': 2,
    'freeze_feature_extractor': True,
    'freeze_encoder_layers': 4,  # freeze first 4 transformer layers
    'dropout': 0.15,
    'gradient_checkpointing': True,
    
    # audio
    'max_duration': 10.0,
    'target_sr': 16000,
    
    # training
    'num_epochs': 15,
    'learning_rate': 5e-5,  # lower for small dataset stability
    'warmup_ratio': 0.1,
    'weight_decay': 0.01,
    'early_stopping_patience': 3,
    
    # loso cv
    'max_folds': None,  # none = all subjects, set to 5-10 for testing
    
    # random seed
    'random_seed': 42
}

# device-specific settings
if device == 'cuda':
    config['batch_size'] = 8
    config['gradient_accumulation_steps'] = 4
    config['fp16'] = True
elif device == 'mps':
    config['batch_size'] = 4
    config['gradient_accumulation_steps'] = 8
    config['fp16'] = False  # mps fp16 unstable
else:
    config['batch_size'] = 2
    config['gradient_accumulation_steps'] = 16
    config['fp16'] = False

effective_batch = config['batch_size'] * config['gradient_accumulation_steps']
print(f"batch size: {config['batch_size']} (effective: {effective_batch})")
print(f"learning rate: {config['learning_rate']}")
print(f"epochs: {config['num_epochs']}")
print(f"frozen layers: cnn + first {config['freeze_encoder_layers']} transformer")

## 2. load dataset

In [None]:
# load dataset
data_root = project_root / 'data' / 'raw'

dataset = ItalianPVSDataset(
    root_dir=str(data_root / 'italian_pvs'),
    task=None,  # all tasks
    target_sr=config['target_sr'],
    max_duration=config['max_duration']
)

print(f"samples: {len(dataset)}")

# extract labels and subject ids for loso cv
labels = np.array([s['label'] for s in dataset.samples])
subject_ids = np.array([s['subject_id'] for s in dataset.samples])

# unique subjects
unique_subjects = np.unique(subject_ids)
n_subjects = len(unique_subjects)

# class distribution
n_pd = np.sum(labels)
n_hc = len(labels) - n_pd
print(f"class distribution: {n_hc} hc, {n_pd} pd")
print(f"subjects: {n_subjects}")
print(f"loso cv folds: {n_subjects}")

In [None]:
# create output directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
experiment_name = f"wav2vec2_loso_{timestamp}"
output_dir = project_root / 'results' / 'checkpoints' / experiment_name
output_dir.mkdir(parents=True, exist_ok=True)

# save config
config_path = output_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

print(f"experiment: {experiment_name}")
print(f"output: {output_dir}")

## 3. training utilities

In [None]:
def create_model(config: dict, device: str):
    """create fresh wav2vec2 model with specified freezing strategy."""
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        config['model_name'],
        num_labels=config['num_labels'],
        classifier_proj_size=256,
        hidden_dropout=config['dropout'],
        attention_dropout=config['dropout'],
        final_dropout=config['dropout']
    )
    
    # enable gradient checkpointing
    if config['gradient_checkpointing']:
        model.gradient_checkpointing_enable()
    
    # freeze cnn feature extractor
    if config['freeze_feature_extractor']:
        model.freeze_feature_encoder()
    
    # freeze first n transformer layers
    if config['freeze_encoder_layers'] > 0:
        for i, layer in enumerate(model.wav2vec2.encoder.layers):
            if i < config['freeze_encoder_layers']:
                for param in layer.parameters():
                    param.requires_grad = False
    
    return model.to(device)


def count_parameters(model):
    """count trainable and frozen parameters."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, total - trainable

In [None]:
def train_epoch(model, loader, optimizer, scheduler, scaler, device, accumulation_steps):
    """train for one epoch with gradient accumulation."""
    model.train()
    total_loss = 0
    n_batches = 0
    optimizer.zero_grad()
    
    for step, batch in enumerate(loader):
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(input_values, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss / accumulation_steps
            scaler.scale(loss).backward()
        else:
            outputs = model(input_values, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss / accumulation_steps
            loss.backward()
        
        total_loss += loss.item() * accumulation_steps
        n_batches += 1
        
        if (step + 1) % accumulation_steps == 0:
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
    
    return total_loss / n_batches if n_batches > 0 else 0.0

@torch.no_grad()
def evaluate(model, loader, device):
    """evaluate model on dataset."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    for batch in loader:
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_values, attention_mask=attention_mask, labels=labels)
        total_loss += outputs.loss.item()
        
        probs = torch.softmax(outputs.logits, dim=-1)
        preds = outputs.logits.argmax(dim=-1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())
    
    n_batches = len(loader)
    
    return {
        'loss': total_loss / n_batches if n_batches > 0 else 0,
        'accuracy': accuracy_score(all_labels, all_preds),
        'predictions': np.array(all_preds),
        'labels': np.array(all_labels),
        'probabilities': np.array(all_probs)
    }

In [None]:
## 4. loso cross-validation training

In [None]:
def train_fold(
    dataset,
    train_indices: np.ndarray,
    test_indices: np.ndarray,
    config: dict,
    device: str,
    fold_idx: int,
    output_dir: Path
) -> Dict:
    """train model on single loso fold."""
    
    # create data subsets
    train_subset = Subset(dataset, train_indices.tolist())
    test_subset = Subset(dataset, test_indices.tolist())
    
    # feature extractor and collator
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['model_name'])
    data_collator = DataCollatorWithPadding(
        feature_extractor=feature_extractor,
        max_length=int(config['max_duration'] * config['target_sr'])
    )
    
    # dataloaders
    train_loader = DataLoader(
        train_subset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=data_collator,
        num_workers=0,
        pin_memory=(device == 'cuda')
    )
    
    test_loader = DataLoader(
        test_subset,
        batch_size=config['batch_size'] * 2,
        shuffle=False,
        collate_fn=data_collator,
        num_workers=0,
        pin_memory=(device == 'cuda')
    )
    
    # create fresh model
    model = create_model(config, device)
    
    # optimizer
    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # scheduler
    steps_per_epoch = max(1, len(train_loader) // config['gradient_accumulation_steps'])
    total_steps = steps_per_epoch * config['num_epochs']
    warmup_steps = int(total_steps * config['warmup_ratio'])
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    # gradient scaler for fp16
    scaler = torch.cuda.amp.GradScaler() if config['fp16'] and device == 'cuda' else None
    
    # training loop with early stopping
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(config['num_epochs']):
        train_loss = train_epoch(
            model, train_loader, optimizer, scheduler, scaler,
            device, config['gradient_accumulation_steps']
        )
        
        test_metrics = evaluate(model, test_loader, device)
        
        # early stopping check
        if test_metrics['loss'] < best_loss:
            best_loss = test_metrics['loss']
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= config['early_stopping_patience']:
                break
    
    # final evaluation
    final_metrics = evaluate(model, test_loader, device)
    
    # cleanup
    del model, optimizer, scheduler
    if device == 'cuda':
        torch.cuda.empty_cache()
    elif device == 'mps':
        torch.mps.empty_cache()
    
    return {
        'fold': fold_idx,
        'train_samples': len(train_indices),
        'test_samples': len(test_indices),
        'accuracy': final_metrics['accuracy'],
        'predictions': final_metrics['predictions'],
        'labels': final_metrics['labels'],
        'probabilities': final_metrics['probabilities']
    }

In [None]:
# run loso cross-validation
logo = LeaveOneGroupOut()
n_folds = logo.get_n_splits(groups=subject_ids)

if config['max_folds']:
    n_folds = min(n_folds, config['max_folds'])

print(f"running loso cv: {n_folds} folds")
print(f"device: {device}")
print("=" * 60)

fold_results = []
all_predictions = []
all_labels = []
all_probabilities = []
all_subject_ids = []

start_time = datetime.now()

for fold_idx, (train_idx, test_idx) in enumerate(
    logo.split(np.arange(len(dataset)), labels, subject_ids)
):
    if config['max_folds'] and fold_idx >= config['max_folds']:
        break
    
    test_subject = subject_ids[test_idx[0]]
    test_label = labels[test_idx[0]]
    label_str = "pd" if test_label == 1 else "hc"
    
    print(f"\nfold {fold_idx + 1}/{n_folds}: {test_subject} ({label_str})")
    print(f"  train: {len(train_idx)}, test: {len(test_idx)}")
    
    result = train_fold(
        dataset=dataset,
        train_indices=train_idx,
        test_indices=test_idx,
        config=config,
        device=device,
        fold_idx=fold_idx,
        output_dir=output_dir
    )
    
    fold_results.append(result)
    all_predictions.extend(result['predictions'])
    all_labels.extend(result['labels'])
    all_probabilities.extend(result['probabilities'])
    all_subject_ids.extend([test_subject] * len(result['predictions']))
    
    print(f"  accuracy: {result['accuracy']:.1%}")

elapsed = datetime.now() - start_time
print(f"\n{'=' * 60}")
print(f"loso cv complete in {elapsed}")

## 5. aggregate results

In [None]:
# convert to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)

# overall metrics
overall_accuracy = accuracy_score(all_labels, all_predictions)
overall_precision = precision_score(all_labels, all_predictions, zero_division=0)
overall_recall = recall_score(all_labels, all_predictions, zero_division=0)
overall_f1 = f1_score(all_labels, all_predictions, zero_division=0)

try:
    overall_auc = roc_auc_score(all_labels, all_probabilities)
except:
    overall_auc = 0.5

# per-fold accuracy
fold_accuracies = [r['accuracy'] for r in fold_results]
mean_accuracy = np.mean(fold_accuracies)
std_accuracy = np.std(fold_accuracies)

# confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

print("=" * 60)
print("LOSO CROSS-VALIDATION RESULTS")
print("=" * 60)
print(f"\noverall metrics (aggregated across all folds):")
print(f"  accuracy: {overall_accuracy:.1%}")
print(f"  precision: {overall_precision:.3f}")
print(f"  recall: {overall_recall:.3f}")
print(f"  f1 score: {overall_f1:.3f}")
print(f"  auc-roc: {overall_auc:.3f}")

print(f"\nper-fold statistics:")
print(f"  mean accuracy: {mean_accuracy:.1%} ± {std_accuracy:.1%}")
print(f"  min: {min(fold_accuracies):.1%}, max: {max(fold_accuracies):.1%}")

print(f"\nconfusion matrix:")
print(f"           predicted")
print(f"            hc    pd")
print(f"actual hc  {cm[0,0]:4d}  {cm[0,1]:4d}")
print(f"       pd  {cm[1,0]:4d}  {cm[1,1]:4d}")

# per-subject accuracy
subject_results = {}
for subj, pred, label in zip(all_subject_ids, all_predictions, all_labels):
    if subj not in subject_results:
        subject_results[subj] = {'correct': 0, 'total': 0, 'label': label}
    subject_results[subj]['total'] += 1
    if pred == label:
        subject_results[subj]['correct'] += 1

subject_accuracies = []
subject_data = []

for subj, data in subject_results.items():
    acc = data['correct'] / data['total']
    subject_accuracies.append(acc)
    diagnosis = 'pd' if data['label'] == 1 else 'hc'
    subject_data.append({
        'subject_id': subj,
        'diagnosis': diagnosis,
        'accuracy': acc,
        'correct': data['correct'],
        'total': data['total']
    })

subject_df = pd.DataFrame(subject_data)
subject_df = subject_df.sort_values('accuracy', ascending=False)

print(f"\nper-subject accuracy:")
print(f"  mean: {np.mean(subject_accuracies):.1%}")
print(f"  median: {np.median(subject_accuracies):.1%}")
print(f"  subjects with 100% accuracy: {sum(1 for a in subject_accuracies if a == 1.0)}/{len(subject_accuracies)}")
print(f"  subjects with <50% accuracy: {sum(1 for a in subject_accuracies if a < 0.5)}/{len(subject_accuracies)}")

In [None]:
## 6. comparison with clinical baseline

In [None]:
# load clinical baseline results for comparison
baseline_path = project_root / 'results' / 'clinical_baseline_results.json'

if baseline_path.exists():
    with open(baseline_path) as f:
        baseline_results = json.load(f)
    
    clinical_acc = baseline_results['svm_results']['mean_accuracy']
    clinical_std = baseline_results['svm_results']['std_accuracy']
    
    print("=" * 60)
    print("COMPARISON WITH CLINICAL BASELINE")
    print("=" * 60)
    print(f"\nclinical baseline (svm, 17 features):")
    print(f"  accuracy: {clinical_acc:.1%} ± {clinical_std:.1%}")
    
    print(f"\nwav2vec2 (loso cv):")
    print(f"  accuracy: {overall_accuracy:.1%}")
    print(f"  per-fold mean: {mean_accuracy:.1%} ± {std_accuracy:.1%}")
    
    diff = overall_accuracy - clinical_acc
    print(f"\ndifference: {diff:+.1%}")
    
    if diff > 0:
        print("  wav2vec2 outperforms clinical baseline")
    elif diff < -0.05:
        print("  clinical baseline outperforms wav2vec2")
    else:
        print("  comparable performance")
else:
    print("clinical baseline results not found")

## 7. save results

In [None]:
# save comprehensive results
results = {
    'experiment': experiment_name,
    'config': config,
    'device': device,
    'n_folds': len(fold_results),
    'n_subjects': n_subjects,
    'n_samples': len(dataset),
    
    'overall_metrics': {
        'accuracy': float(overall_accuracy),
        'precision': float(overall_precision),
        'recall': float(overall_recall),
        'f1': float(overall_f1),
        'auc': float(overall_auc)
    },
    
    'per_fold_stats': {
        'mean_accuracy': float(mean_accuracy),
        'std_accuracy': float(std_accuracy),
        'min_accuracy': float(min(fold_accuracies)),
        'max_accuracy': float(max(fold_accuracies))
    },
    
    'confusion_matrix': cm.tolist(),
    
    'fold_results': [
        {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in r.items()}
        for r in fold_results
    ],
    
    'timestamp': datetime.now().isoformat(),
    'elapsed_time': str(elapsed)
}

# save to json
results_path = output_dir / 'wav2vec2_loso_results.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

# save subject-level results
subject_path = output_dir / 'wav2vec2_subject_accuracy.csv'
subject_df.to_csv(subject_path, index=False)

# also save to main results folder for easy access
main_results_path = project_root / 'results' / 'wav2vec2_loso_results.json'
with open(main_results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"results saved to:")
print(f"  {results_path}")
print(f"  {subject_path}")
print(f"  {main_results_path}")

## 8. visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# publication style
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 10,
    'axes.titlesize': 11,
    'axes.labelsize': 10,
    'figure.dpi': 150,
    'savefig.dpi': 300
})

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# fold accuracy distribution
ax1 = axes[0]
ax1.bar(range(len(fold_accuracies)), fold_accuracies, color='steelblue', alpha=0.7)
ax1.axhline(y=mean_accuracy, color='red', linestyle='--', linewidth=2, 
            label=f'mean: {mean_accuracy:.1%}')
ax1.set_xlabel('fold')
ax1.set_ylabel('accuracy')
ax1.set_title('per-fold accuracy')
ax1.legend()
ax1.set_ylim(0, 1.05)

# confusion matrix
ax2 = axes[1]
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax2,
            xticklabels=['hc', 'pd'], yticklabels=['hc', 'pd'])
ax2.set_xlabel('predicted')
ax2.set_ylabel('actual')
ax2.set_title(f'confusion matrix (acc: {overall_accuracy:.1%})')

# subject accuracy histogram
ax3 = axes[2]
ax3.hist(subject_accuracies, bins=10, color='steelblue', alpha=0.7, edgecolor='white')
ax3.axvline(x=np.mean(subject_accuracies), color='red', linestyle='--', 
            label=f'mean: {np.mean(subject_accuracies):.1%}')
ax3.set_xlabel('accuracy')
ax3.set_ylabel('count')
ax3.set_title('per-subject accuracy distribution')
ax3.legend()

plt.tight_layout()
plt.savefig(output_dir / 'wav2vec2_results_summary.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"figure saved to {output_dir / 'wav2vec2_results_summary.png'}")

In [None]:
# comparison bar chart
if baseline_path.exists():
    fig, ax = plt.subplots(figsize=(8, 5))
    
    models = ['clinical baseline\n(svm, 17 features)', 'wav2vec2\n(fine-tuned)']
    accuracies = [clinical_acc, overall_accuracy]
    stds = [clinical_std, std_accuracy]
    colors = ['#2ecc71', '#3498db']
    
    bars = ax.bar(models, accuracies, yerr=stds, capsize=8, color=colors, alpha=0.8)
    
    for bar, acc in zip(bars, accuracies):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{acc:.1%}', ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    ax.set_ylabel('accuracy')
    ax.set_title('model comparison: loso cross-validation')
    ax.set_ylim(0, 1.1)
    ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, label='chance')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(output_dir / 'model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"comparison figure saved")

## 9. summary and next steps

In [None]:
print("=" * 60)
print("PHASE 3 COMPLETE: WAV2VEC2 FINE-TUNING")
print("=" * 60)

print(f"\nmodel: {config['model_name']}")
print(f"device: {device}")
print(f"loso cv folds: {len(fold_results)}")
print(f"training time: {elapsed}")

print(f"\nresults:")
print(f"  accuracy: {overall_accuracy:.1%}")
print(f"  precision: {overall_precision:.3f}")
print(f"  recall: {overall_recall:.3f}")
print(f"  f1 score: {overall_f1:.3f}")
print(f"  auc-roc: {overall_auc:.3f}")

print(f"\nper-fold accuracy: {mean_accuracy:.1%} ± {std_accuracy:.1%}")

print(f"\nnext steps:")
print(f"  1. phase 4: activation extraction (notebook 04)")
print(f"  2. phase 5: probing experiments (notebook 05)")
print(f"  3. phase 6: activation patching (notebook 06)")

print(f"\noutputs saved to: {output_dir}")

## 10. train final model for activation extraction

train a single model on all data for use in probing and patching experiments.
this model will be used to extract activations in phase 4.

In [None]:
# train final model on 80% of data (hold out 20% for testing)
train_subset, _, test_subset = dataset.get_subject_split(
    test_size=0.2,
    val_size=0.0,
    random_state=config['random_seed']
)

print(f"training final model...")
print(f"  train samples: {len(train_subset)}")
print(f"  test samples: {len(test_subset)}")

# create model
final_model = create_model(config, device)

# feature extractor and data collator
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['model_name'])
data_collator = DataCollatorWithPadding(
    feature_extractor=feature_extractor,
    max_length=int(config['max_duration'] * config['target_sr'])
)

# dataloaders
train_loader = DataLoader(
    train_subset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=data_collator,
    num_workers=0,
    pin_memory=(device == 'cuda')
)

test_loader = DataLoader(
    test_subset,
    batch_size=config['batch_size'] * 2,
    shuffle=False,
    collate_fn=data_collator,
    num_workers=0,
    pin_memory=(device == 'cuda')
)

# optimizer and scheduler
optimizer = AdamW(
    [p for p in final_model.parameters() if p.requires_grad],
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

steps_per_epoch = max(1, len(train_loader) // config['gradient_accumulation_steps'])
total_steps = steps_per_epoch * config['num_epochs']
warmup_steps = int(total_steps * config['warmup_ratio'])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

scaler = torch.cuda.amp.GradScaler() if config['fp16'] and device == 'cuda' else None

# training loop
best_acc = 0
for epoch in range(config['num_epochs']):
    train_loss = train_epoch(
        final_model, train_loader, optimizer, scheduler, scaler,
        device, config['gradient_accumulation_steps']
    )
    
    test_metrics = evaluate(final_model, test_loader, device)
    
    if test_metrics['accuracy'] > best_acc:
        best_acc = test_metrics['accuracy']
        # save checkpoint
        checkpoint_path = output_dir / 'final_model'
        final_model.save_pretrained(checkpoint_path)
        feature_extractor.save_pretrained(checkpoint_path)
    
    print(f"epoch {epoch+1}/{config['num_epochs']}: loss={train_loss:.4f}, acc={test_metrics['accuracy']:.1%}")

print(f"\nfinal model saved to: {checkpoint_path}")
print(f"test accuracy: {best_acc:.1%}")