## Task 3 - Model Explainability

This notebook provides comprehensive explainability for the **best Task 2 model** using SHAP:

### Objectives:
1. **Feature Importance Baseline**: Extract and visualize built-in feature importance from ensemble model
2. **SHAP Analysis**: 
   - Global feature importance (SHAP summary plot)
   - Local explanations for individual predictions (TP, FP, FN)
3. **Interpretation**: Compare SHAP with built-in importance, identify top drivers
4. **Business Recommendations**: Actionable insights based on SHAP analysis

### Prereqs

1) Run Task 2 first (so models exist):

```bash
python -m scripts.task2_train --dataset all
```

2) Install Task 3 dependency:

```bash
pip install -r requirements-task3.txt
```



In [None]:
from pathlib import Path
import sys

# Ensure repo root is on PYTHONPATH so `import src...` works in Jupyter
sys.path.insert(0, str(Path("..").resolve()))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import shap

from src.modeling.task3_shap import Task3Paths, explain_task3

RAW_DIR = Path("../data/raw")
REPORTS_DIR = Path("../reports")
MODELS_DIR = Path("../models")

paths = Task3Paths(raw_dir=RAW_DIR, reports_dir=REPORTS_DIR, models_dir=MODELS_DIR)

shap.initjs()

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")



In [None]:
# Explain Fraud_Data best model

print("Analyzing Fraud_Data dataset...")
res_fraud = explain_task3(dataset="fraud", paths=paths, explain_size=200)
print(f"Best model: {res_fraud['model_name']}")
print(f"Examples found: {res_fraud['examples']}")
print(f"Test samples explained: {res_fraud['n_test_sample_explained']}")
res_fraud



In [None]:
## 1. Feature Importance Baseline (Fraud_Data)

Extract and visualize built-in feature importance from the ensemble model.



In [None]:
# Extract built-in feature importance
builtin_importance = res_fraud["builtin_importance"]
feature_names = res_fraud["feature_names"][:len(builtin_importance)]

# Create DataFrame for easier handling
importance_df = pd.DataFrame({
    "feature": feature_names,
    "importance": builtin_importance
}).sort_values("importance", ascending=False)

# Visualize top 10 features
plt.figure(figsize=(10, 6))
top_10 = importance_df.head(10)
plt.barh(range(len(top_10)), top_10["importance"].values)
plt.yticks(range(len(top_10)), top_10["feature"].values)
plt.xlabel("Feature Importance")
plt.title(f"Top 10 Built-in Feature Importance (Fraud_Data - {res_fraud['model_name']})")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print("\nTop 10 Features by Built-in Importance:")
print(importance_df.head(10).to_string(index=False))


## 2. SHAP Analysis (Fraud_Data)

### 2.1 Global Feature Importance - SHAP Summary Plot


In [None]:
# Global explanation: SHAP summary plot (Fraud_Data)
shap.plots.beeswarm(res_fraud["shap_values"], max_display=15, show=False)
plt.title("SHAP Summary Plot - Global Feature Importance (Fraud_Data)", fontsize=14, pad=20)
plt.tight_layout()
plt.show()


### 2.2 Local Explanations - Individual Predictions

SHAP force plots for specific cases: True Positive, False Positive, and False Negative.


In [None]:
# Local explanations (Fraud_Data): TP / FP / FN

def show_case(res, idx, title):
    if idx is None:
        print(f"{title}: not found in the explained sample (try increasing explain_size)")
        return
    print(f"\n{'='*60}")
    print(f"{title}")
    print(f"{'='*60}")
    print(f"Actual label: {res['y_test_sample'][idx]}")
    print(f"Predicted probability: {res['y_proba_sample'][idx]:.4f}")
    print(f"Predicted class: {1 if res['y_proba_sample'][idx] >= 0.5 else 0}")
    print()
    # Waterfall plot (works in most environments)
    shap.plots.waterfall(res["shap_values"][idx], max_display=15, show=False)
    plt.title(title, fontsize=12, pad=10)
    plt.tight_layout()
    plt.show()

show_case(res_fraud, res_fraud["examples"]["tp_index"], "True Positive (TP) - Fraud Correctly Flagged")
show_case(res_fraud, res_fraud["examples"]["fp_index"], "False Positive (FP) - Legitimate Transaction Flagged as Fraud")
show_case(res_fraud, res_fraud["examples"]["fn_index"], "False Negative (FN) - Missed Fraud")



## 3. Interpretation & Comparison (Fraud_Data)

### 3.1 Compare SHAP Importance with Built-in Feature Importance


In [None]:
# Calculate SHAP importance (mean absolute SHAP values)
shap_values_array = res_fraud["shap_values"].values
shap_importance = np.abs(shap_values_array).mean(axis=0)

