In [None]:
# Publication-quality LaTeX figure setup and helpers (same as cpu notebook)
import matplotlib as mpl
import shutil

def setup_publication_figures():
    tex_exe = shutil.which('pdflatex') or shutil.which('xelatex') or shutil.which('lualatex')
    latex_available = bool(tex_exe)
    if latex_available:
        mpl.rcParams.update({
            'text.usetex': True,
            'font.family': 'serif',
            'font.serif': ['Times'],
            'axes.labelsize': 11,
            'font.size': 11,
            'axes.titlesize': 12,
            'legend.fontsize': 10,
            'xtick.labelsize': 10,
            'ytick.labelsize': 10,
            'figure.dpi': 300,
        })
        mpl.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath}\\usepackage{siunitx}\\usepackage{bm}'
        mpl.rcParams['pdf.fonttype'] = 42
        mpl.rcParams['ps.fonttype'] = 42
        print(f"latex found: {tex_exe} -> enabled text.usetex")
    else:
        mpl.rcParams.update({
            'text.usetex': False,
            'mathtext.fontset': 'cm',
            'font.family': 'serif',
            'axes.labelsize': 11,
            'font.size': 11,
            'axes.titlesize': 12,
            'legend.fontsize': 10,
            'xtick.labelsize': 10,
            'ytick.labelsize': 10,
            'figure.dpi': 300,
        })
        print("latex not found: falling back to matplotlib mathtext. Install a TeX distribution to enable full LaTeX rendering.")
    return latex_available


def save_pub_fig(path_without_ext, fig=None, formats=('pdf','svg','png')):
    import matplotlib.pyplot as _plt
    from pathlib import Path as _Path
    fig = fig if fig is not None else _plt.gcf()
    base = _Path(path_without_ext)
    saved = []
    for fmt in formats:
        p = base.with_suffix('.' + fmt)
        fig.savefig(p, dpi=300, bbox_inches='tight', format=fmt)
        saved.append(str(p.name))
    return saved

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0, '/content/drive/MyDrive/pd-interpretability')

In [None]:
# install dependencies
!pip install -q transformers datasets librosa scipy tqdm

In [None]:
import torch
import numpy as np
import json
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings

# suppress warnings for clean output
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# verify gpu
print(f"gpu available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"gpu: {torch.cuda.get_device_name(0)}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# setup publication-quality figures early
print("setting up publication-quality figures...")
latex_enabled = setup_publication_figures()
print(f"latex rendering: {'enabled' if latex_enabled else 'disabled (using mathtext fallback)'}")

## 1. Load Fine-tuned Model and Dataset

In [None]:
# configuration
project_root = Path('/content/drive/MyDrive/pd-interpretability')

CONFIG = {
    'model_path': project_root / 'results' / 'checkpoints' / 'best_model',
    'data_path': project_root / 'data',
    'output_path': project_root / 'results' / 'patching',
    'figures_path': project_root / 'results' / 'figures',
    'n_pairs': 100,  # number of (HC, PD) pairs for patching
    'batch_size': 8,
    'random_seed': 42
}

# create output directories
CONFIG['output_path'].mkdir(parents=True, exist_ok=True)
CONFIG['figures_path'].mkdir(parents=True, exist_ok=True)

# set random seeds
np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['random_seed'])

print(f"configuration:")
print(f"  model: {CONFIG['model_path']}")
print(f"  data: {CONFIG['data_path']}")
print(f"  output: {CONFIG['output_path']}")
print(f"  figures: {CONFIG['figures_path']}")
print(f"  random seed: {CONFIG['random_seed']}")

In [None]:
from src.models import Wav2Vec2PDClassifier
from src.data import ItalianPVSDataset

# load fine-tuned model
model_path = Path(CONFIG['model_path'])

if model_path.exists():
    classifier = Wav2Vec2PDClassifier.load(model_path)
    print(f"loaded model from {model_path}")
