In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0, '/content/drive/MyDrive/pd-interpretability')

In [None]:
# install dependencies
!pip install -q transformers datasets librosa scipy tqdm scikit-learn

In [None]:
import torch
import numpy as np
import json
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns

# verify gpu
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

np.random.seed(42)
torch.manual_seed(42)

## 1. Configuration

In [None]:
CONFIG = {
    'data_path': '/content/drive/MyDrive/pd-interpretability/data',
    'output_path': '/content/drive/MyDrive/pd-interpretability/results/generalization',
    
    # training parameters
    'batch_size': 8,
    'epochs': 3,  # fewer epochs per dataset for quick analysis
    'learning_rate': 1e-5,
    'warmup_ratio': 0.1,
    
    # probing parameters
    'probe_epochs': 50,
    'probe_lr': 0.01,
    
    # clinical features for alignment
    'clinical_features': ['jitter', 'shimmer', 'hnr', 'speech_rate'],
    
    'random_seed': 42
}

output_path = Path(CONFIG['output_path'])
output_path.mkdir(parents=True, exist_ok=True)

## 2. Load All Datasets

In [None]:
from src.data import (
    ItalianPVSDataset,
    ArkansasDataset,
    MDVRKCLDataset,
    UCIParkinsonDataset
)

data_path = Path(CONFIG['data_path'])

# load available datasets
datasets = {}

# italian pvs
italian_path = data_path / 'raw' / 'italian_pvs'
if italian_path.exists():
    datasets['italian_pvs'] = ItalianPVSDataset(italian_path, max_duration=10.0)
    print(f"loaded italian_pvs: {len(datasets['italian_pvs'])} samples")

# arkansas
arkansas_path = data_path / 'raw' / 'arkansas (figshare)'
if arkansas_path.exists():
    datasets['arkansas'] = ArkansasDataset(arkansas_path, max_duration=10.0)
    print(f"loaded arkansas: {len(datasets['arkansas'])} samples")

# mdvr-kcl
mdvr_path = data_path / 'raw' / 'mdvr-kcl'
if mdvr_path.exists():
    datasets['mdvr_kcl'] = MDVRKCLDataset(mdvr_path, max_duration=10.0)
    print(f"loaded mdvr_kcl: {len(datasets['mdvr_kcl'])} samples")

print(f"\ntotal datasets loaded: {len(datasets)}")

In [None]:
# dataset statistics
print("\ndataset statistics:")
print("-" * 60)

for name, ds in datasets.items():
    labels = [ds[i][1] for i in range(len(ds))]
    n_pd = sum(labels)
    n_hc = len(labels) - n_pd
    print(f"{name:15s}: {len(ds):4d} samples (PD: {n_pd:3d}, HC: {n_hc:3d})")

## 3. Train Dataset-Specific Models

Train a separate model on each dataset.

In [None]:
from src.models import DatasetSpecificTrainer, Wav2Vec2PDClassifier
from torch.utils.data import DataLoader, random_split

trainer = DatasetSpecificTrainer(
    model_class=Wav2Vec2PDClassifier,
    epochs=CONFIG['epochs'],
    learning_rate=CONFIG['learning_rate'],
    batch_size=CONFIG['batch_size'],
    device=device
)

In [None]:
# train models on each dataset
print("training dataset-specific models...")
print("=" * 60)

models = {}
training_metrics = {}

for name, ds in datasets.items():
    print(f"\ntraining on {name}...")
    
    # split dataset
    train_size = int(0.8 * len(ds))
    val_size = len(ds) - train_size
    train_ds, val_ds = random_split(ds, [train_size, val_size])
    
    # train
    model, metrics = trainer.train(
        train_dataset=train_ds,
        val_dataset=val_ds,
        dataset_name=name
    )
    
    models[name] = model
    training_metrics[name] = metrics
    
    print(f"  final train accuracy: {metrics['train_accuracy']:.3f}")
    print(f"  final val accuracy: {metrics['val_accuracy']:.3f}")

