In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0, '/content/drive/MyDrive/pd-interpretability')

In [None]:
# install dependencies
!pip install -q transformers datasets librosa praat-parselmouth scipy scikit-learn tqdm

In [None]:
import numpy as np
import torch
import json
from pathlib import Path
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification

# project imports
from src.interpretability.prediction_interface import (
    InterpretablePredictionInterface,
    InterpretablePrediction,
    create_interpretable_interface
)
from src.features.clinical import ClinicalFeatureExtractor

print("imports successful!")

## 1. Configuration

In [None]:
CONFIG = {
    'project_path': '/content/drive/MyDrive/pd-interpretability',
    'model_path': '/content/drive/MyDrive/pd-interpretability/results/checkpoints/best_model.pt',
    'analysis_path': '/content/drive/MyDrive/pd-interpretability/results',
    'data_path': '/content/drive/MyDrive/pd-interpretability/data',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print(f"using device: {CONFIG['device']}")

## 2. Load Model and Create Interface

In [None]:
# load wav2vec2 processor
processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base')

# load fine-tuned model
model_path = Path(CONFIG['model_path'])

if model_path.exists():
    model = torch.load(model_path, map_location=CONFIG['device'], weights_only=False)
    print(f"loaded fine-tuned model from {model_path}")
else:
    # fallback to pretrained with random classifier
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        'facebook/wav2vec2-base',
        num_labels=2
    )
    print("using pretrained model (no fine-tuned checkpoint found)")

model = model.to(CONFIG['device'])
model.eval()
print(f"model ready on {CONFIG['device']}")

In [None]:
# create clinical feature extractor
clinical_extractor = ClinicalFeatureExtractor()
print("clinical feature extractor ready")

In [None]:
# load precomputed analysis results (if available)
analysis_path = Path(CONFIG['analysis_path'])

probing_results = None
patching_results = None

# try to load probing results
probing_file = analysis_path / 'probing' / 'probing_results.json'
if probing_file.exists():
    with open(probing_file) as f:
        probing_data = json.load(f)
    # convert to expected format
    probing_results = {
        feat: {int(k): v.get('mean', v) if isinstance(v, dict) else v
               for k, v in layers.items()}
        for feat, layers in probing_data.items()
        if isinstance(layers, dict)
    }
    print(f"loaded probing results for {len(probing_results)} features")

# try to load patching results
patching_file = analysis_path / 'patching' / 'head_importance.json'
if patching_file.exists():
    with open(patching_file) as f:
        patching_data = json.load(f)
    # convert to expected format (tuple keys)
    patching_results = {
        tuple(map(int, k.split(','))): v
        for k, v in patching_data.items()
    }
    print(f"loaded patching results for {len(patching_results)} heads")

In [None]:
# create the interpretable prediction interface
interface = create_interpretable_interface(
    model=model,
    processor=processor,
    clinical_extractor=clinical_extractor,
    probing_results=probing_results,
    patching_results=patching_results,
    device=CONFIG['device']
)

print("\n=== Interpretable Prediction Interface Created ===")
print(f"Evidence layers: {interface._evidence_layers}")
print(f"Key attention heads: {interface._key_heads[:5]}...")

## 3. Load Test Data

In [None]:
import librosa

# load some test samples
data_path = Path(CONFIG['data_path']) / 'raw' / 'italian_pvs'

test_samples = []

# load a few HC samples
hc_dir = data_path / '22 elderly healthy control'
if hc_dir.exists():
    for subject_dir in list(hc_dir.iterdir())[:3]:
        if subject_dir.is_dir():
            audio_files = list(subject_dir.glob('*.txt'))
            if audio_files:
                # these are actually audio files with .txt extension
                audio_path = audio_files[0]
                try:
                    audio, sr = librosa.load(audio_path, sr=16000)
                    test_samples.append({
                        'audio': audio,
                        'sample_rate': sr,
                        'label': 0,
                        'subject_id': subject_dir.name
                    })
                except:
                    pass

# load a few PD samples
pd_dir = data_path / '28 people with parkinson\'s disease'
if pd_dir.exists():
    for subgroup in pd_dir.iterdir():
        if subgroup.is_dir():
            for subject_dir in list(subgroup.iterdir())[:1]:
                if subject_dir.is_dir():
                    audio_files = list(subject_dir.glob('*.*'))
                    if audio_files:
                        try:
                            audio, sr = librosa.load(audio_files[0], sr=16000)
                            test_samples.append({
                                'audio': audio,
                                'sample_rate': sr,
                                'label': 1,
                                'subject_id': subject_dir.name
                            })
                        except:
                            pass

print(f"loaded {len(test_samples)} test samples")

# fallback to synthetic if no real data
if len(test_samples) == 0:
    print("using synthetic test samples")
    test_samples = [
        {'audio': np.random.randn(48000).astype(np.float32), 'sample_rate': 16000, 'label': i % 2, 'subject_id': f'synth_{i}'}
        for i in range(6)
    ]

## 4. Generate Interpretable Predictions