else:
    # load base model for testing
    print("fine-tuned model not found, loading base model")
    classifier = Wav2Vec2PDClassifier(num_labels=2)

# load dataset
dataset = ItalianPVSDataset(
    root_dir=Path(CONFIG['data_path']) / 'raw' / 'italian_pvs',
    max_duration=10.0
)

print(f"dataset size: {len(dataset)} samples")
print(f"model layers: {len(classifier.model.wav2vec2.encoder.layers)}")

## 2. Create Minimal Pairs for Patching

Match HC samples with acoustically similar PD samples using MFCC distance.

In [None]:
from src.interpretability import create_mfcc_matched_pairs, create_minimal_pairs

# create mfcc-matched pairs
print("creating mfcc-matched minimal pairs...")
pairs = create_mfcc_matched_pairs(
    dataset,
    n_pairs=CONFIG['n_pairs'],
    same_task=True
)

if len(pairs) < CONFIG['n_pairs'] // 2:
    print("falling back to random matching")
    pairs = create_minimal_pairs(dataset, n_pairs=CONFIG['n_pairs'])

print(f"created {len(pairs)} minimal pairs")

# separate components
clean_inputs = [p[0] for p in pairs]
corrupted_inputs = [p[1] for p in pairs]
labels = [p[2] for p in pairs]

# show distance distribution if available
if len(pairs[0]) > 3:
    distances = [p[3] for p in pairs]
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.hist(distances, bins=30, edgecolor='black', color='steelblue', alpha=0.7)
    ax.set_xlabel('mfcc distance', fontsize=11)
    ax.set_ylabel('count', fontsize=11)
    ax.set_title('distribution of mfcc distances in minimal pairs', fontsize=12)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    save_pub_fig(CONFIG['figures_path'] / 'fig_p3_01_mfcc_distances', fig=fig)
    plt.show()
    
    print(f"mfcc distance stats: mean={np.mean(distances):.3f}, std={np.std(distances):.3f}")

## 3. Layer-Level Patching

For each layer (1-12):
- Run model on HC sample, cache layer activations
- Run model on matched PD sample with patched HC activations
- Measure: Does prediction shift toward HC?

In [None]:
from src.interpretability import ActivationPatcher

# initialize patcher
print("initializing activation patcher...")
patcher = ActivationPatcher(classifier.model, device=device)

print(f"  number of layers: {patcher.num_layers}")
print(f"  number of attention heads per layer: {patcher.num_heads}")
print(f"  hidden size: {patcher.hidden_size}")
print(f"  total attention heads: {patcher.num_layers * patcher.num_heads}")

In [None]:
# run layer-level patching with progress tracking
print("running layer-level activation patching...")
print(f"processing {len(clean_inputs)} minimal pairs across {patcher.num_layers} layers")
print("(this may take several minutes)\n")

# suppress warnings during patching
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    patching_results = patcher.run_batch_patching(
        clean_inputs,
        corrupted_inputs,
        labels,
        include_heads=False  # layer-level only first
    )

layer_patching = patching_results['layer_patching']

# display results
print("\n" + "="*60)
print("layer-level patching results:")
print("="*60)
for layer_idx in range(patcher.num_layers):
    stats = layer_patching[layer_idx]
    print(f"layer {layer_idx:2d}: mean recovery = {stats['mean_recovery']:+.3f} ± {stats['std_recovery']:.3f}")
print("="*60)

In [None]:
# visualize layer-level patching results
layers = list(range(patcher.num_layers))
mean_recoveries = [layer_patching[l]['mean_recovery'] for l in layers]
std_recoveries = [layer_patching[l]['std_recovery'] for l in layers]

fig, ax = plt.subplots(figsize=(12, 6))

bars = ax.bar(layers, mean_recoveries, yerr=std_recoveries, capsize=3,
              color='steelblue', edgecolor='black', alpha=0.8)