print("\n" + "=" * 60)
print(f"trained {len(models)} models")

## 4. Build Cross-Dataset Evaluation Matrix

Evaluate each model on all datasets to build N×N performance matrix.

In [None]:
from src.models import CrossDatasetEvaluator

evaluator = CrossDatasetEvaluator(device=device)

# build evaluation matrix
print("building cross-dataset evaluation matrix...")

results = evaluator.evaluate_all(
    models=models,
    datasets=datasets,
    batch_size=CONFIG['batch_size']
)

print("\nevaluation complete!")

In [None]:
# display accuracy matrix
dataset_names = list(datasets.keys())
n_datasets = len(dataset_names)

print("\naccuracy matrix (row=train, col=test):")
print("-" * 60)

# header
header = "train\\test   " + "  ".join([f"{n[:8]:>8s}" for n in dataset_names])
print(header)

# rows
for train_name in dataset_names:
    row_values = []
    for test_name in dataset_names:
        acc = results.accuracy_matrix.get((train_name, test_name), 0.0)
        row_values.append(f"{acc:.3f}")
    print(f"{train_name[:12]:12s}  " + "    ".join(row_values))

In [None]:
# visualize as heatmap
acc_matrix = np.zeros((n_datasets, n_datasets))

for i, train_name in enumerate(dataset_names):
    for j, test_name in enumerate(dataset_names):
        acc_matrix[i, j] = results.accuracy_matrix.get((train_name, test_name), 0.0)

fig, ax = plt.subplots(figsize=(10, 8))

im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0.4, vmax=1.0)

# add text annotations
for i in range(n_datasets):
    for j in range(n_datasets):
        color = 'white' if acc_matrix[i, j] < 0.7 else 'black'
        text = ax.text(j, i, f'{acc_matrix[i, j]:.3f}',
                       ha='center', va='center', color=color, fontsize=12)

ax.set_xticks(range(n_datasets))
ax.set_yticks(range(n_datasets))
ax.set_xticklabels([n.replace('_', '\n') for n in dataset_names], fontsize=10)
ax.set_yticklabels([n.replace('_', '\n') for n in dataset_names], fontsize=10)

ax.set_xlabel('Test Dataset', fontsize=12)
ax.set_ylabel('Train Dataset', fontsize=12)
ax.set_title('Cross-Dataset Generalization Matrix\n(Accuracy)', fontsize=14)

cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Accuracy', fontsize=11)