In [None]:
# run predictions
print("Generating interpretable predictions...\n")
print("=" * 80)

for sample in test_samples:
    # make prediction
    prediction = interface.predict(
        audio=sample['audio'],
        sample_rate=sample['sample_rate'],
        include_clinical=True
    )
    
    # display results
    true_label = 'PD' if sample['label'] == 1 else 'HC'
    pred_label = 'PD' if prediction.pd_probability >= 0.5 else 'HC'
    correct = '✓' if (true_label == pred_label) else '✗'
    
    print(f"\nSubject: {sample['subject_id']}")
    print(f"True label: {true_label}, Predicted: {pred_label} {correct}")
    print(f"PD Probability: {prediction.pd_probability:.3f}")
    print(f"Confidence: {prediction.confidence:.3f}")
    
    print("\nTop Feature Contributions:")
    for feat, score in prediction.get_top_features(3):
        print(f"  • {feat}: {score:+.3f}")
    
    if prediction.clinical_features:
        print("\nKey Clinical Features:")
        for feat in ['jitter_local', 'shimmer_local', 'hnr_mean', 'f0_mean']:
            if feat in prediction.clinical_features:
                val = prediction.clinical_features[feat]
                if not np.isnan(val):
                    print(f"  • {feat}: {val:.4f}")
    
    print(f"\nEvidence Layers: {prediction.evidence_layers[:5]}")
    print(f"Key Heads: {prediction.key_attention_heads[:3]}")
    print("-" * 80)

## 5. Examine Full Prediction Output

In [None]:
# show full JSON output for one sample
if test_samples:
    sample = test_samples[0]
    prediction = interface.predict(
        audio=sample['audio'],
        sample_rate=sample['sample_rate'],
        include_clinical=True
    )
    
    print("Full Prediction Output (JSON format):")
    print("=" * 50)
    print(prediction.to_json())

## 6. Generate Natural Language Explanation

In [None]:
# generate human-readable explanation
if test_samples:
    sample = test_samples[0]
    
    explanation = interface.explain_prediction(
        audio=sample['audio'],
        sample_rate=sample['sample_rate'],
        format='text'
    )
    
    print("Natural Language Explanation:")
    print("=" * 50)
    print(explanation)

In [None]:
# markdown format explanation
if test_samples:
    explanation_md = interface.explain_prediction(
        audio=test_samples[0]['audio'],
        sample_rate=test_samples[0]['sample_rate'],
        format='markdown'
    )
    
    from IPython.display import display, Markdown
    display(Markdown(explanation_md))

## 7. Batch Predictions

In [None]:
# batch prediction
audio_list = [s['audio'] for s in test_samples]

predictions = interface.batch_predict(
    audio_list=audio_list,
    sample_rate=16000,
    include_clinical=True,
    show_progress=True
)

print(f"\nGenerated {len(predictions)} predictions")

# summary statistics
probs = [p.pd_probability for p in predictions]
confs = [p.confidence for p in predictions]

print(f"\nSummary:")
print(f"  Mean PD probability: {np.mean(probs):.3f}")
print(f"  Mean confidence: {np.mean(confs):.3f}")
print(f"  Predicted as PD: {sum(1 for p in probs if p >= 0.5)}")
print(f"  Predicted as HC: {sum(1 for p in probs if p < 0.5)}")

## 8. Save Results

In [None]:
# save all predictions
output_dir = Path(CONFIG['project_path']) / 'results' / 'phase5_synthesis'
output_dir.mkdir(parents=True, exist_ok=True)

# save individual predictions
for sample, prediction in zip(test_samples, predictions):
    prediction.metadata['subject_id'] = sample['subject_id']
    prediction.metadata['true_label'] = sample['label']
    
    interface.save_prediction(
        prediction,
        output_dir / f"{sample['subject_id']}_prediction.json"
    )

# save summary
summary = {
    'n_samples': len(predictions),
    'mean_pd_probability': float(np.mean(probs)),
    'mean_confidence': float(np.mean(confs)),
    'n_predicted_pd': sum(1 for p in probs if p >= 0.5),
    'n_predicted_hc': sum(1 for p in probs if p < 0.5)
}

with open(output_dir / 'prediction_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"Results saved to {output_dir}")

## Summary

This notebook demonstrated the **Interpretable Prediction Interface**, which:

1. **Synthesizes** all mechanistic interpretability analyses into a single interface
2. **Produces predictions** with probability, confidence, and explanations
3. **Identifies** which clinical features (jitter, shimmer, HNR) drive predictions
4. **Reveals** which transformer layers encode PD-relevant information
5. **Highlights** key attention heads with causal importance

### Output Format

```json
{
    "pd_probability": 0.87,
    "feature_contributions": {
        "jitter_elevated": 0.34,
        "hnr_reduced": 0.28,
        "f0_unstable": 0.21
    },
    "evidence_layers": [3, 4, 7],
    "key_attention_heads": [[3, 4], [4, 2], [7, 8]]
}
```

This enables **transparent, explainable** PD detection from speech.