# highlight important layers (recovery > 0.1)
for i, (layer, recovery) in enumerate(zip(layers, mean_recoveries)):
    if recovery > 0.1:
        bars[i].set_color('darkred')

ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax.axhline(y=0.1, color='red', linestyle='--', linewidth=1, alpha=0.7, label='importance threshold (0.1)')

ax.set_xlabel('layer', fontsize=12)
ax.set_ylabel('logit difference recovery', fontsize=12)
ax.set_title('layer-level activation patching results\\n(hc $\\\\rightarrow$ pd)', fontsize=14)
ax.set_xticks(layers)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
save_pub_fig(CONFIG['figures_path'] / 'fig_p3_02_layer_patching_results', fig=fig)
plt.show()

# identify important layers
important_layers = [l for l, r in zip(layers, mean_recoveries) if r > 0.1]
print(f"\nimportant layers (recovery > 0.1): {important_layers}")
print(f"peak recovery: layer {np.argmax(mean_recoveries)} ({max(mean_recoveries):.3f})")

## 4. Attention Head-Level Patching

For layers identified as important, patch each attention head individually.

In [None]:
# run head-level patching on important layers
print(f"running head-level patching on layers: {important_layers if important_layers else 'all'}")
print("processing 50 pairs (subset for computational efficiency)\n")

head_results = []

# add progress bar for head patching
for clean, corrupted, label in tqdm(
    zip(clean_inputs[:50], corrupted_inputs[:50], labels[:50]),
    total=50,
    desc="head patching",
    ncols=100
):
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            result = patcher.run_head_patching(
                clean, corrupted, label,
                target_layers=important_layers if important_layers else None
            )
            head_results.append(result)
    except Exception as e:
        print(f"warning: head patching failed for one pair: {e}")

print(f"\ncompleted {len(head_results)}/50 head patching experiments")

In [None]:
from src.interpretability import HeadImportanceRanking

# create head importance ranking
head_ranking = HeadImportanceRanking.from_patching_results(
    head_results,
    top_k=20,
    threshold=0.05
)

print("\n" + "="*60)
print("top 20 important attention heads:")
print("="*60)
for i, (layer, head, score) in enumerate(head_ranking.head_rankings[:20]):
    print(f"{i+1:2d}. layer {layer:2d}, head {head:2d}: recovery = {score:+.4f}")
print("="*60)

print(f"\nidentified {len(head_ranking.important_heads)} important heads (threshold: 0.05)")

In [None]:
# visualize head importance as heatmap
n_layers = patcher.num_layers
n_heads = patcher.num_heads

head_matrix = np.zeros((n_layers, n_heads))

for (layer, head), score in head_ranking.head_scores.items():
    head_matrix[layer, head] = score

fig, ax = plt.subplots(figsize=(14, 8))

im = ax.imshow(head_matrix, cmap='RdBu_r', aspect='auto', vmin=-0.2, vmax=0.2)

ax.set_xlabel('attention head', fontsize=12)
ax.set_ylabel('layer', fontsize=12)
ax.set_title('attention head patching results\\n(logit difference recovery)', fontsize=14)

ax.set_xticks(range(n_heads))
ax.set_yticks(range(n_layers))

cbar = plt.colorbar(im, ax=ax)
cbar.set_label('recovery score', fontsize=11)

# mark important heads with black boxes
for layer, head in head_ranking.important_heads:
    ax.add_patch(plt.Rectangle((head-0.5, layer-0.5), 1, 1, 
                               fill=False, edgecolor='black', linewidth=2))

plt.tight_layout()
save_pub_fig(CONFIG['figures_path'] / 'fig_p3_03_head_patching_heatmap', fig=fig)
plt.show()

print(f"important heads marked with black boxes")

## 5. Mean Ablation Validation

Complement patching with ablation: replace target component activations with dataset mean.

