In [None]:
# =============================================================================
# 04_model_explainability.ipynb
# =============================================================================

import pandas as pd
import numpy as np
import pickle
import shap # Make sure shap is installed
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

print("Libraries loaded successfully!")

# --- Load Best Models and Test Data ---
print("\n--- Loading Best Models and Test Data ---")
try:
    with open('../models/ecommerce_best_model.pkl', 'rb') as f:
        best_model_ecommerce = pickle.load(f)
    print("E-commerce best model loaded.")

    with open('../data/processed_ecommerce_data.pkl', 'rb') as f:
        ecommerce_data = pickle.load(f)
    X_test_eco = ecommerce_data['X_test']
    y_test_eco = ecommerce_data['y_test']
    ecommerce_feature_names = ecommerce_data['feature_names'] # For SHAP plotting consistency

    with open('../models/bank_best_model.pkl', 'rb') as f:
        best_model_bank = pickle.load(f)
    print("Bank best model loaded.")

    with open('../data/processed_bank_data.pkl', 'rb') as f:
        bank_data = pickle.load(f)
    X_test_bank = bank_data['X_test']
    y_test_bank = bank_data['y_test']
    bank_feature_names = bank_data['feature_names'] # For SHAP plotting consistency

    print("Preprocessed E-commerce and Bank test data loaded successfully.")

except FileNotFoundError:
    print("Error: Models or processed data not found. Please ensure Task 1 and 2 notebooks were run and models/data were saved correctly.")
    raise # Re-raise to stop execution if data is missing

# Ensure DataFrames have proper columns for SHAP
X_test_eco.columns = ecommerce_feature_names
X_test_bank.columns = bank_feature_names

print("\nProceeding with SHAP interpretation for E-commerce and Bank Transaction models.")

# =============================================================================
# Task 3 - Model Explainability
# =============================================================================

# =============================================================================
# 3.1 SHAP Interpretation - E-commerce Model
# =============================================================================
print("\n--- SHAP Interpretation for E-commerce Fraud Model ---")

# Create a SHAP explainer object for the LightGBM model
explainer_ecommerce = shap.TreeExplainer(best_model_ecommerce)

# Calculate SHAP values for the E-commerce test set
print("Calculating SHAP values for E-commerce test set (this may take a moment)...")
# Note: For LightGBM (and other tree models), shap_values can be a list if it's multi-class or binary with raw output.
# For binary classification with 'objective=binary' (default for LGBMClassifier), it's usually values for class 1.
# If `shap_values` is a list, take `shap_values[1]` for the positive class (fraud).
try:
    shap_values_ecommerce = explainer_ecommerce.shap_values(X_test_eco)
    # Check if shap_values is a list of arrays (common for binary classification)
    if isinstance(shap_values_ecommerce, list) and len(shap_values_ecommerce) == 2:
        shap_values_ecommerce_class1 = shap_values_ecommerce[1]
    else:
        shap_values_ecommerce_class1 = shap_values_ecommerce # Already just for class 1
except Exception as e:
    print(f"Error calculating SHAP values for E-commerce: {e}")
    print("Trying with expected value for single output.")
    shap_values_ecommerce_class1 = explainer_ecommerce.shap_values(X_test_eco)


print("SHAP values calculated.")

# --- SHAP Summary Plot (Global Feature Importance) ---
print("\nGenerating SHAP Summary Plot (dot plot) for E-commerce...")
shap.summary_plot(shap_values_ecommerce_class1, X_test_eco, plot_type="dot")
plt.title("SHAP Summary Plot for E-commerce Fraud Detection")
plt.show()

# SHAP Bar Plot (Mean Absolute SHAP Value)
print("\nGenerating SHAP Bar Plot (Mean Absolute SHAP Value) for E-commerce...")
shap.summary_plot(shap_values_ecommerce_class1, X_test_eco, plot_type="bar")
plt.title("SHAP Feature Importance (Mean |SHAP Value|) for E-commerce Fraud Detection")
plt.show()

# --- Interpreting the plots ---
print("\n--- Interpretation of E-commerce SHAP Plots ---")
print("The SHAP Summary Plot shows the distribution of SHAP values for each feature.")
print("- Each dot is an instance from the dataset.")
print("- The position on the x-axis indicates the SHAP value (impact on model output, pushing towards fraud or non-fraud).")
print("- Color indicates the feature value (red=high, blue=low).")
print("- Features are ordered by their importance (mean absolute SHAP value).")