plt.tight_layout()
plt.savefig(output_path / 'cross_dataset_accuracy_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# compute generalization gaps
print("\ngeneralization gaps (in-domain - out-of-domain):")
print("-" * 60)

for train_name in dataset_names:
    in_domain = results.accuracy_matrix.get((train_name, train_name), 0.0)
    
    out_domain_accs = [
        results.accuracy_matrix.get((train_name, test_name), 0.0)
        for test_name in dataset_names if test_name != train_name
    ]
    out_domain_mean = np.mean(out_domain_accs) if out_domain_accs else 0.0
    
    gap = results.generalization_gaps.get(train_name, in_domain - out_domain_mean)
    
    print(f"{train_name:15s}: in-domain={in_domain:.3f}, out-of-domain={out_domain_mean:.3f}, gap={gap:+.3f}")

## 5. Clinical Alignment Analysis

Compute layerwise probing accuracy for clinical features.

In [None]:
# load clinical features
clinical_data = {}

for name in datasets.keys():
    clinical_path = Path(CONFIG['data_path']) / 'clinical_features' / f'{name}_features.json'
    if clinical_path.exists():
        with open(clinical_path, 'r') as f:
            clinical_data[name] = json.load(f)
        print(f"loaded clinical features for {name}")
    else:
        print(f"no clinical features for {name}")

print(f"\nclinical features available for {len(clinical_data)} datasets")

In [None]:
from src.models import ClinicalAlignmentAnalyzer

if len(clinical_data) > 0:
    alignment_analyzer = ClinicalAlignmentAnalyzer(
        probe_epochs=CONFIG['probe_epochs'],
        probe_lr=CONFIG['probe_lr'],
        device=device
    )
    
    # compute alignment profiles for each model-dataset pair
    alignment_profiles = {}
    
    for model_name, model in models.items():
        for dataset_name, ds in datasets.items():
            if dataset_name not in clinical_data:
                continue
                
            print(f"\ncomputing alignment: {model_name} model → {dataset_name} data")
            
            # get clinical features for this dataset
            features = clinical_data[dataset_name]['features']
            sample_ids = clinical_data[dataset_name]['sample_ids']
            
            profile = alignment_analyzer.compute_alignment_profile(
                model=model,
                dataset=ds,
                clinical_features=features,
                sample_ids=sample_ids
            )
            
            alignment_profiles[(model_name, dataset_name)] = profile
            print(f"  overall alignment score: {profile.overall_alignment:.3f}")
    
    print(f"\ncomputed {len(alignment_profiles)} alignment profiles")
else:
    print("skipping alignment analysis - no clinical features available")
    alignment_profiles = {}

In [None]:
if len(alignment_profiles) > 0:
    # visualize layerwise alignment for each feature
    sample_profile = list(alignment_profiles.values())[0]
    n_layers = len(sample_profile.layerwise_probing_accuracy.get('jitter', {}))
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    for ax, feature in zip(axes.flat, CONFIG['clinical_features'][:4]):
        for (model_name, dataset_name), profile in alignment_profiles.items():
            if feature in profile.layerwise_probing_accuracy:
                layer_accs = profile.layerwise_probing_accuracy[feature]
                layers = list(range(len(layer_accs)))
                accuracies = [layer_accs.get(l, 0.5) for l in layers]
                
                label = f"{model_name[:6]}→{dataset_name[:6]}"
                ax.plot(layers, accuracies, marker='o', label=label, alpha=0.7)
        
        ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
        ax.set_xlabel('Layer')
        ax.set_ylabel('Probing Accuracy')
        ax.set_title(f'{feature.upper()} Encoding by Layer')
        ax.legend(fontsize=8)
        ax.set_ylim(0.4, 1.0)
    
    plt.tight_layout()
    plt.savefig(output_path / 'clinical_alignment_layers.png', dpi=150, bbox_inches='tight')
    plt.show()

## 6. Generalization-Interpretability Correlation

Test hypothesis: Higher clinical alignment → Better generalization

In [None]:
from src.models import GeneralizationInterpretabilityAnalyzer

if len(alignment_profiles) > 0:
    correlation_analyzer = GeneralizationInterpretabilityAnalyzer()
    
    # gather data points for correlation
    alignment_scores = []
    generalization_scores = []
    labels = []
    
    for model_name in models.keys():
        # overall alignment score for this model
        model_alignments = [
            profile.overall_alignment 
            for (m, d), profile in alignment_profiles.items() 
            if m == model_name
        ]
        
        if not model_alignments:
            continue
            
        avg_alignment = np.mean(model_alignments)
        
        # out-of-domain accuracy for this model
        out_domain_accs = [
            results.accuracy_matrix.get((model_name, test_name), 0.0)
            for test_name in datasets.keys() if test_name != model_name
        ]
        
        if not out_domain_accs:
            continue
            
        avg_generalization = np.mean(out_domain_accs)
        
        alignment_scores.append(avg_alignment)
        generalization_scores.append(avg_generalization)
        labels.append(model_name)
    
    # compute correlation
    correlation_result = correlation_analyzer.compute_correlation(
        alignment_scores=alignment_scores,
        generalization_scores=generalization_scores
    )
    
    print("\ngeneralization-interpretability correlation:")
    print("-" * 60)
    print(f"spearman correlation: {correlation_result['spearman_r']:.3f}")
    print(f"p-value: {correlation_result['p_value']:.4f}")
    print(f"interpretation: {correlation_result['interpretation']}")
else:
    print("skipping correlation analysis - no alignment profiles")

In [None]:
if len(alignment_profiles) > 0 and len(alignment_scores) >= 2:
    # visualize correlation
    fig, ax = plt.subplots(figsize=(10, 8))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(labels)))
    
    for i, (a, g, l, c) in enumerate(zip(alignment_scores, generalization_scores, labels, colors)):
        ax.scatter(a, g, s=150, c=[c], label=l, edgecolors='black', linewidth=1.5)
    
    # trend line
    if len(alignment_scores) >= 2:
        z = np.polyfit(alignment_scores, generalization_scores, 1)
        p = np.poly1d(z)
        x_line = np.linspace(min(alignment_scores), max(alignment_scores), 100)
        ax.plot(x_line, p(x_line), 'r--', alpha=0.7, linewidth=2, label='trend')
    
    ax.set_xlabel('Clinical Alignment Score', fontsize=12)
    ax.set_ylabel('Out-of-Domain Accuracy', fontsize=12)
    ax.set_title(f'Generalization vs. Clinical Alignment\n(ρ = {correlation_result["spearman_r"]:.3f}, p = {correlation_result["p_value"]:.4f})', 
                 fontsize=14)
    ax.legend(loc='best')
    
    plt.tight_layout()
    plt.savefig(output_path / 'generalization_interpretability_correlation.png', dpi=150, bbox_inches='tight')
    plt.show()