In [None]:
# compute mean activations across dataset for ablation baseline
print("computing mean activations for ablation baseline...")
print("(processing 200 samples - this may take a few minutes)\n")

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    mean_acts = patcher.compute_mean_activations(dataset, max_samples=200)

print(f"computed mean activations for {len(mean_acts)} layers")
print(f"mean activation shape per layer: {list(mean_acts.values())[0].shape if mean_acts else 'N/A'}")

In [None]:
# run ablation validation to complement patching results
print("validating with mean ablation...")
print("(processing 100 samples)\n")

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    ablation_results = patcher.validate_with_ablation(
        dataset,
        patching_results,
        max_samples=100
    )

print("\n" + "="*60)
print("ablation effects per layer:")
print("="*60)
for layer_idx, stats in ablation_results['ablation_effects'].items():
    print(f"layer {layer_idx:2d}: mean effect = {stats['mean_effect']:+.4f} ± {stats['std_effect']:.4f}")
print("="*60)

print("\nconcordance analysis (patching vs ablation):")
print(f"  spearman correlation: {ablation_results['concordance']['spearman_correlation']:.3f}")
print(f"  p-value: {ablation_results['concordance']['p_value']:.4f}")
print(f"  interpretation: {ablation_results['concordance']['interpretation']}")

In [None]:
# visualize patching vs ablation concordance
patching_scores = [layer_patching[l]['mean_recovery'] for l in range(n_layers)]
ablation_scores = [ablation_results['ablation_effects'].get(l, {}).get('mean_effect', 0) 
                   for l in range(n_layers)]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# comparison bar plot
x = np.arange(n_layers)
width = 0.35

ax1.bar(x - width/2, patching_scores, width, label='patching recovery', color='steelblue')
ax1.bar(x + width/2, ablation_scores, width, label='ablation effect', color='coral')

ax1.set_xlabel('layer', fontsize=11)
ax1.set_ylabel('effect size', fontsize=11)
ax1.set_title('patching vs. ablation effects', fontsize=12)
ax1.set_xticks(x)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3, axis='y')

# scatter plot for correlation
ax2.scatter(patching_scores, ablation_scores, s=100, c='darkgreen', alpha=0.7)

for i, (p, a) in enumerate(zip(patching_scores, ablation_scores)):
    ax2.annotate(f'L{i}', (p, a), xytext=(5, 5), textcoords='offset points', fontsize=9)

# add trend line
z = np.polyfit(patching_scores, ablation_scores, 1)
p = np.poly1d(z)
x_line = np.linspace(min(patching_scores), max(patching_scores), 100)
ax2.plot(x_line, p(x_line), 'r--', alpha=0.7, label='trend')

ax2.set_xlabel('patching recovery', fontsize=11)
ax2.set_ylabel('ablation effect', fontsize=11)
ax2.set_title(f'concordance (r = {ablation_results["concordance"]["spearman_correlation"]:.3f})', fontsize=12)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
save_pub_fig(CONFIG['figures_path'] / 'fig_p3_04_patching_ablation_concordance', fig=fig)
plt.show()

## 6. Clinical Feature Path Patching

Test if heads with high clinical feature probing accuracy are the same heads that causally affect predictions.

In [None]:
# load clinical features for stratified patching analysis
clinical_path = Path(CONFIG['data_path']) / 'clinical_features' / 'italian_pvs_features.csv'

if clinical_path.exists():
    import pandas as pd
    clinical_df = pd.read_csv(clinical_path)
    
    # extract feature columns
    clinical_feature_cols = ['jitter_local', 'shimmer_local', 'hnr_mean', 'f0_std']
    clinical_features = {
        name: clinical_df[name].values for name in clinical_feature_cols if name in clinical_df.columns
    }
    sample_ids = clinical_df['subject_id'].values if 'subject_id' in clinical_df.columns else None
    
    print(f"loaded clinical features: {list(clinical_features.keys())}")
    print(f"total samples: {len(clinical_df)}")
