In [None]:
# mount drive and setup
from google.colab import drive
drive.mount('/content/drive')

import os
import sys

PROJECT_ROOT = '/content/drive/MyDrive/pd-interpretability'
os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

!pip install -q -r requirements-colab.txt

In [None]:
# imports
import torch
import numpy as np
import json
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

from src.data.datasets import ItalianPVSDataset, MDVRKCLDataset
from src.models.classifier import Wav2Vec2PDClassifier
from src.interpretability.extraction import (
    Wav2Vec2ActivationExtractor,
    extract_activations_from_dataset,
    load_activations_memmap,
    AttentionExtractor
)

print(f'pytorch: {torch.__version__}')
print(f'cuda: {torch.cuda.is_available()}')

## configuration

In [None]:
# configuration
config = {
    'model_checkpoint': 'results/checkpoints/wav2vec2_pd_italian_v1/final_model',
    'dataset': 'italian_pvs',
    'max_duration': 10.0,
    'target_sr': 16000,
    'pooling': 'mean',
    'batch_size': 8,
    'extract_attention': True
}

# find latest checkpoint if default doesn't exist
checkpoint_path = Path(PROJECT_ROOT) / config['model_checkpoint']
if not checkpoint_path.exists():
    checkpoints_dir = Path(PROJECT_ROOT) / 'results' / 'checkpoints'
    if checkpoints_dir.exists():
        checkpoints = sorted(checkpoints_dir.glob('wav2vec2_pd_*'))
        if checkpoints:
            latest = checkpoints[-1]
            checkpoint_path = latest / 'final_model'
            if not checkpoint_path.exists():
                checkpoint_path = latest
            config['model_checkpoint'] = str(checkpoint_path.relative_to(PROJECT_ROOT))
            print(f'using checkpoint: {checkpoint_path}')

# output paths
output_dir = Path(PROJECT_ROOT) / 'data' / 'activations'
output_dir.mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
activations_path = output_dir / f"activations_{config['dataset']}_{timestamp}.dat"

## load model and dataset

In [None]:
# load fine-tuned model
model_path = Path(PROJECT_ROOT) / config['model_checkpoint']

classifier = Wav2Vec2PDClassifier.load(
    checkpoint_dir=model_path,
    device='cuda'
)

print(f'model loaded from {model_path}')
print(f'number of layers: {len(classifier.model.wav2vec2.encoder.layers)}')
print(f'hidden size: {classifier.model.config.hidden_size}')

In [None]:
# load dataset
data_root = Path(PROJECT_ROOT) / 'data' / 'raw'

dataset = ItalianPVSDataset(
    root_dir=str(data_root / 'italian_pvs'),
    task=None,
    max_duration=config['max_duration'],
    target_sr=config['target_sr']
)

print(f'dataset loaded: {len(dataset)} samples')

## extract activations

In [None]:
# create extractor
extractor = Wav2Vec2ActivationExtractor(
    model=classifier.model,
    device='cuda'
)

print(f'extractor initialized')
print(f'  layers: {extractor.num_layers}')
print(f'  hidden size: {extractor.hidden_size}')

In [None]:
# test extraction on single sample
sample = dataset[0]
test_activations = extractor.extract(
    sample['input_values'],
    pooling=config['pooling']
)

print('test extraction:')
for name, act in test_activations.items():
    if name.startswith('layer_'):
        print(f'  {name}: shape {act.shape}')

In [None]:
# extract activations for all samples
print(f'extracting activations for {len(dataset)} samples...')
print(f'output path: {activations_path}')

# prepare input values list
input_values_list = []
sample_metadata = []

for i in tqdm(range(len(dataset)), desc='loading samples'):
    sample = dataset[i]
    input_values_list.append(sample['input_values'])
    sample_metadata.append({
        'idx': i,
        'label': sample['label'],
        'subject_id': sample['subject_id'],
        'path': sample['path']
    })

print(f'loaded {len(input_values_list)} samples')

In [None]:
# extract to memmap
activations_memmap = extractor.extract_to_memmap(
    input_values_list=input_values_list,
    output_path=str(activations_path),
    pooling=config['pooling'],
    batch_size=config['batch_size']
)

print(f'\nactivations extracted')
print(f'  shape: {activations_memmap.shape}')
print(f'  size: {activations_memmap.nbytes / 1e6:.2f} MB')

In [None]:
# save enhanced metadata
metadata_path = str(activations_path).replace('.dat', '_metadata.json')

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

# add sample info
metadata['samples'] = sample_metadata
metadata['labels'] = [s['label'] for s in sample_metadata]
metadata['subject_ids'] = [s['subject_id'] for s in sample_metadata]
metadata['config'] = config
metadata['timestamp'] = timestamp
metadata['model_checkpoint'] = config['model_checkpoint']

