# SHAP-Based Prior Extraction

**Goal:** Extract Bayesian priors (β₀, Σ₀) from trained CatBoost model using SHAP values

This notebook implements **Algorithm 4.2** from Section 4.2.3 (Prior Distribution Extraction).

## What We'll Do

1. Load trained CatBoost model
2. Compute SHAP values on validation set
3. Extract prior means (β₀) from normalized SHAP values
4. Compute prior variances (Σ₀) from cross-dataset heterogeneity
5. Generate **Table 4.6** (Extracted Prior Distributions)
6. Validate priors through predictive checks (**Table 4.7**)
7. Save priors for hierarchical model

---

## Setup

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pickle
import warnings

# SHAP for explainability
import shap

# CatBoost
from catboost import CatBoostClassifier

# SmallML framework
from src.layer1_transfer.shap_extractor import SHAPPriorExtractor

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

print("✓ All imports successful")

## 1. Load Trained Model and Data

In [None]:
# Load trained CatBoost model from Day 10
model_path = '../models/transfer_learning/catboost_base.cbm'
model = CatBoostClassifier()
model.load_model(model_path)

print(f"✓ Model loaded: {model.tree_count_} trees")

In [None]:
# Load training data (for normalization statistics)
X_train = pd.read_csv('../data/harmonized/X_train.csv')
y_train = pd.read_csv('../data/harmonized/y_train.csv')['churned']

print(f"✓ Training data: {X_train.shape[0]:,} samples, {X_train.shape[1]} features")

In [None]:
# Load validation data
X_val = pd.read_csv('../data/harmonized/X_val.csv')
y_val = pd.read_csv('../data/harmonized/y_val.csv')['churned']

print(f"✓ Validation data: {X_val.shape[0]:,} samples, {X_val.shape[1]} features")

In [None]:
# Load dataset source labels (need to recreate split from D_public_processed.csv)
D_public = pd.read_csv('../data/harmonized/D_public_processed.csv')

print(f"✓ Full processed dataset: {D_public.shape[0]:,} samples")

In [None]:
# Recreate split to extract dataset_source labels
from sklearn.model_selection import train_test_split

X_full = D_public.drop(columns=['churned', 'dataset_source'])
y_full = D_public['churned']
dataset_source_full = D_public['dataset_source']

# Same split as Day 9 (random_state=42)
X_train_check, X_val_check, y_train_check, y_val_check, _, dataset_source_val = train_test_split(
    X_full, y_full, dataset_source_full,
    test_size=0.2,
    stratify=y_full,
    random_state=42
)

# Verify split matches
assert len(X_val_check) == len(X_val), "Split mismatch!"

print(f"✓ Dataset source labels extracted: {len(dataset_source_val):,} validation samples")
print(f"  Dataset distribution:")
print(dataset_source_val.value_counts())

## 2. Initialize SHAP Prior Extractor

In [None]:
# Initialize extractor with λ=1.0 (doubles empirical variance for conservatism)
extractor = SHAPPriorExtractor(
    model=model,
    X_train=X_train,
    lambda_scale=1.0,
    random_seed=42
)

print("✓ SHAPPriorExtractor initialized")
print(f"  Features: {len(extractor.feature_names_)}")
print(f"  Scaling factor: λ = {extractor.lambda_scale}")

## 3. Compute SHAP Values (Algorithm 4.2, Step 1)

**Note:** This step takes 10-20 minutes depending on your CPU. Go grab a coffee! ☕

In [None]:
# Compute SHAP values on validation set
shap_values = extractor.compute_shap_values(X_val, verbose=True)

### Visualize SHAP Summary Plot