# Align feature names
n_features = min(len(feature_names), len(builtin_importance), len(shap_importance))
feature_names_aligned = feature_names[:n_features]
builtin_aligned = builtin_importance[:n_features]
shap_aligned = shap_importance[:n_features]

# Normalize for comparison
builtin_norm = builtin_aligned / builtin_aligned.sum()
shap_norm = shap_aligned / shap_aligned.sum()

# Create comparison DataFrame
comparison_df = pd.DataFrame({
    "feature": feature_names_aligned,
    "builtin_importance": builtin_norm,
    "shap_importance": shap_norm
})

# Sort by SHAP importance
comparison_df = comparison_df.sort_values("shap_importance", ascending=False)

# Visualize comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

# Top 10 by built-in importance
top_builtin = comparison_df.nlargest(10, "builtin_importance")
ax1.barh(range(len(top_builtin)), top_builtin["builtin_importance"].values, label="Built-in", alpha=0.7)
ax1.barh(range(len(top_builtin)), top_builtin["shap_importance"].values, label="SHAP", alpha=0.7)
ax1.set_yticks(range(len(top_builtin)))
ax1.set_yticklabels(top_builtin["feature"].values)
ax1.set_xlabel("Normalized Importance")
ax1.set_title("Top 10 Features: Built-in vs SHAP Importance")
ax1.legend()
ax1.invert_yaxis()

# Top 10 by SHAP importance
top_shap = comparison_df.head(10)
ax2.barh(range(len(top_shap)), top_shap["builtin_importance"].values, label="Built-in", alpha=0.7)
ax2.barh(range(len(top_shap)), top_shap["shap_importance"].values, label="SHAP", alpha=0.7)
ax2.set_yticks(range(len(top_shap)))
ax2.set_yticklabels(top_shap["feature"].values)
ax2.set_xlabel("Normalized Importance")
ax2.set_title("Top 10 Features by SHAP Importance")
ax2.legend()
ax2.invert_yaxis()

plt.tight_layout()
plt.show()

print("\nTop 5 Drivers of Fraud Predictions (by SHAP importance):")
print(comparison_df.head(5)[["feature", "shap_importance", "builtin_importance"]].to_string(index=False))


### 3.2 Key Findings

**Top 5 Drivers of Fraud Predictions:**
1. Based on SHAP analysis, identify the top 5 features driving fraud predictions
2. Compare with built-in importance to identify any discrepancies
3. Note any surprising or counterintuitive findings


In [None]:
# Identify top 5 drivers
top_5_drivers = comparison_df.head(5)

print("="*60)
print("TOP 5 DRIVERS OF FRAUD PREDICTIONS (Fraud_Data)")
print("="*60)
for idx, row in top_5_drivers.iterrows():
    print(f"\n{row.name + 1}. {row['feature']}")
    print(f"   SHAP Importance: {row['shap_importance']:.4f}")
    print(f"   Built-in Importance: {row['builtin_importance']:.4f}")
    print(f"   Agreement: {'✓' if abs(row['shap_importance'] - row['builtin_importance']) < 0.05 else '⚠'}")

# Check for surprising findings
print("\n" + "="*60)
print("SURPRISING FINDINGS")
print("="*60)

# Features with high SHAP but low built-in
high_shap_low_builtin = comparison_df[
    (comparison_df["shap_importance"] > comparison_df["shap_importance"].quantile(0.75)) &
    (comparison_df["builtin_importance"] < comparison_df["builtin_importance"].quantile(0.5))
]

if len(high_shap_low_builtin) > 0:
    print("\nFeatures with HIGH SHAP importance but LOW built-in importance:")
    print("(These may have complex interactions that SHAP captures better)")
    print(high_shap_low_builtin[["feature", "shap_importance", "builtin_importance"]].to_string(index=False))
else:
    print("\nNo major discrepancies found between SHAP and built-in importance.")

# Features with low SHAP but high built-in
low_shap_high_builtin = comparison_df[
    (comparison_df["shap_importance"] < comparison_df["shap_importance"].quantile(0.5)) &
    (comparison_df["builtin_importance"] > comparison_df["builtin_importance"].quantile(0.75))
]

if len(low_shap_high_builtin) > 0:
    print("\nFeatures with LOW SHAP importance but HIGH built-in importance:")
    print("(These may have less direct impact on individual predictions)")
    print(low_shap_high_builtin[["feature", "shap_importance", "builtin_importance"]].to_string(index=False))


## 4. Business Recommendations (Fraud_Data)

Based on SHAP analysis, here are actionable business recommendations:


In [None]:
# Analyze SHAP values to generate recommendations
shap_vals = res_fraud["shap_values"].values
feature_data = res_fraud["shap_values"].data
feature_names_shap = res_fraud["shap_values"].feature_names