else:
    print("clinical features not found at expected path")
    print(f"expected: {clinical_path}")
    print("skipping stratified analysis")
    clinical_features = None
    sample_ids = None

In [None]:
if clinical_features is not None:
    from src.interpretability import ClinicalStratifiedPatcher
    
    print("running clinical feature-stratified patching analysis...")
    print("(tests if heads encoding clinical features causally affect predictions)\n")
    
    stratified_patcher = ClinicalStratifiedPatcher(
        patcher,
        clinical_features,
        sample_ids
    )
    
    # test on jitter_local (a key PD biomarker)
    if 'jitter_local' in clinical_features:
        print("analyzing jitter_local stratification...")
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            jitter_results = stratified_patcher.run_stratified_head_patching(
                dataset,
                feature_name='jitter_local',
                target_heads=head_ranking.important_heads[:10],  # top 10 heads
                n_pairs_per_stratum=15
            )
        
        print("\n" + "="*60)
        print("jitter-stratified patching results:")
        print("="*60)
        for stratum, data in jitter_results.items():
            if stratum in ['low', 'medium', 'high']:
                print(f"\n{stratum.upper()} jitter stratum:")
                for head, stats in data.get('head_effects', {}).items():
                    print(f"  {head}: recovery = {stats['mean_recovery']:+.4f}")
        
        # differential effects (high - low)
        print("\n" + "="*60)
        print("differential effects (high jitter - low jitter):")
        print("="*60)
        for head, diff in jitter_results.get('differential_effects', {}).items():
            print(f"  {head}: {diff:+.4f}")
        print("="*60)
else:
    print("skipping stratified analysis (clinical features not loaded)")

## 7. Save Results

In [None]:
# compile all results for saving
full_results = {
    'config': CONFIG,
    'layer_patching': patching_results['layer_patching'],
    'head_patching': head_ranking.to_dict(),
    'ablation_validation': ablation_results,
    'important_layers': important_layers,
    'important_heads': [{'layer': l, 'head': h} for l, h in head_ranking.important_heads]
}

if clinical_features is not None and 'jitter_local' in clinical_features:
    full_results['clinical_stratified'] = {'jitter_local': jitter_results}

# save to json
results_path = CONFIG['output_path'] / 'patching_results.json'
with open(results_path, 'w') as f:
    json.dump(full_results, f, indent=2, default=str)

print(f"results saved to {results_path}")

# also save a summary text file
summary_path = CONFIG['output_path'] / 'patching_summary.txt'
with open(summary_path, 'w') as f:
    f.write("activation patching analysis summary\n")
    f.write("="*60 + "\n\n")
    f.write(f"total minimal pairs tested: {len(pairs)}\n")
    f.write(f"important layers: {important_layers}\n")
    f.write(f"important attention heads: {len(head_ranking.important_heads)}\n")
    f.write(f"patching-ablation concordance: {ablation_results['concordance']['spearman_correlation']:.3f}\n")

print(f"summary saved to {summary_path}")

In [None]:
# final summary
print("\n" + "=" * 60)
print("activation patching analysis complete")
print("=" * 60)
print(f"\ntotal minimal pairs tested: {len(pairs)}")
print(f"\nimportant layers (recovery > 0.1): {important_layers}")
print(f"number of important attention heads: {len(head_ranking.important_heads)}")

print(f"\ntop 5 attention heads:")
for layer, head, score in head_ranking.head_rankings[:5]:
    print(f"  layer {layer}, head {head}: {score:+.4f}")

print(f"\npatching-ablation concordance: {ablation_results['concordance']['spearman_correlation']:.3f}")
print(f"concordance interpretation: {ablation_results['concordance']['interpretation']}")

print("\n" + "=" * 60)
print("all results saved to:")
print(f"  figures: {CONFIG['figures_path']}")
print(f"  data: {CONFIG['output_path']}")
print("=" * 60)