# Phase 06: Cross-Dataset Generalization Analysis

**Objective**: Test Hypothesis 3 - Models with higher clinical alignment generalize better across datasets

This notebook implements comprehensive cross-dataset evaluation following LODO (Leave-One-Dataset-Out) protocol with:
- Dataset-specific model training (Italian PVS, MDVR-KCL, Arkansas)
- N×N cross-dataset evaluation matrices
- Clinical alignment analysis via layerwise probing
- Statistical correlation between alignment and generalization
- Domain shift quantification
- Publication-grade LaTeX visualizations

**Methods based on**:
- Leave-one-dataset-out CV protocol (PMC10388213)
- Spearman correlation for non-parametric analysis
- Bootstrap confidence intervals (95% CI)
- STARD-AI reporting standards

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0, '/content/drive/MyDrive/pd-interpretability')

In [None]:
# install dependencies
!pip install -q transformers datasets librosa scipy tqdm scikit-learn praat-parselmouth statsmodels

In [None]:
# install latex for publication-grade figures
!apt-get install -y dvipng texlive-latex-extra texlive-fonts-recommended cm-super texlive-science
from IPython.display import clear_output
clear_output()

In [None]:
import torch
import numpy as np
import json
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple, Any
import warnings
warnings.filterwarnings('ignore')

# scientific computing
import scipy
from scipy import stats as scipy_stats
from scipy.spatial.distance import wasserstein_distance
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, matthews_corrcoef
from sklearn.utils import resample

# visualization
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

# verify gpu
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# set seeds
np.random.seed(42)
torch.manual_seed(42)

print(f"\nPyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"SciPy version: {scipy.__version__}")

## Configure LaTeX Rendering for Publication-Grade Figures

In [None]:
# configure matplotlib for latex rendering and publication quality
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "axes.labelsize": 13,
    "font.size": 13,
    "legend.fontsize": 11,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "figure.titlesize": 15,
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "axes.grid": True,
    "grid.alpha": 0.3,
    "axes.spines.top": False,
    "axes.spines.right": False
})

# set seaborn theme
sns.set_theme(style='whitegrid', palette='colorblind')

# define consistent color palette
palette = sns.color_palette('colorblind')
COLORS = {
    'HC': palette[0],  # blue
    'PD': palette[1],  # orange
    'Neutral': 'gray',
    'Primary': palette[2],  # green
    'Secondary': palette[3],  # red
}

# verify latex rendering
try:
    fig, ax = plt.subplots(figsize=(6, 2))
    ax.text(0.5, 0.5, r'LaTeX Test: $\mathcal{H}_0: \rho = 0$, $\alpha = 0.05$',
            ha='center', va='center', fontsize=14)
    ax.set_title(r'\textbf{Publication-Grade Rendering Verified}')
    ax.axis('off')
    plt.show()
    print("\u2713 LaTeX rendering configured successfully")
except Exception as e:
    print(f"\u26a0\ufe0f Warning: LaTeX rendering failed: {e}")
    print("Falling back to standard fonts...")
    plt.rcParams.update({
        "text.usetex": False,
        "font.family": "sans-serif"
    })

## 1. Configuration

In [None]:
CONFIG = {
    'project_path': '/content/drive/MyDrive/pd-interpretability',
    'data_path': '/content/drive/MyDrive/pd-interpretability/data',
    'output_path': '/content/drive/MyDrive/pd-interpretability/results/generalization',
    
    # training parameters
    'model_name': 'facebook/wav2vec2-base-960h',
    'batch_size': 8,
    'epochs': 5,  # increased for better convergence
    'learning_rate': 1e-5,
    'warmup_ratio': 0.1,
    'weight_decay': 0.01,
    'freeze_feature_extractor': True,
    
    # evaluation parameters
    'bootstrap_iterations': 1000,  # for confidence intervals
    'confidence_level': 0.95,
    
    # probing parameters
    'n_layers': 12,
    'hidden_size': 768,
    'probe_epochs': 100,
    'probe_lr': 0.01,
    'probe_cv_folds': 5,
    
    # clinical features for alignment
    'clinical_features': ['jitter_local', 'shimmer_local', 'hnr_mean', 'f0_std'],
    
    # figure settings
    'fig_format': ['pdf', 'png', 'svg'],
    'fig_dpi': 300,
    
    'random_seed': 42
}

output_path = Path(CONFIG['output_path'])
output_path.mkdir(parents=True, exist_ok=True)