In [None]:
# SHAP summary plot: Top 20 features
plt.figure(figsize=(10, 8))
shap.summary_plot(
    shap_values,
    X_val,
    max_display=20,
    show=False
)
plt.title('SHAP Summary Plot: Top 20 Features', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/figures/shap_summary.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved: results/figures/shap_summary.png")

## 4. Extract Prior Means (Algorithm 4.2, Steps 2-3)

Transform SHAP values to coefficient-scale priors:
- φ_j = mean(|SHAP_j|)
- β₀_j = φ_j / std(x_j)

In [None]:
# Extract prior means
beta_0 = extractor.extract_prior_means(verbose=True)

### Visualize Prior Mean Distribution

In [None]:
# Plot distribution of prior means
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(beta_0, bins=30, edgecolor='black', alpha=0.7)
axes[0].axvline(0, color='red', linestyle='--', linewidth=2, label='Zero')
axes[0].set_xlabel('Prior Mean (β₀_j)', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Distribution of Prior Means', fontsize=13, fontweight='bold')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Top 15 features by absolute value
top_indices = np.argsort(np.abs(beta_0))[-15:][::-1]
top_features = [extractor.feature_names_[i] for i in top_indices]
top_values = beta_0[top_indices]

colors = ['red' if v > 0 else 'blue' for v in top_values]
axes[1].barh(range(len(top_features)), top_values, color=colors, alpha=0.7, edgecolor='black')
axes[1].set_yticks(range(len(top_features)))
axes[1].set_yticklabels(top_features, fontsize=10)
axes[1].axvline(0, color='black', linestyle='-', linewidth=1)
axes[1].set_xlabel('Prior Mean (β₀_j)', fontsize=12)
axes[1].set_title('Top 15 Features by |β₀|', fontsize=13, fontweight='bold')
axes[1].grid(alpha=0.3, axis='x')
axes[1].invert_yaxis()

plt.tight_layout()
plt.savefig('../results/figures/prior_means_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved: results/figures/prior_means_distribution.png")

## 5. Extract Prior Variances (Algorithm 4.2, Steps 4-6)

Measure cross-dataset SHAP heterogeneity:
- Compute dataset-specific SHAP: φ_j^(k) for k ∈ {telco, bank, ecomm}
- Calculate variance: σ²_j = Var(φ_j^(1), φ_j^(2), φ_j^(3))
- Apply scaling: Σ₀ = diag(σ²_j × (1 + λ))

In [None]:
# Extract prior variances
Sigma_0 = extractor.extract_prior_variances(
    X_val,
    dataset_source_val,
    verbose=True
)

### Visualize Cross-Dataset Variance

In [None]:
# Compute per-dataset SHAP for visualization
datasets = dataset_source_val.unique()
shap_by_dataset = {}

for dataset in datasets:
    mask = (dataset_source_val == dataset)
    shap_by_dataset[dataset] = np.abs(shap_values[mask]).mean(axis=0)

# Convert to DataFrame
shap_df = pd.DataFrame(shap_by_dataset, index=extractor.feature_names_)

# Plot variance across datasets for top 15 features
prior_stds = np.sqrt(np.diag(Sigma_0))
top_var_indices = np.argsort(prior_stds)[-15:][::-1]
top_var_features = [extractor.feature_names_[i] for i in top_var_indices]

shap_df_top = shap_df.loc[top_var_features]

fig, ax = plt.subplots(figsize=(12, 7))
shap_df_top.plot(kind='barh', ax=ax, width=0.7, edgecolor='black')
ax.set_xlabel('Average |SHAP| per Dataset', fontsize=12)
ax.set_ylabel('')
ax.set_title('Cross-Dataset SHAP Variation (Top 15 Most Uncertain Features)', fontsize=13, fontweight='bold')
ax.legend(title='Dataset', fontsize=10, title_fontsize=11)
ax.grid(alpha=0.3, axis='x')
ax.invert_yaxis()

plt.tight_layout()
plt.savefig('../results/figures/cross_dataset_variance.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved: results/figures/cross_dataset_variance.png")

## 6. Generate Table 4.6: Extracted Prior Distributions

In [None]:
# Load feature importances from Table 4.4 (Day 10)
table_4_4 = pd.read_csv('../results/tables/table_4_4.csv')
print("Feature importances from Table 4.4:")
print(table_4_4.head())

In [None]:
# Generate Table 4.6
table_4_6 = extractor.generate_table_4_6(
    top_n=5,
    feature_importances=table_4_4
)

print("\nTable 4.6: Extracted Prior Distributions for Top 5 Features")
print("="*80)
print(table_4_6.to_string(index=False))
print("="*80)

### Interpretation of Table 4.6

For each feature:
- **Importance (w_j)**: CatBoost feature importance from Day 10
- **Avg SHAP (φ_j)**: Average absolute SHAP value (feature effect magnitude)
- **Prior Mean (β₀_j)**: Expected coefficient value for hierarchical model
  - Positive → increases churn probability
  - Negative → decreases churn probability
- **Prior Std (√Σ₀_jj)**: Uncertainty in transferability
  - Small → consistent effect across datasets (tight prior)
  - Large → heterogeneous effect (diffuse prior, allows SME adaptation)

In [None]:
# Save Table 4.6
table_4_6.to_csv('../results/tables/table_4_6.csv', index=False)

# Save as Markdown
with open('../results/tables/table_4_6.md', 'w') as f:
    f.write("# Table 4.6: Extracted Prior Distributions for Top 5 Features\n\n")
    f.write(table_4_6.to_markdown(index=False))
    f.write("\n\n*Generated from SHAP values on validation set (Algorithm 4.2)*\n")

print("✓ Table 4.6 saved to results/tables/")

## 7. Prior Predictive Check (Table 4.7)

Validate prior quality by comparing:
1. Random coefficients (baseline)
2. Prior-only predictions (using β₀, Σ₀)
3. Trained CatBoost (reference)

**Expected:** Prior-only should outperform random but underperform full model.

In [None]:
# Run prior predictive check
results = extractor.prior_predictive_check(
    X_val,
    y_val,
    n_samples=100,
    verbose=True
)

In [None]:
# Create Table 4.7
table_4_7 = pd.DataFrame([
    {
        'Model': 'Random coefficients β ~ N(0, 1)',
        'AUC': results['random_coefficients'],
        'Interpretation': 'Barely better than chance'
    },
    {
        'Model': 'Prior-only β ~ N(β₀, Σ₀)',
        'AUC': results['prior_only'],
        'Interpretation': 'Substantial signal from transfer learning'
    },
    {
        'Model': 'Fully-trained CatBoost',
        'AUC': results['trained_catboost'],
        'Interpretation': 'Full model performance'
    }
])

print("\nTable 4.7: Prior Predictive Performance on Validation Data")
print("="*80)
print(table_4_7.to_string(index=False))
print("="*80)

In [None]:
# Visualize Table 4.7
fig, ax = plt.subplots(figsize=(10, 6))

models = ['Random\nCoefficients', 'Prior-Only\n(Transfer)', 'CatBoost\n(Full Model)']
aucs = [
    results['random_coefficients'],
    results['prior_only'],
    results['trained_catboost']
]
colors = ['red', 'orange', 'green']

bars = ax.bar(models, aucs, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax.set_ylabel('AUC-ROC', fontsize=13)
ax.set_title('Prior Predictive Check: Model Comparison (Table 4.7)', fontsize=14, fontweight='bold')
ax.set_ylim([0.4, 1.0])
ax.axhline(0.5, color='gray', linestyle='--', linewidth=1, label='Random Guess')
ax.grid(alpha=0.3, axis='y')

# Add value labels on bars
for bar, auc in zip(bars, aucs):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
            f'{auc:.4f}',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.legend(fontsize=11)
plt.tight_layout()
plt.savefig('../results/figures/prior_predictive_check.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved: results/figures/prior_predictive_check.png")

In [None]:
# Save Table 4.7
table_4_7.to_csv('../results/tables/table_4_7.csv', index=False)

# Save as Markdown
with open('../results/tables/table_4_7.md', 'w') as f:
    f.write("# Table 4.7: Prior Predictive Performance on Validation Data\n\n")
    f.write(table_4_7.to_markdown(index=False))
    f.write("\n\n*Prior predictive check validates that extracted priors encode transferable knowledge.*\n")

print("✓ Table 4.7 saved to results/tables/")

## 8. Save Priors for Hierarchical Model

In [None]:
# Save priors
extractor.save_priors(
    '../models/transfer_learning/priors.pkl',
    include_metadata=True
)

In [None]:
# Verify saved priors
loaded_priors = SHAPPriorExtractor.load_priors('../models/transfer_learning/priors.pkl')

print("\nSaved prior structure:")
print(f"  - beta_0: {loaded_priors['beta_0'].shape}")
print(f"  - Sigma_0: {loaded_priors['Sigma_0'].shape}")
print(f"  - feature_names: {len(loaded_priors['feature_names'])} features")
print(f"  - lambda_scale: {loaded_priors['lambda_scale']}")
if 'metadata' in loaded_priors:
    print(f"  - metadata: {list(loaded_priors['metadata'].keys())}")

## 9. Summary Statistics

In [None]:
print("="*80)
print("COMPLETE: SHAP-Based Prior Extraction")
print("="*80)
print(f"\nPrior Means (β₀):")
print(f"  - Shape: {beta_0.shape}")
print(f"  - Mean |β₀|: {np.abs(beta_0).mean():.4f}")
print(f"  - Max |β₀|: {np.abs(beta_0).max():.4f}")
print(f"  - Positive coefficients: {(beta_0 > 0).sum()}")
print(f"  - Negative coefficients: {(beta_0 < 0).sum()}")

prior_stds = np.sqrt(np.diag(Sigma_0))
print(f"\nPrior Covariance (Σ₀):")
print(f"  - Shape: {Sigma_0.shape}")
print(f"  - Mean σ₀: {prior_stds.mean():.4f}")
print(f"  - Median σ₀: {np.median(prior_stds):.4f}")
print(f"  - Max σ₀: {prior_stds.max():.4f}")

print(f"\nPrior Predictive Check:")
print(f"  - Random AUC: {results['random_coefficients']:.4f}")
print(f"  - Prior-only AUC: {results['prior_only']:.4f}")
print(f"  - CatBoost AUC: {results['trained_catboost']:.4f}")
print(f"  - Prior improvement: {results['prior_only'] - results['random_coefficients']:.4f}")
print(f"  - Remaining gap: {results['trained_catboost'] - results['prior_only']:.4f}")