with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f'metadata saved to {metadata_path}')

## extract attention weights (optional)

In [None]:
if config['extract_attention']:
    print('extracting attention weights...')
    
    attention_extractor = AttentionExtractor(
        model=classifier.model,
        device='cuda'
    )
    
    # extract for subset of samples (attention is memory intensive)
    n_attention_samples = min(50, len(dataset))
    attention_data = []
    
    for i in tqdm(range(n_attention_samples), desc='extracting attention'):
        sample = dataset[i]
        
        try:
            attentions = attention_extractor.extract_attention(
                sample['input_values'],
                layer_idx=None
            )
            
            # store summary stats (full attention matrices are too large)
            attention_summary = {
                'idx': i,
                'label': sample['label'],
                'layer_entropy': [],
                'layer_max_attention': []
            }
            
            for layer_att in attentions:
                # average over heads
                avg_att = layer_att.mean(axis=0)
                
                # compute entropy (measure of attention spread)
                att_probs = avg_att / (avg_att.sum(axis=-1, keepdims=True) + 1e-9)
                entropy = -np.sum(att_probs * np.log(att_probs + 1e-9), axis=-1).mean()
                
                # max attention weight
                max_att = avg_att.max()
                
                attention_summary['layer_entropy'].append(float(entropy))
                attention_summary['layer_max_attention'].append(float(max_att))
            
            attention_data.append(attention_summary)
            
        except Exception as e:
            print(f'failed on sample {i}: {e}')
    
    # save attention summaries
    attention_path = output_dir / f"attention_summary_{config['dataset']}_{timestamp}.json"
    with open(attention_path, 'w') as f:
        json.dump(attention_data, f, indent=2)
    
    print(f'attention summaries saved to {attention_path}')

## verification

In [None]:
# verify saved activations
loaded_activations, loaded_metadata = load_activations_memmap(activations_path)

print('verification:')
print(f'  loaded shape: {loaded_activations.shape}')
print(f'  n_samples: {loaded_metadata["n_samples"]}')
print(f'  n_layers: {loaded_metadata["n_layers"]}')
print(f'  hidden_size: {loaded_metadata["hidden_size"]}')

# check for nan/inf values
n_nan = np.isnan(loaded_activations).sum()
n_inf = np.isinf(loaded_activations).sum()
print(f'  nan values: {n_nan}')
print(f'  inf values: {n_inf}')

if n_nan > 0 or n_inf > 0:
    print('  warning: invalid values detected!')

In [None]:
# basic statistics
print('activation statistics by layer:')
print('-' * 50)

for layer_idx in range(loaded_metadata['n_layers']):
    layer_acts = loaded_activations[:, layer_idx, :]
    
    print(f'layer {layer_idx:2d}: '
          f'mean={layer_acts.mean():.4f}, '
          f'std={layer_acts.std():.4f}, '
          f'min={layer_acts.min():.4f}, '
          f'max={layer_acts.max():.4f}')

In [None]:
# check class separability (quick sanity check)
labels = np.array(loaded_metadata['labels'])

print('\nclass separability (cosine similarity within vs between classes):')

from sklearn.metrics.pairwise import cosine_similarity

for layer_idx in [0, 5, 11]:
    layer_acts = loaded_activations[:, layer_idx, :]
    
    pd_acts = layer_acts[labels == 1]
    hc_acts = layer_acts[labels == 0]
    
    # within-class similarity
    pd_sim = cosine_similarity(pd_acts).mean()
    hc_sim = cosine_similarity(hc_acts).mean()
    
    # between-class similarity
    cross_sim = cosine_similarity(pd_acts, hc_acts).mean()
    
    print(f'layer {layer_idx:2d}: '
          f'pd-pd={pd_sim:.3f}, '
          f'hc-hc={hc_sim:.3f}, '
          f'pd-hc={cross_sim:.3f}')

## summary

In [None]:
print('=' * 60)
print('EXTRACTION SUMMARY')
print('=' * 60)
print(f'model: {config["model_checkpoint"]}')
print(f'dataset: {config["dataset"]} ({len(dataset)} samples)')
print(f'pooling: {config["pooling"]}')
print()
print(f'activations shape: {loaded_activations.shape}')
print(f'activations size: {loaded_activations.nbytes / 1e6:.2f} MB')
print()
print('output files:')
print(f'  {activations_path}')
print(f'  {metadata_path}')
if config['extract_attention']:
    print(f'  {attention_path}')
print('=' * 60)
print()
print('next steps:')
print('1. download activations to local machine')
print('2. run probing experiments (notebooks/local/03_probing_analysis.ipynb)')
print('3. run patching experiments')

In [None]:
# cleanup
import gc
gc.collect()
torch.cuda.empty_cache()
print('memory cleared')