In [None]:
# ============================================================
# phase 4: activation extraction for mechanistic interpretability
# ============================================================
# extracts intermediate layer representations from fine-tuned wav2vec2
# for probing classifiers and activation patching experiments.
#
# requirements:
# - trained wav2vec2 model checkpoint (from phase 3)
# - italian pvs or neurovoz dataset
# - gpu recommended (cuda or mps)
# ============================================================

# mount google drive
from google.colab import drive
drive.mount('/content/drive')

import os
import sys

project_root = '/content/drive/MyDrive/pd-interpretability'
os.chdir(project_root)
sys.path.insert(0, project_root)

print("=" * 60)
print("phase 4: activation extraction setup")
print("=" * 60)
print(f"project root: {project_root}")
print(f"working directory: {os.getcwd()}")

# install dependencies
print("\ninstalling dependencies...")
!pip install -q -r requirements-colab.txt
print("dependencies installed ✓")

In [None]:
# imports
import torch
import numpy as np
import pandas as pd
import json
import gc
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

from transformers import (
    Wav2Vec2ForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Config
)

from src.data.datasets import ItalianPVSDataset, NeuroVozDataset
from src.interpretability.extraction import (
    Wav2Vec2ActivationExtractor,
    load_activations_memmap,
    AttentionExtractor
)

print("=" * 60)
print("imports complete")
print("=" * 60)
print(f"pytorch: {torch.__version__}")
print(f"cuda available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"cuda device: {torch.cuda.get_device_name(0)}")
    print(f"vram: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} gb")
print("=" * 60)

## configuration

In [None]:
# ============================================================
# configuration
# ============================================================
# choose dataset: 'italian_pvs' or 'neurovoz'
dataset_name = 'italian_pvs'  # change to 'neurovoz' for neurovoz dataset

config = {
    'dataset': dataset_name,
    'max_duration': 10.0,
    'target_sr': 16000,
    'pooling': 'mean',
    'batch_size': 8,
    'extract_attention': True,
    'n_attention_samples': 100  # attention is memory-intensive
}

# ============================================================
# find model checkpoint
# ============================================================
project_path = Path(project_root)

# primary location: results/final_model (saved during training)
primary_checkpoint = project_path / 'results' / 'final_model'

# fallback: look in checkpoints directory for latest loso run
checkpoints_dir = project_path / 'results' / 'checkpoints'

checkpoint_path = None

if primary_checkpoint.exists():
    checkpoint_path = primary_checkpoint
    print(f"found primary checkpoint: {checkpoint_path}")
else:
    print(f"primary checkpoint not found at: {primary_checkpoint}")
    
    if checkpoints_dir.exists():
        # find latest wav2vec2_loso_* checkpoint with final_model subdirectory
        loso_checkpoints = sorted(checkpoints_dir.glob('wav2vec2_loso_*'))
        
        for ckpt in reversed(loso_checkpoints):  # newest first
            final_model_path = ckpt / 'final_model'
            if final_model_path.exists() and (final_model_path / 'model.safetensors').exists():
                checkpoint_path = final_model_path
                print(f"found checkpoint: {checkpoint_path}")
                break
            elif (ckpt / 'model.safetensors').exists():
                checkpoint_path = ckpt
                print(f"found checkpoint: {checkpoint_path}")
                break

if checkpoint_path is None:
    raise FileNotFoundError(
        "no trained model checkpoint found!\n"
        "please ensure phase 3 training is complete and model is saved to:\n"
        f"  - {primary_checkpoint}\n"
        f"  - or {checkpoints_dir}/wav2vec2_loso_*/final_model/"
    )

# output paths
output_dir = project_path / 'data' / 'activations'
output_dir.mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
activations_filename = f"activations_{config['dataset']}_{timestamp}.dat"
activations_path = output_dir / activations_filename

print("=" * 60)
print("configuration")
print("=" * 60)
print(f"dataset:         {config['dataset']}")
print(f"checkpoint:      {checkpoint_path}")
print(f"output:          {activations_path}")
print(f"pooling:         {config['pooling']}")
print(f"batch size:      {config['batch_size']}")
print(f"extract attention: {config['extract_attention']}")
print("=" * 60)

## load model and dataset

In [None]:
# ============================================================
# load fine-tuned model
# ============================================================
print("=" * 60)
print("step 1: loading fine-tuned wav2vec2 model")
print("=" * 60)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load model config and weights with error handling
try:
    model_config = Wav2Vec2Config.from_pretrained(checkpoint_path)
    model = Wav2Vec2ForSequenceClassification.from_pretrained(checkpoint_path)
    model = model.to(device)
    model.eval()
except Exception as e:
    raise RuntimeError(
        f"failed to load model from {checkpoint_path}\n"
        f"error: {str(e)}\n\n"
        f"possible causes:\n"
        f"  - checkpoint files are corrupted or missing\n"
        f"  - checkpoint was saved with incompatible transformers version\n"
        f"  - checkpoint directory is incomplete (missing config.json or model.safetensors)\n\n"
        f"please verify the checkpoint exists and is complete."
    ) from e

# load feature extractor from base model (not saved with checkpoint)
# this is safe because we only use it for dataset loading, not inference
try:
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
        'facebook/wav2vec2-base-960h'
    )