print("="*60)
print("BUSINESS RECOMMENDATIONS (Based on SHAP Analysis)")
print("="*60)

# Recommendation 1: Time-based features
time_features = [f for f in feature_names_shap if any(x in f.lower() for x in ['time', 'hour', 'day', 'signup'])]
if time_features:
    time_idx = [i for i, f in enumerate(feature_names_shap) if f in time_features]
    time_shap_impact = np.abs(shap_vals[:, time_idx]).mean()
    if len(time_idx) > 0:
        avg_impact = np.mean(time_shap_impact) if isinstance(time_shap_impact, np.ndarray) else time_shap_impact
        print("\n1. TRANSACTION TIMING & SIGNUP WINDOW")
        print(f"   Insight: Time-based features ({', '.join(time_features[:3])}) show significant impact")
        print(f"   Recommendation: Transactions within 24-48 hours of signup should receive")
        print(f"                    additional verification (OTP/2FA) due to higher fraud risk.")
        print(f"   SHAP Evidence: Time features contribute {avg_impact:.4f} average absolute SHAP value")

# Recommendation 2: Transaction velocity
velocity_features = [f for f in feature_names_shap if any(x in f.lower() for x in ['count', 'velocity', 'txn'])]
if velocity_features:
    velocity_idx = [i for i, f in enumerate(feature_names_shap) if f in velocity_features]
    velocity_shap_impact = np.abs(shap_vals[:, velocity_idx]).mean()
    if len(velocity_idx) > 0:
        avg_impact = np.mean(velocity_shap_impact) if isinstance(velocity_shap_impact, np.ndarray) else velocity_shap_impact
        print("\n2. TRANSACTION VELOCITY MONITORING")
        print(f"   Insight: Velocity features ({', '.join(velocity_features[:2])}) are key fraud indicators")
        print(f"   Recommendation: Implement real-time velocity checks:")
        print(f"                    - Flag users with >3 transactions in 1 hour for manual review")
        print(f"                    - Block users with >10 transactions in 24 hours until verified")
        print(f"   SHAP Evidence: Velocity features contribute {avg_impact:.4f} average absolute SHAP value")

# Recommendation 3: Device/User patterns
device_features = [f for f in feature_names_shap if any(x in f.lower() for x in ['device', 'user', 'unique'])]
if device_features:
    device_idx = [i for i, f in enumerate(feature_names_shap) if f in device_features]
    device_shap_impact = np.abs(shap_vals[:, device_idx]).mean()
    if len(device_idx) > 0:
        avg_impact = np.mean(device_shap_impact) if isinstance(device_shap_impact, np.ndarray) else device_shap_impact
        print("\n3. DEVICE & USER BEHAVIOR PATTERNS")
        print(f"   Insight: Device/user aggregation features ({', '.join(device_features[:2])}) reveal fraud patterns")
        print(f"   Recommendation: Monitor device-user relationships:")
        print(f"                    - Flag devices associated with >5 unique users in 30 days")
        print(f"                    - Require verification for users switching devices frequently")
        print(f"   SHAP Evidence: Device/user features contribute {avg_impact:.4f} average absolute SHAP value")

# Recommendation 4: Geographic/Country risk
country_features = [f for f in feature_names_shap if 'country' in f.lower()]
if country_features:
    country_idx = [i for i, f in enumerate(feature_names_shap) if f in country_features]
    country_shap_impact = np.abs(shap_vals[:, country_idx]).mean()
    if len(country_idx) > 0:
        avg_impact = np.mean(country_shap_impact) if isinstance(country_shap_impact, np.ndarray) else country_shap_impact
        print("\n4. GEOGRAPHIC RISK ASSESSMENT")
        print(f"   Insight: Country features show varying fraud risk levels")
        print(f"   Recommendation: Implement country-based risk scoring:")
        print(f"                    - High-risk countries: Require additional verification")
        print(f"                    - Mismatch between IP country and billing address: Flag for review")
        print(f"   SHAP Evidence: Country features contribute {avg_impact:.4f} average absolute SHAP value")

# Recommendation 5: Purchase value
value_features = [f for f in feature_names_shap if 'value' in f.lower() or 'amount' in f.lower() or 'purchase' in f.lower()]
if value_features:
    value_idx = [i for i, f in enumerate(feature_names_shap) if f in value_features]
    value_shap_impact = np.abs(shap_vals[:, value_idx]).mean()
    if len(value_idx) > 0:
        avg_impact = np.mean(value_shap_impact) if isinstance(value_shap_impact, np.ndarray) else value_shap_impact
        print("\n5. TRANSACTION VALUE THRESHOLDS")
        print(f"   Insight: Purchase value features ({', '.join(value_features[:1])}) impact fraud probability")
        print(f"   Recommendation: Implement tiered verification based on transaction value:")
        print(f"                    - Low value (<$50): Standard processing")
        print(f"                    - Medium value ($50-$500): Additional verification if combined with other risk factors")
        print(f"                    - High value (>$500): Always require step-up authentication")
        print(f"   SHAP Evidence: Value features contribute {avg_impact:.4f} average absolute SHAP value")