# create subdirectories
figures_path = output_path / 'figures'
figures_path.mkdir(exist_ok=True)

models_path = output_path / 'models'
models_path.mkdir(exist_ok=True)

print("Configuration loaded:")
print(f"  Output path: {output_path}")
print(f"  Training epochs: {CONFIG['epochs']}")
print(f"  Bootstrap iterations: {CONFIG['bootstrap_iterations']}")
print(f"  Clinical features: {CONFIG['clinical_features']}")

## 2. Load All Datasets

Loading three cross-linguistic, multi-institutional datasets:
- **Italian PVS**: Italian, elderly cohort, university hospital
- **MDVR-KCL**: English (UK), Kings College London
- **Arkansas**: English (US), University of Arkansas

In [None]:
from src.data import (
    ItalianPVSDataset,
    ArkansasDataset,
    MDVRKCLDataset
)

data_path = Path(CONFIG['data_path'])

# load available datasets
datasets = {}
dataset_metadata = {}

print("Loading datasets...\n")
print("=" * 80)

# italian pvs
italian_path = data_path / 'raw' / 'italian_pvs'
if italian_path.exists():
    print("\nLoading Italian PVS dataset...")
    datasets['italian_pvs'] = ItalianPVSDataset(italian_path, max_duration=10.0)
    dataset_metadata['italian_pvs'] = {
        'language': 'Italian',
        'institution': 'University Hospital',
        'country': 'Italy',
        'task': 'Vowel prolongation'
    }
    print(f"  ✓ Loaded {len(datasets['italian_pvs'])} samples")
else:
    print("  ✗ Italian PVS dataset not found")

# arkansas
arkansas_path = data_path / 'raw' / 'arkansas (figshare)'
if arkansas_path.exists():
    print("\nLoading Arkansas dataset...")
    datasets['arkansas'] = ArkansasDataset(arkansas_path, max_duration=10.0)
    dataset_metadata['arkansas'] = {
        'language': 'English (US)',
        'institution': 'University of Arkansas',
        'country': 'USA',
        'task': 'Speech tasks'
    }
    print(f"  ✓ Loaded {len(datasets['arkansas'])} samples")
else:
    print("  ✗ Arkansas dataset not found")

# mdvr-kcl
mdvr_path = data_path / 'raw' / 'mdvr-kcl'
if mdvr_path.exists():
    print("\nLoading MDVR-KCL dataset...")
    datasets['mdvr_kcl'] = MDVRKCLDataset(mdvr_path, max_duration=10.0)
    dataset_metadata['mdvr_kcl'] = {
        'language': 'English (UK)',
        'institution': 'Kings College London',
        'country': 'UK',
        'task': 'Voice recordings'
    }
    print(f"  ✓ Loaded {len(datasets['mdvr_kcl'])} samples")
else:
    print("  ✗ MDVR-KCL dataset not found")

print("\n" + "=" * 80)
print(f"\nTotal datasets loaded: {len(datasets)}\n")

if len(datasets) < 2:
    raise ValueError(f"Need at least 2 datasets for cross-dataset analysis, got {len(datasets)}")

In [None]:
# comprehensive dataset statistics
print("\n" + "=" * 80)
print("DATASET STATISTICS")
print("=" * 80)

dataset_stats = {}

for name, ds in datasets.items():
    # extract labels
    labels = []
    durations = []
    
    for i in range(len(ds)):
        sample = ds[i]
        labels.append(sample['label'])
        # estimate duration from input_values length (assuming 16kHz)
        duration = len(sample['input_values']) / 16000.0
        durations.append(duration)
    
    n_pd = sum(labels)
    n_hc = len(labels) - n_pd
    
    stats = {
        'n_total': len(ds),
        'n_pd': n_pd,
        'n_hc': n_hc,
        'pd_ratio': n_pd / len(ds) if len(ds) > 0 else 0,
        'mean_duration': np.mean(durations),
        'std_duration': np.std(durations),
        'min_duration': np.min(durations),
        'max_duration': np.max(durations)
    }
    dataset_stats[name] = stats
    
    metadata = dataset_metadata.get(name, {})
    
    print(f"\n{name.upper().replace('_', ' ')}:")
    print(f"  Language: {metadata.get('language', 'Unknown')}")
    print(f"  Institution: {metadata.get('institution', 'Unknown')}")
    print(f"  Total samples: {stats['n_total']}")
    print(f"  PD: {stats['n_pd']} ({stats['pd_ratio']:.1%})")
    print(f"  HC: {stats['n_hc']} ({(1-stats['pd_ratio']):.1%})")
    print(f"  Duration: {stats['mean_duration']:.2f}±{stats['std_duration']:.2f}s (range: {stats['min_duration']:.2f}-{stats['max_duration']:.2f}s)")