except Exception as e:
    raise RuntimeError(
        f"failed to load wav2vec2 feature extractor\n"
        f"error: {str(e)}\n\n"
        f"this may indicate missing transformers library or network issues."
    ) from e

print(f"\nmodel loaded successfully ✓")
print(f"  checkpoint: {checkpoint_path.name}")
print(f"  device: {device}")
print(f"  num_labels: {model_config.num_labels}")
print(f"  num_layers: {len(model.wav2vec2.encoder.layers)}")
print(f"  hidden_size: {model_config.hidden_size}")
print(f"  num_attention_heads: {model_config.num_attention_heads}")

# count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"  total parameters: {total_params:,}")
print("=" * 60)

In [None]:
# ============================================================
# load dataset
# ============================================================
print("=" * 60)
print("step 2: loading dataset")
print("=" * 60)

data_root = project_path / 'data' / 'raw'

if config['dataset'] == 'italian_pvs':
    dataset = ItalianPVSDataset(
        root_dir=str(data_root / 'italian_pvs'),
        task=None,  # all tasks
        max_duration=config['max_duration'],
        target_sr=config['target_sr']
    )
elif config['dataset'] == 'neurovoz':
    dataset = NeuroVozDataset(
        root_dir=str(data_root / 'neurovoz'),
        task=None,  # all tasks
        max_duration=config['max_duration'],
        target_sr=config['target_sr']
    )
else:
    raise ValueError(f"unknown dataset: {config['dataset']}")

# get dataset statistics
labels = [dataset.samples[i]['label'] for i in range(len(dataset))]
n_hc = sum(1 for l in labels if l == 0)
n_pd = sum(1 for l in labels if l == 1)

print(f"\ndataset: {config['dataset']}")
print(f"  total samples: {len(dataset):,}")
print(f"  hc samples: {n_hc:,}")
print(f"  pd samples: {n_pd:,}")
print(f"  class balance: {n_pd / len(dataset) * 100:.1f}% pd")
print("=" * 60)

## model validation

verify that the loaded model produces valid predictions before extraction.
a degenerate model (predicting all one class) should not be used for interpretability analysis.

In [None]:
# ============================================================
# model validation - check for degenerate predictions
# ============================================================
print("=" * 60)
print("step 3: model validation")
print("=" * 60)

from sklearn.metrics import accuracy_score, roc_auc_score

# run predictions on a sample of the dataset
n_validation_samples = min(100, len(dataset))
validation_indices = np.linspace(0, len(dataset)-1, n_validation_samples, dtype=int)

all_preds = []
all_probs = []
all_labels = []