## 7. Save All Results

In [None]:
# compile all results
full_results = {
    'config': CONFIG,
    'datasets': {name: len(ds) for name, ds in datasets.items()},
    'training_metrics': training_metrics,
    'cross_dataset_evaluation': {
        'accuracy_matrix': {f"{k[0]}_to_{k[1]}": v for k, v in results.accuracy_matrix.items()},
        'f1_matrix': {f"{k[0]}_to_{k[1]}": v for k, v in results.f1_matrix.items()},
        'auc_matrix': {f"{k[0]}_to_{k[1]}": v for k, v in results.auc_matrix.items()},
        'generalization_gaps': results.generalization_gaps
    }
}

if len(alignment_profiles) > 0:
    full_results['clinical_alignment'] = {
        f"{m}_{d}": {
            'overall_alignment': profile.overall_alignment,
            'feature_scores': profile.feature_alignment_scores
        }
        for (m, d), profile in alignment_profiles.items()
    }

if len(alignment_profiles) > 0 and len(alignment_scores) >= 2:
    full_results['generalization_interpretability_correlation'] = correlation_result

# save to json
results_path = output_path / 'generalization_results.json'
with open(results_path, 'w') as f:
    json.dump(full_results, f, indent=2, default=str)

print(f"results saved to {results_path}")

In [None]:
# summary
print("\n" + "=" * 60)
print("CROSS-DATASET GENERALIZATION ANALYSIS SUMMARY")
print("=" * 60)

print(f"\ndatasets analyzed: {list(datasets.keys())}")
print(f"total models trained: {len(models)}")

print("\nbest generalizing model:")
best_model = min(results.generalization_gaps.items(), key=lambda x: x[1])
print(f"  {best_model[0]}: generalization gap = {best_model[1]:+.3f}")

if len(alignment_profiles) > 0:
    print("\nhighest clinical alignment:")
    best_alignment = max(
        [(f"{m}→{d}", p.overall_alignment) for (m, d), p in alignment_profiles.items()],
        key=lambda x: x[1]
    )
    print(f"  {best_alignment[0]}: alignment = {best_alignment[1]:.3f}")

if len(alignment_profiles) > 0 and len(alignment_scores) >= 2:
    print("\ngeneralization-interpretability correlation:")
    print(f"  spearman ρ = {correlation_result['spearman_r']:.3f} (p = {correlation_result['p_value']:.4f})")
    print(f"  {correlation_result['interpretation']}")

print("\n" + "=" * 60)