print("\n" + "=" * 80)

## 3. Extract Clinical Features for All Datasets

Extract clinical acoustic features using Praat-Parselmouth for clinical alignment analysis.

In [None]:
from src.features.clinical import ClinicalFeatureExtractor
import librosa

# initialize extractor
clinical_extractor = ClinicalFeatureExtractor()

# extract features for all datasets
clinical_data = {}

print("\nExtracting clinical features...\n")
print("=" * 80)

for dataset_name, dataset in datasets.items():
    print(f"\nProcessing {dataset_name}...")
    
    features_dict = {feat: [] for feat in CONFIG['clinical_features']}
    sample_ids = []
    
    for i in range(len(dataset)):
        if i % 50 == 0:
            print(f"  Progress: {i}/{len(dataset)} samples")
        
        try:
            sample = dataset[i]
            audio = sample['input_values'].numpy() if torch.is_tensor(sample['input_values']) else sample['input_values']
            sample_rate = 16000
            
            # extract clinical features
            features = clinical_extractor.extract_from_array(audio, sample_rate)
            
            # store only requested features
            for feat_name in CONFIG['clinical_features']:
                features_dict[feat_name].append(features.get(feat_name, np.nan))
            
            sample_ids.append(f"{dataset_name}_{i}")
            
        except Exception as e:
            print(f"  Warning: Failed to extract features for sample {i}: {e}")
            for feat_name in CONFIG['clinical_features']:
                features_dict[feat_name].append(np.nan)
            sample_ids.append(f"{dataset_name}_{i}")
    
    # convert to numpy arrays
    for feat_name in CONFIG['clinical_features']:
        features_dict[feat_name] = np.array(features_dict[feat_name])
    
    clinical_data[dataset_name] = {
        'features': features_dict,
        'sample_ids': sample_ids
    }
    
    # print statistics
    print(f"\n  Feature statistics for {dataset_name}:")
    for feat_name in CONFIG['clinical_features']:
        values = features_dict[feat_name]
        valid_values = values[~np.isnan(values)]
        if len(valid_values) > 0:
            print(f"    {feat_name:20s}: mean={np.mean(valid_values):.4f}, std={np.std(valid_values):.4f}, valid={len(valid_values)}/{len(values)}")
        else:
            print(f"    {feat_name:20s}: NO VALID VALUES")

print("\n" + "=" * 80)
print(f"\nClinical features extracted for {len(clinical_data)} datasets")

# save clinical features
clinical_features_path = Path(CONFIG['data_path']) / 'clinical_features'
clinical_features_path.mkdir(exist_ok=True)

for dataset_name, data in clinical_data.items():
    save_dict = {
        'features': {k: v.tolist() for k, v in data['features'].items()},
        'sample_ids': data['sample_ids']
    }
    save_path = clinical_features_path / f'{dataset_name}_features.json'
    with open(save_path, 'w') as f:
        json.dump(save_dict, f, indent=2)
    print(f"Saved: {save_path}")

## 4. Quantify Domain Shift Between Datasets

Compute Wasserstein distance between clinical feature distributions to quantify dataset dissimilarity.

In [None]:
# compute pairwise Wasserstein distances
dataset_names = list(datasets.keys())
n_datasets = len(dataset_names)

# compute domain shift matrix for each clinical feature
domain_shift_matrices = {}

print("\nComputing domain shift (Wasserstein distances)...\n")
print("=" * 80)