print(f"validating on {n_validation_samples} samples...")

with torch.no_grad():
    for idx in tqdm(validation_indices, desc="validating model"):
        sample = dataset[idx]
        input_values = sample['input_values']
        
        if isinstance(input_values, np.ndarray):
            input_values = torch.from_numpy(input_values)
        
        input_values = input_values.unsqueeze(0).to(device)
        
        outputs = model(input_values)
        probs = torch.softmax(outputs.logits, dim=-1)
        pred = torch.argmax(probs, dim=-1)
        
        all_preds.append(pred.cpu().item())
        all_probs.append(probs[0, 1].cpu().item())  # pd probability
        all_labels.append(sample['label'])

all_preds = np.array(all_preds)
all_probs = np.array(all_probs)
all_labels = np.array(all_labels)

# calculate metrics
val_accuracy = accuracy_score(all_labels, all_preds)
val_auc = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else 0.5

# check prediction diversity
n_pred_0 = (all_preds == 0).sum()
n_pred_1 = (all_preds == 1).sum()

print(f"\nvalidation results:")
print(f"  accuracy: {val_accuracy:.1%}")
print(f"  auc-roc:  {val_auc:.3f}")
print(f"  predictions: {n_pred_0} hc, {n_pred_1} pd")
print(f"  labels:      {(all_labels == 0).sum()} hc, {(all_labels == 1).sum()} pd")

# check for degenerate model
is_degenerate = False
if n_pred_0 == 0 or n_pred_1 == 0:
    print("\n⚠️  WARNING: model predicts only one class!")
    print("    this is a degenerate model - interpretability results will be meaningless")
    is_degenerate = True
elif val_auc < 0.55:
    print("\n⚠️  WARNING: model has near-random performance (auc < 0.55)")
    print("    consider using a better-trained checkpoint")
    is_degenerate = True
else:
    print("\n✓ model produces valid predictions")

print("=" * 60)

# ask user to confirm if model is degenerate
if is_degenerate:
    print("\nthe loaded model appears to be degenerate or poorly trained.")
    print("activation extraction will proceed, but results may not be meaningful.")
    print("consider waiting for full loso training to complete.")

# cleanup validation arrays to free memory before extraction
del all_preds, all_probs, all_labels
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("\n✓ gpu memory cleared after validation")
elif device == 'mps':
    torch.mps.empty_cache()
    print("\n✓ mps memory cleared after validation")

In [None]:
# ============================================================
# create activation extractor
# ============================================================
print("=" * 60)
print("step 4: initializing activation extractor")
print("=" * 60)

extractor = Wav2Vec2ActivationExtractor(
    model=model,
    device=device
)

print(f"\nextractor initialized ✓")
print(f"  num_layers: {extractor.num_layers}")
print(f"  hidden_size: {extractor.hidden_size}")
print(f"  pooling: {config['pooling']}")
print("=" * 60)

In [None]:
# ============================================================
# test extraction on single sample
# ============================================================
print("testing extraction on single sample...")

sample = dataset[0]
input_values = sample['input_values']

if isinstance(input_values, np.ndarray):
    input_values = torch.from_numpy(input_values)

test_activations = extractor.extract(
    input_values,
    pooling=config['pooling']
)

print("\ntest extraction results:")
for name, act in sorted(test_activations.items()):
    if name.startswith('layer_'):
        print(f"  {name}: shape {act.shape}, mean {act.mean():.4f}, std {act.std():.4f}")

print(f"\ncnn_features: shape {test_activations.get('cnn_features', np.array([])).shape}")
print("=" * 60)

In [None]:
# ============================================================
# prepare samples for extraction
# ============================================================
print("=" * 60)
print("step 5: preparing samples for extraction")
print("=" * 60)

input_values_list = []
sample_metadata = []

print(f"\nloading {len(dataset)} samples...")