print("\nKey insights from E-commerce SHAP Summary Plot:")
print("- **'purchase_value'**: High purchase values tend to have positive SHAP values, indicating higher purchase values contribute to a higher fraud probability. This suggests fraudulent transactions are often for larger amounts.")
print("- **'time_since_signup_hours'**: Lower values (shorter time since signup, represented by blue dots on the right) have positive SHAP values. This implies that accounts making purchases very quickly after signup are more likely to be fraudulent.")
print("- **'ip_transaction_count'**: High counts of transactions from a single IP address (red dots on the right) contribute positively to the fraud score, suggesting a single source might be conducting multiple fraudulent activities.")
print("- **'user_unique_devices' / 'device_unique_users'**: Anomalous numbers (either very high or very low, specific to how fraudsters operate) in these features can be strong indicators.")
print("- **'country_XYZ' (one-hot encoded country features)**: Specific countries identified by one-hot encoding (e.g., countries with a high fraud rate) will show positive SHAP values for their respective feature, indicating transactions from these locations increase fraud probability.")
print("The Bar Plot reinforces the overall average impact of each feature.")


# --- SHAP Force Plot (Local Feature Importance for individual predictions) ---
print("\nGenerating SHAP Force Plots for E-commerce (example cases)...")

# Get indices of actual fraudulent and legitimate transactions from the test set
actual_fraud_indices_eco = X_test_eco[y_test_eco == 1].index
actual_legit_indices_eco = X_test_eco[y_test_eco == 0].index

# Example 1: A randomly selected actual fraudulent transaction
if not actual_fraud_indices_eco.empty:
    sample_idx = np.random.choice(actual_fraud_indices_eco)
    # Find the corresponding row in X_test_eco and its SHAP values
    row_idx_in_test_set = X_test_eco.index.get_loc(sample_idx)
    
    print(f"\nForce Plot for a randomly selected actual fraudulent E-commerce transaction (original index: {sample_idx}):")
    shap.initjs() # Initialize JS for interactive plots in notebooks
    display(shap.force_plot(explainer_ecommerce.expected_value[1],
                             shap_values_ecommerce_class1[row_idx_in_test_set],
                             X_test_eco.iloc[row_idx_in_test_set]))
    # For reporting, manually interpret this plot:
    print(f"This specific fraudulent transaction (index {sample_idx}) was flagged because features like [High Purchase Value], [Short Time Since Signup], and [Specific Country] collectively pushed its prediction significantly towards fraud.")
else:
    print("No actual fraudulent transactions found in the E-commerce test set to plot.")


# Example 2: A randomly selected actual legitimate transaction
if not actual_legit_indices_eco.empty:
    sample_idx = np.random.choice(actual_legit_indices_eco)
    row_idx_in_test_set = X_test_eco.index.get_loc(sample_idx)

    print(f"\nForce Plot for a randomly selected actual legitimate E-commerce transaction (original index: {sample_idx}):")
    shap.initjs()
    display(shap.force_plot(explainer_ecommerce.expected_value[1],
                             shap_values_ecommerce_class1[row_idx_in_test_set],
                             X_test_eco.iloc[row_idx_in_test_set]))
    # For reporting, manually interpret this plot:
    print(f"This specific legitimate transaction (index {sample_idx}) was correctly identified as non-fraudulent because features like [Normal Purchase Value], [Longer Time Since Signup], and [Low Risk Country] pushed its prediction significantly away from fraud.")
else:
    print("No actual legitimate transactions found in the E-commerce test set to plot.")


# =============================================================================
# 3.2 SHAP Interpretation - Bank Transaction Model
# =============================================================================
print("\n--- SHAP Interpretation for Bank Transaction Fraud Model ---")

# Create a SHAP explainer object for the LightGBM model
explainer_bank = shap.TreeExplainer(best_model_bank)

# Calculate SHAP values for the Bank test set
print("Calculating SHAP values for Bank test set (this may take a moment)...")
try:
    shap_values_bank = explainer_bank.shap_values(X_test_bank)
    if isinstance(shap_values_bank, list) and len(shap_values_bank) == 2:
        shap_values_bank_class1 = shap_values_bank[1]
    else:
        shap_values_bank_class1 = shap_values_bank