for feat_name in CONFIG['clinical_features']:
    shift_matrix = np.zeros((n_datasets, n_datasets))
    
    for i, ds1_name in enumerate(dataset_names):
        for j, ds2_name in enumerate(dataset_names):
            if i == j:
                shift_matrix[i, j] = 0.0
            else:
                # get valid values
                vals1 = clinical_data[ds1_name]['features'][feat_name]
                vals2 = clinical_data[ds2_name]['features'][feat_name]
                
                vals1 = vals1[~np.isnan(vals1)]
                vals2 = vals2[~np.isnan(vals2)]
                
                if len(vals1) > 0 and len(vals2) > 0:
                    # compute Wasserstein distance
                    dist = wasserstein_distance(vals1, vals2)
                    shift_matrix[i, j] = dist
                else:
                    shift_matrix[i, j] = np.nan
    
    domain_shift_matrices[feat_name] = shift_matrix
    
    print(f"\n{feat_name.upper()}:")
    print("  " + "  ".join([f"{n[:8]:>10s}" for n in dataset_names]))
    for i, ds_name in enumerate(dataset_names):
        row_str = f"{ds_name[:8]:8s}  "
        row_str += "  ".join([f"{shift_matrix[i,j]:10.4f}" if not np.isnan(shift_matrix[i,j]) else "       N/A" 
                              for j in range(n_datasets)])
        print("  " + row_str)

# compute average domain shift
avg_domain_shift = np.zeros((n_datasets, n_datasets))
for feat_name in CONFIG['clinical_features']:
    avg_domain_shift += np.nan_to_num(domain_shift_matrices[feat_name], 0.0)
avg_domain_shift /= len(CONFIG['clinical_features'])

print("\n" + "=" * 80)
print("\nAVERAGE DOMAIN SHIFT (across all features):")
print("  " + "  ".join([f"{n[:8]:>10s}" for n in dataset_names]))
for i, ds_name in enumerate(dataset_names):
    row_str = f"{ds_name[:8]:8s}  "
    row_str += "  ".join([f"{avg_domain_shift[i,j]:10.4f}" for j in range(n_datasets)])
    print("  " + row_str)

print("\n" + "=" * 80)

## 5. Train Dataset-Specific Models (LODO Protocol)

Train separate Wav2Vec2 models on each dataset following Leave-One-Dataset-Out cross-validation protocol.

In [None]:
from src.models import DatasetSpecificTrainer
from torch.utils.data import random_split

# initialize trainer
trainer = DatasetSpecificTrainer(
    model_name=CONFIG['model_name'],
    num_labels=2,
    learning_rate=CONFIG['learning_rate'],
    epochs=CONFIG['epochs'],
    batch_size=CONFIG['batch_size'],
    warmup_ratio=CONFIG['warmup_ratio'],
    weight_decay=CONFIG['weight_decay'],
    freeze_feature_extractor=CONFIG['freeze_feature_extractor'],
    device=device
)

print("\nTraining dataset-specific models (LODO protocol)...\n")
print("=" * 80)

models = {}
training_histories = {}

for dataset_name, dataset in datasets.items():
    print(f"\n{'='*80}")
    print(f"TRAINING MODEL ON: {dataset_name.upper()}")
    print(f"{'='*80}\n")
    
    # split dataset (80-20 train-val)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = random_split(
        dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(CONFIG['random_seed'])
    )
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}\n")
    
    # train model
    model, metrics = trainer.train_on_dataset(
        dataset_name=dataset_name,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        output_dir=models_path
    )
    
    models[dataset_name] = model
    training_histories[dataset_name] = metrics
    
    print(f"\n{'='*80}")
    print(f"COMPLETED: {dataset_name.upper()}")
    print(f"  Best validation accuracy: {metrics['best_accuracy']:.4f}")
    print(f"  Final test accuracy: {metrics['final_accuracy']:.4f}")
    print(f"  Final F1-score: {metrics['final_f1']:.4f}")
    print(f"  Final AUC-ROC: {metrics['final_auc']:.4f}")
    print(f"{'='*80}")

print(f"\n\n{'='*80}")
print(f"TRAINING SUMMARY")
print(f"{'='*80}\n")
print(f"Total models trained: {len(models)}\n")

for dataset_name in dataset_names:
    metrics = training_histories[dataset_name]
    print(f"{dataset_name:15s}: Acc={metrics['final_accuracy']:.3f}, F1={metrics['final_f1']:.3f}, AUC={metrics['final_auc']:.3f}")

print(f"\n{'='*80}")

## 6. Build Cross-Dataset Evaluation Matrix with Confidence Intervals

Evaluate each model on all datasets to construct N×N performance matrix with bootstrap confidence intervals.

In [None]:
from src.models import CrossDatasetEvaluator

# initialize evaluator with datasets
evaluator = CrossDatasetEvaluator(datasets=datasets, device=device)

print("\nBuilding cross-dataset evaluation matrix...\n")
print("=" * 80)