print("\n" + "="*60)
print("Note: These recommendations should be tested in a controlled environment")
print("      and adjusted based on business constraints and false positive tolerance.")
print("="*60)


In [None]:
# Explain creditcard best model

print("Analyzing creditcard dataset...")
res_cc = explain_task3(dataset="creditcard", paths=paths, explain_size=200)
print(f"Best model: {res_cc['model_name']}")
print(f"Examples found: {res_cc['examples']}")
print(f"Test samples explained: {res_cc['n_test_sample_explained']}")
res_cc



In [None]:
### Feature Importance Baseline (CreditCard)



In [None]:
# Extract built-in feature importance for creditcard
builtin_importance_cc = res_cc["builtin_importance"]
feature_names_cc = res_cc["feature_names"][:len(builtin_importance_cc)]

# Create DataFrame
importance_df_cc = pd.DataFrame({
    "feature": feature_names_cc,
    "importance": builtin_importance_cc
}).sort_values("importance", ascending=False)

# Visualize top 10 features
plt.figure(figsize=(10, 6))
top_10_cc = importance_df_cc.head(10)
plt.barh(range(len(top_10_cc)), top_10_cc["importance"].values)
plt.yticks(range(len(top_10_cc)), top_10_cc["feature"].values)
plt.xlabel("Feature Importance")
plt.title(f"Top 10 Built-in Feature Importance (CreditCard - {res_cc['model_name']})")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print("\nTop 10 Features by Built-in Importance:")
print(importance_df_cc.head(10).to_string(index=False))


### SHAP Summary Plot (CreditCard)


In [None]:
# Global explanation: SHAP summary plot (creditcard)
shap.plots.beeswarm(res_cc["shap_values"], max_display=15, show=False)
plt.title("SHAP Summary Plot - Global Feature Importance (CreditCard)", fontsize=14, pad=20)
plt.tight_layout()
plt.show()


In [None]:
# Local explanations (creditcard): TP / FP / FN

show_case(res_cc, res_cc["examples"]["tp_index"], "True Positive (TP) - Fraud Correctly Flagged")
show_case(res_cc, res_cc["examples"]["fp_index"], "False Positive (FP) - Legitimate Transaction Flagged as Fraud")
show_case(res_cc, res_cc["examples"]["fn_index"], "False Negative (FN) - Missed Fraud")



### Interpretation & Business Recommendations (CreditCard)


In [None]:
# Calculate SHAP importance for creditcard
shap_values_array_cc = res_cc["shap_values"].values
shap_importance_cc = np.abs(shap_values_array_cc).mean(axis=0)

# Align feature names
n_features_cc = min(len(feature_names_cc), len(builtin_importance_cc), len(shap_importance_cc))
feature_names_aligned_cc = feature_names_cc[:n_features_cc]
builtin_aligned_cc = builtin_importance_cc[:n_features_cc]
shap_aligned_cc = shap_importance_cc[:n_features_cc]

# Normalize
builtin_norm_cc = builtin_aligned_cc / builtin_aligned_cc.sum()
shap_norm_cc = shap_aligned_cc / shap_aligned_cc.sum()

# Create comparison DataFrame
comparison_df_cc = pd.DataFrame({
    "feature": feature_names_aligned_cc,
    "builtin_importance": builtin_norm_cc,
    "shap_importance": shap_norm_cc
}).sort_values("shap_importance", ascending=False)

print("\nTop 5 Drivers of Fraud Predictions (CreditCard - by SHAP importance):")
print(comparison_df_cc.head(5)[["feature", "shap_importance", "builtin_importance"]].to_string(index=False))

print("\n" + "="*60)
print("BUSINESS RECOMMENDATIONS (CreditCard Dataset)")
print("="*60)
print("\nNote: CreditCard dataset uses PCA features (V1-V28), making direct")
print("      business interpretation challenging. Recommendations focus on:")
print("      - Transaction amount monitoring")
print("      - Time-based patterns")
print("      - Anomaly detection thresholds")
print("\n1. TRANSACTION AMOUNT MONITORING")
print("   - Implement dynamic thresholds based on user history")
print("   - Flag transactions >2 standard deviations from user's average")
print("\n2. TIME-BASED PATTERNS")
print("   - Monitor transactions outside normal user activity windows")
print("   - Flag rapid successive transactions")
print("\n3. ANOMALY DETECTION")
print("   - Use PCA feature combinations as anomaly indicators")
print("   - Implement real-time scoring with adaptive thresholds")