for i in tqdm(range(len(dataset)), desc="loading samples"):
    sample = dataset[i]
    
    input_values = sample['input_values']
    if isinstance(input_values, np.ndarray):
        input_values = torch.from_numpy(input_values)
    
    input_values_list.append(input_values)
    
    sample_metadata.append({
        'idx': i,
        'label': sample['label'],
        'subject_id': sample['subject_id'],
        'path': str(sample['path']),
        'task': sample.get('task', 'unknown')
    })

print(f"\n✓ loaded {len(input_values_list)} samples")
print(f"  hc: {sum(1 for s in sample_metadata if s['label'] == 0)}")
print(f"  pd: {sum(1 for s in sample_metadata if s['label'] == 1)}")
print("=" * 60)

In [None]:
# ============================================================
# extract activations to memmap
# ============================================================
print("=" * 60)
print("step 6: extracting activations")
print("=" * 60)
print(f"\nextracting activations for {len(input_values_list)} samples...")
print(f"output path: {activations_path}")
print(f"this may take 10-30 minutes depending on dataset size and gpu.")

import time
start_time = time.time()

activations_memmap = extractor.extract_to_memmap(
    input_values_list=input_values_list,
    output_path=str(activations_path),
    pooling=config['pooling'],
    batch_size=config['batch_size']
)

extraction_time = time.time() - start_time

print(f"\n✓ extraction complete!")
print(f"  time elapsed: {extraction_time/60:.1f} minutes")
print(f"  shape: {activations_memmap.shape}")
print(f"  size: {activations_memmap.nbytes / 1e6:.2f} MB")
print(f"  rate: {len(input_values_list) / extraction_time:.1f} samples/sec")
print("=" * 60)

## save enhanced metadata

save sample metadata alongside activations for probing experiments.

In [None]:
# ============================================================
# save enhanced metadata
# ============================================================
print("=" * 60)
print("step 7: saving enhanced metadata")
print("=" * 60)

metadata_path = str(activations_path).replace('.dat', '_metadata.json')

# load existing metadata with error handling
try:
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    print(f"✓ loaded existing metadata from {metadata_path}")
except FileNotFoundError:
    print(f"⚠️  warning: metadata file not found at {metadata_path}")
    print("    creating new metadata from scratch...")
    
    # create new metadata structure matching what extract_to_memmap should produce
    metadata = {
        'shape': list(activations_memmap.shape),
        'n_samples': len(input_values_list),
        'n_layers': extractor.num_layers,
        'hidden_size': extractor.hidden_size,
        'pooling': config['pooling'],
        'dtype': 'float32'
    }
except json.JSONDecodeError as e:
    raise RuntimeError(
        f"metadata file exists but contains invalid json: {metadata_path}\n"
        f"error: {str(e)}\n\n"
        f"the file may be corrupted. consider deleting it and re-running extraction."
    ) from e

# add comprehensive sample info
metadata['samples'] = sample_metadata
metadata['labels'] = [s['label'] for s in sample_metadata]
metadata['subject_ids'] = [s['subject_id'] for s in sample_metadata]
metadata['tasks'] = [s['task'] for s in sample_metadata]

# add configuration
metadata['config'] = config
metadata['timestamp'] = timestamp
metadata['model_checkpoint'] = str(checkpoint_path)
metadata['dataset_name'] = config['dataset']

# add validation results (use defaults if validation was skipped)
_val_accuracy = val_accuracy if 'val_accuracy' in dir() else None
_val_auc = val_auc if 'val_auc' in dir() else None
_n_val_samples = n_validation_samples if 'n_validation_samples' in dir() else 0
_is_degenerate = is_degenerate if 'is_degenerate' in dir() else None

metadata['validation'] = {
    'accuracy': float(_val_accuracy) if _val_accuracy is not None else None,
    'auc': float(_val_auc) if _val_auc is not None else None,
    'n_samples': int(_n_val_samples),
    'is_degenerate': _is_degenerate
}