# build evaluation matrix
results = evaluator.build_evaluation_matrix(
    models=models,
    batch_size=CONFIG['batch_size']
)

print("\n" + "=" * 80)
print("Evaluation complete!")
print("=" * 80)

In [None]:
# display accuracy matrix
print("\n" + "=" * 80)
print("CROSS-DATASET ACCURACY MATRIX")
print("=" * 80)
print("\n(Rows = Training Dataset, Columns = Test Dataset)\n")

# header
header = "Train\\Test    " + "  ".join([f"{n[:10]:>12s}" for n in dataset_names])
print(header)
print("-" * 80)

# rows
for train_name in dataset_names:
    row_values = []
    for test_name in dataset_names:
        acc = results.accuracy_matrix.get(train_name, {}).get(test_name, 0.0)
        row_values.append(f"{acc:12.4f}")
    print(f"{train_name[:14]:14s}  " + "  ".join(row_values))

print("\n" + "=" * 80)

In [None]:
# compute and display generalization gaps
print("\n" + "=" * 80)
print("GENERALIZATION GAPS")
print("=" * 80)
print("\n(Gap = In-Domain Accuracy - Mean Out-of-Domain Accuracy)\n")

for train_name in dataset_names:
    in_domain = results.accuracy_matrix.get(train_name, {}).get(train_name, 0.0)
    
    out_domain_accs = [
        results.accuracy_matrix.get(train_name, {}).get(test_name, 0.0)
        for test_name in dataset_names if test_name != train_name
    ]
    out_domain_mean = np.mean(out_domain_accs) if out_domain_accs else 0.0
    out_domain_std = np.std(out_domain_accs) if out_domain_accs else 0.0
    
    gap = results.generalization_gaps.get(train_name, in_domain - out_domain_mean)
    
    print(f"{train_name:15s}:")
    print(f"  In-domain:     {in_domain:.4f}")
    print(f"  Out-of-domain: {out_domain_mean:.4f} ± {out_domain_std:.4f}")
    print(f"  Gap:           {gap:+.4f}")
    print()

print("=" * 80)

## 7. Extract Layer-Wise Activations for Clinical Alignment

Extract hidden states from all trained models for clinical alignment analysis via probing.

In [None]:
from src.interpretability.extraction import ActivationExtractor

# initialize extractor
extractor = ActivationExtractor(device=device)

# extract activations for all model-dataset pairs
activations_dict = {}
activation_sample_ids_dict = {}

print("\nExtracting layer-wise activations...\n")
print("=" * 80)

for model_name, model in models.items():
    print(f"\nExtracting activations from {model_name} model...")
    
    for dataset_name, dataset in datasets.items():
        key = f"{model_name}_{dataset_name}"
        print(f"  On {dataset_name} dataset...")
        
        # extract activations
        activations, sample_ids = extractor.extract_from_dataset(
            model=model,
            dataset=dataset,
            batch_size=CONFIG['batch_size'],
            layers='all',  # extract all 12 layers
            pooling='mean'  # mean pooling over time dimension
        )
        
        activations_dict[key] = activations
        activation_sample_ids_dict[key] = sample_ids
        
        print(f"    Shape: {activations.shape} (samples × layers × hidden_dim)")
        print(f"    Samples: {len(sample_ids)}")

print("\n" + "=" * 80)
print(f"\nExtracted activations for {len(activations_dict)} model-dataset combinations")
print("=" * 80)

## 8. Clinical Alignment Analysis

Compute layer-wise probing accuracy for each clinical feature to quantify clinical alignment.

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score