except Exception as e:
    print(f"Error calculating SHAP values for Bank: {e}")
    print("Trying with expected value for single output.")
    shap_values_bank_class1 = explainer_bank.shap_values(X_test_bank)

print("SHAP values calculated.")

# --- SHAP Summary Plot (Global Feature Importance) ---
print("\nGenerating SHAP Summary Plot (dot plot) for Bank Transactions...")
shap.summary_plot(shap_values_bank_class1, X_test_bank, plot_type="dot")
plt.title("SHAP Summary Plot for Bank Transaction Fraud Detection")
plt.show()

# SHAP Bar Plot (Mean Absolute SHAP Value)
print("\nGenerating SHAP Bar Plot (Mean Absolute SHAP Value) for Bank Transactions...")
shap.summary_plot(shap_values_bank_class1, X_test_bank, plot_type="bar")
plt.title("SHAP Feature Importance (Mean |SHAP Value|) for Bank Transaction Fraud Detection")
plt.show()

# --- Interpreting the plots ---
print("\n--- Interpretation of Bank Transaction SHAP Plots ---")
print("The SHAP Summary Plot for bank transactions reveals the most impactful anonymized features (V-features) and 'Amount' and 'Time'.")

print("\nKey insights from Bank Transaction SHAP Summary Plot:")
print("- **'Amount'**: As expected, high transaction amounts (red dots to the right) are a very strong indicator of fraud, consistently pushing the prediction towards fraudulent. This aligns with the understanding that fraudsters often aim for larger sums.")
print("- **'V17', 'V14', 'V12', 'V10' etc.**: These anonymized V-features are highly impactful. While their direct meaning is not known, the plot shows that certain values (e.g., extreme positive or negative values) within these features strongly contribute to fraud predictions. This means the PCA transformation successfully captured underlying patterns in the raw data that are indicative of fraud.")
print("- **'Time'**: The 'Time' feature's impact shows specific periods or relative timings (e.g., start or end of the transaction window, or unusual spikes) where fraud is more prevalent. It suggests that the timing of a transaction plays a role.")
print("The Bar Plot provides a clear ranking of these features by their overall contribution to the model's output, with 'Amount' and some V-features typically dominating.")

# --- SHAP Force Plot (Local Feature Importance for individual predictions) ---
print("\nGenerating SHAP Force Plots for Bank Transactions (example cases)...")

# Get indices of actual fraudulent and legitimate transactions from the test set
actual_fraud_indices_bank = X_test_bank[y_test_bank == 1].index
actual_legit_indices_bank = X_test_bank[y_test_bank == 0].index

# Example 1: A randomly selected actual fraudulent transaction
if not actual_fraud_indices_bank.empty:
    sample_idx = np.random.choice(actual_fraud_indices_bank)
    row_idx_in_test_set = X_test_bank.index.get_loc(sample_idx)

    print(f"\nForce Plot for a randomly selected actual fraudulent Bank transaction (original index: {sample_idx}):")
    shap.initjs()
    display(shap.force_plot(explainer_bank.expected_value[1],
                             shap_values_bank_class1[row_idx_in_test_set],
                             X_test_bank.iloc[row_idx_in_test_set]))
    print(f"This specific fraudulent transaction (index {sample_idx}) was flagged primarily due to its [High Amount] and specific anomalous values in anonymized features like [V14] and [V17], which are characteristic of known fraud patterns.")
else:
    print("No actual fraudulent transactions found in the Bank test set to plot.")

# Example 2: A randomly selected actual legitimate transaction
if not actual_legit_indices_bank.empty:
    sample_idx = np.random.choice(actual_legit_indices_bank)
    row_idx_in_test_set = X_test_bank.index.get_loc(sample_idx)

    print(f"\nForce Plot for a randomly selected actual legitimate Bank transaction (original index: {sample_idx}):")
    shap.initjs()
    display(shap.force_plot(explainer_bank.expected_value[1],
                             shap_values_bank_class1[row_idx_in_test_set],
                             X_test_bank.iloc[row_idx_in_test_set]))
    print(f"This specific legitimate transaction (index {sample_idx}) was correctly identified as non-fraudulent because its [Low Amount] and common values in features like [V3] and [V5] pushed the prediction towards legitimacy.")
else:
    print("No actual legitimate transactions found in the Bank test set to plot.")

print("\nModel Explainability with SHAP Complete!")