# save updated metadata
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"✓ metadata saved to {metadata_path}")
print(f"  n_samples: {metadata['n_samples']}")
print(f"  n_layers: {metadata['n_layers']}")
print(f"  hidden_size: {metadata['hidden_size']}")
print("=" * 60)

## extract attention weights (optional)

extract attention patterns for visualization and analysis.
this is memory-intensive so we limit to a subset of samples.

In [None]:
# ============================================================
# extract attention weights
# ============================================================
if config['extract_attention']:
    print("=" * 60)
    print("step 8: extracting attention weights")
    print("=" * 60)
    
    attention_extractor = AttentionExtractor(
        model=model,
        device=device
    )
    
    n_attention_samples = min(config['n_attention_samples'], len(dataset))
    print(f"\nextracting attention for {n_attention_samples} samples...")
    
    attention_data = []
    failed_attention = 0
    
    for i in tqdm(range(n_attention_samples), desc="extracting attention"):
        sample = dataset[i]
        
        try:
            input_values = sample['input_values']
            if isinstance(input_values, np.ndarray):
                input_values = torch.from_numpy(input_values)
            
            attentions = attention_extractor.extract_attention(
                input_values,
                layer_idx=None  # all layers
            )
            
            # compute summary statistics (full matrices too large to store)
            attention_summary = {
                'idx': i,
                'label': sample['label'],
                'subject_id': sample['subject_id'],
                'layer_entropy': [],
                'layer_max_attention': [],
                'layer_mean_attention': []
            }
            
            for layer_att in attentions:
                # average over heads: [num_heads, seq_len, seq_len] -> [seq_len, seq_len]
                avg_att = layer_att.mean(axis=0)
                
                # entropy (measure of attention spread)
                att_flat = avg_att.flatten()
                att_probs = att_flat / (att_flat.sum() + 1e-9)
                entropy = -np.sum(att_probs * np.log(att_probs + 1e-9))
                
                attention_summary['layer_entropy'].append(float(entropy))
                attention_summary['layer_max_attention'].append(float(avg_att.max()))
                attention_summary['layer_mean_attention'].append(float(avg_att.mean()))
            
            attention_data.append(attention_summary)
            
        except Exception as e:
            failed_attention += 1
            if failed_attention <= 3:
                print(f"  warning: failed on sample {i}: {str(e)[:50]}")
    
    # save attention summaries
    attention_path = output_dir / f"attention_summary_{config['dataset']}_{timestamp}.json"
    
    with open(attention_path, 'w') as f:
        json.dump({
            'data': attention_data,
            'n_samples': len(attention_data),
            'n_failed': failed_attention,
            'n_layers': extractor.num_layers,
            'timestamp': timestamp,
            'dataset': config['dataset']
        }, f, indent=2)
    
    print(f"\n✓ attention summaries saved to {attention_path}")
    print(f"  extracted: {len(attention_data)} samples")
    print(f"  failed: {failed_attention} samples")
    print("=" * 60)
else:
    attention_path = None
    print("skipping attention extraction (disabled in config)")

In [None]:
# ============================================================
# verification: reload and check activations
# ============================================================
print("=" * 60)
print("step 9: verification")
print("=" * 60)

loaded_activations, loaded_metadata = load_activations_memmap(activations_path)

print(f"\nloaded activations:")
print(f"  shape: {loaded_activations.shape}")
print(f"  n_samples: {loaded_metadata['n_samples']}")
print(f"  n_layers: {loaded_metadata['n_layers']}")
print(f"  hidden_size: {loaded_metadata['hidden_size']}")

# check for invalid values
n_nan = np.isnan(loaded_activations).sum()
n_inf = np.isinf(loaded_activations).sum()

print(f"\ndata quality:")
print(f"  nan values: {n_nan}")
print(f"  inf values: {n_inf}")

if n_nan > 0 or n_inf > 0:
    print("  ⚠️ warning: invalid values detected!")
else:
    print("  ✓ all values valid")

print("=" * 60)

In [None]:
# ============================================================
# activation statistics by layer
# ============================================================
print("activation statistics by layer:")
print("-" * 60)