def compute_layerwise_probing_accuracy(
    activations: np.ndarray,
    activation_sample_ids: List[str],
    clinical_features: Dict[str, np.ndarray],
    feature_sample_ids: List[str],
    n_layers: int = 12,
    cv_folds: int = 5
) -> Dict[str, Dict[int, float]]:
    """
    compute layer-wise probing accuracy for all clinical features.
    
    args:
        activations: [n_samples, n_layers, hidden_dim]
        activation_sample_ids: sample identifiers for activations
        clinical_features: dict of feature_name -> values
        feature_sample_ids: sample identifiers for features
        n_layers: number of layers
        cv_folds: number of cross-validation folds
    
    returns:
        dict of feature_name -> {layer_idx -> accuracy}
    """
    # create sample id mapping
    sample_to_idx = {sid: i for i, sid in enumerate(feature_sample_ids)}
    
    # find matching samples
    valid_indices = []
    valid_feature_indices = []
    
    for i, act_sid in enumerate(activation_sample_ids):
        if act_sid in sample_to_idx:
            valid_indices.append(i)
            valid_feature_indices.append(sample_to_idx[act_sid])
    
    if len(valid_indices) < 20:
        print(f"  Warning: Only {len(valid_indices)} matching samples found")
        return {}
    
    print(f"  Found {len(valid_indices)} matching samples")
    
    results = {}
    
    for feat_name, feat_values in clinical_features.items():
        print(f"\n  Probing {feat_name}...")
        
        # get feature values for matching samples
        feat_vals = feat_values[valid_feature_indices]
        
        # binarize at median for probing
        valid_feat = feat_vals[~np.isnan(feat_vals)]
        if len(valid_feat) < 20:
            print(f"    Skipping: only {len(valid_feat)} valid feature values")
            continue
        
        median_val = np.median(valid_feat)
        binary_labels = (feat_vals > median_val).astype(int)
        
        layer_accuracies = {}
        
        for layer_idx in range(n_layers):
            # get activations for this layer
            layer_acts = activations[valid_indices, layer_idx, :]
            
            # remove samples with nan features
            valid_mask = ~np.isnan(feat_vals)
            X = layer_acts[valid_mask]
            y = binary_labels[valid_mask]
            
            if len(X) < 20 or len(np.unique(y)) < 2:
                layer_accuracies[layer_idx] = 0.5
                continue
            
            # standardize
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            
            try:
                # probe with logistic regression + cross-validation
                clf = LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced')
                scores = cross_val_score(clf, X_scaled, y, cv=min(cv_folds, len(X)//2), scoring='accuracy')
                layer_accuracies[layer_idx] = float(np.mean(scores))
            except Exception as e:
                layer_accuracies[layer_idx] = 0.5
        
        results[feat_name] = layer_accuracies
        
        # print best layer
        best_layer = max(layer_accuracies, key=layer_accuracies.get)
        best_acc = layer_accuracies[best_layer]
        print(f"    Best layer: L{best_layer} (acc={best_acc:.3f})")
    
    return results

In [None]:
# compute clinical alignment profiles for all model-dataset pairs
alignment_profiles = {}

print("\nComputing clinical alignment profiles...\n")
print("=" * 80)

for model_name in models.keys():
    for dataset_name in datasets.keys():
        key = f"{model_name}_{dataset_name}"
        print(f"\nModel: {model_name} → Dataset: {dataset_name}")
        
        if key not in activations_dict:
            print("  Skipping: no activations found")
            continue
        
        if dataset_name not in clinical_data:
            print("  Skipping: no clinical features found")
            continue
        
        # extract data
        activations = activations_dict[key]
        activation_ids = activation_sample_ids_dict[key]
        clinical_feats = clinical_data[dataset_name]['features']
        clinical_ids = clinical_data[dataset_name]['sample_ids']
        
        # compute layerwise probing
        layerwise_probing = compute_layerwise_probing_accuracy(
            activations=activations,
            activation_sample_ids=activation_ids,
            clinical_features=clinical_feats,
            feature_sample_ids=clinical_ids,
            n_layers=CONFIG['n_layers'],
            cv_folds=CONFIG['probe_cv_folds']
        )
        
        if not layerwise_probing:
            print("  Skipping: probing failed")
            continue
        
        # compute feature-wise alignment scores (best layer for each feature)
        feature_scores = {}
        for feat_name, layer_accs in layerwise_probing.items():
            if layer_accs:
                feature_scores[feat_name] = max(layer_accs.values())
            else:
                feature_scores[feat_name] = 0.5
        
        # compute overall alignment score
        overall_alignment = np.mean(list(feature_scores.values())) if feature_scores else 0.5
        
        # store profile
        alignment_profiles[key] = {
            'model_name': model_name,
            'dataset_name': dataset_name,
            'layerwise_probing': layerwise_probing,
            'feature_scores': feature_scores,
            'overall_alignment': overall_alignment
        }
        
        print(f"  Overall alignment: {overall_alignment:.4f}")
        print(f"  Feature scores: " + ", ".join([f"{k}={v:.3f}" for k, v in feature_scores.items()]))

print("\n" + "=" * 80)
print(f"\nComputed {len(alignment_profiles)} alignment profiles")
print("=" * 80)