for layer_idx in range(loaded_metadata['n_layers']):
    layer_acts = loaded_activations[:, layer_idx, :]
    
    print(f"layer {layer_idx:2d}: "
          f"mean={layer_acts.mean():+.4f}, "
          f"std={layer_acts.std():.4f}, "
          f"min={layer_acts.min():+.4f}, "
          f"max={layer_acts.max():+.4f}")

print("-" * 60)

## class separability analysis

quick sanity check: measure cosine similarity within and between classes.
if the model learned meaningful representations, within-class similarity should exceed between-class similarity.

In [None]:
# ============================================================
# class separability analysis
# ============================================================
print("=" * 60)
print("class separability (cosine similarity)")
print("=" * 60)

from sklearn.metrics.pairwise import cosine_similarity

labels = np.array(loaded_metadata['labels'])

print(f"\nlayer   pd-pd   hc-hc   pd-hc   separability")
print("-" * 50)

separability_scores = []

for layer_idx in [0, 3, 6, 9, 11]:  # sample of layers
    layer_acts = loaded_activations[:, layer_idx, :]
    
    pd_acts = layer_acts[labels == 1]
    hc_acts = layer_acts[labels == 0]
    
    # within-class similarity (sample to reduce compute)
    n_sample = min(50, len(pd_acts), len(hc_acts))
    pd_sample = pd_acts[np.random.choice(len(pd_acts), n_sample, replace=False)]
    hc_sample = hc_acts[np.random.choice(len(hc_acts), n_sample, replace=False)]
    
    pd_sim = cosine_similarity(pd_sample).mean()
    hc_sim = cosine_similarity(hc_sample).mean()
    cross_sim = cosine_similarity(pd_sample, hc_sample).mean()
    
    # separability: how much more similar within-class than between-class
    avg_within = (pd_sim + hc_sim) / 2
    separability = avg_within - cross_sim
    separability_scores.append((layer_idx, separability))
    
    print(f"layer {layer_idx:2d}  {pd_sim:.3f}   {hc_sim:.3f}   {cross_sim:.3f}   {separability:+.4f}")

# find most separable layer
best_layer = max(separability_scores, key=lambda x: x[1])
print("-" * 50)
print(f"best separability: layer {best_layer[0]} (score: {best_layer[1]:+.4f})")
print("=" * 60)

In [None]:
# ============================================================
# extraction summary
# ============================================================
print("\n" + "=" * 60)
print("PHASE 4 EXTRACTION COMPLETE")
print("=" * 60)
print(f"\nmodel checkpoint: {checkpoint_path}")
print(f"dataset: {config['dataset']} ({len(dataset)} samples)")
print(f"pooling: {config['pooling']}")

# report validation results if available
if 'val_accuracy' in dir() and val_accuracy is not None:
    print(f"\nmodel validation:")
    print(f"  accuracy: {val_accuracy:.1%}")
    print(f"  auc-roc:  {val_auc:.3f}")
    if 'is_degenerate' in dir() and is_degenerate:
        print(f"  ⚠️ WARNING: model appears degenerate!")
else:
    print(f"\nmodel validation: skipped")

print(f"\nactivations shape: {loaded_activations.shape}")
print(f"activations size: {loaded_activations.nbytes / 1e6:.2f} MB")
print(f"\ngenerated files:")
print(f"  ├── {activations_path.name}")
print(f"  ├── {activations_path.name.replace('.dat', '_metadata.json')}")
if 'attention_path' in dir() and attention_path:
    print(f"  └── {attention_path.name}")
print(f"\nnext steps:")
print(f"  1. download activations to local machine (if running on colab)")
print(f"  2. run probing experiments (notebooks/colab/05_activation_patching.ipynb)")
print(f"  3. run probing classifiers (notebooks/colab/07_probing_experiments.ipynb)")
print("=" * 60)

# cleanup gpu memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("\ngpu memory cleared ✓")