# Churn Prediction Model with SHAP Explainability

This notebook builds machine learning models to predict customer churn, evaluates model performance, and uses SHAP (SHapley Additive exPlanations) to explain predictions and understand feature importance.

## 1. Import Required Libraries

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import bz2
import warnings

from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, precision_score, recall_score, 
                            f1_score, roc_auc_score, roc_curve, confusion_matrix,
                            classification_report)
import shap

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Libraries imported successfully!")

## 2. Load and Prepare Data

In [None]:
# Load dataset
df = pd.read_csv('../data/WA_Fn-UseC_-Telco-Customer-Churn.csv')

# Data preparation
df['TotalCharges'] = pd.to_numeric(df['TotalCharges'], errors='coerce')
df['TotalCharges'].fillna(df['MonthlyCharges'], inplace=True)
df['SeniorCitizen'] = df['SeniorCitizen'].map({0: 'No', 1: 'Yes'})

print(f"Dataset Shape: {df.shape}")
print(f"Churn Distribution:\n{df['Churn'].value_counts()}\n")
df.head()

## 3. Feature Engineering and Encoding

In [None]:
# Drop customerID
df_model = df.drop(['customerID'], axis=1)

# Encode target variable
df_model['Churn'] = df_model['Churn'].map({'Yes': 1, 'No': 0})

# Identify categorical and numerical columns
categorical_cols = df_model.select_dtypes(include=['object']).columns.tolist()
numerical_cols = ['tenure', 'MonthlyCharges', 'TotalCharges']

print(f"Categorical features: {len(categorical_cols)}")
print(f"Numerical features: {len(numerical_cols)}")

# Label encode categorical features
label_encoders = {}
df_encoded = df_model.copy()

for col in categorical_cols:
    le = LabelEncoder()
    df_encoded[col] = le.fit_transform(df_encoded[col])
    label_encoders[col] = le

print(f"\n✓ Encoded {len(categorical_cols)} categorical features")
print(f"✓ Final dataset shape: {df_encoded.shape}")

## 4. Train-Test Split and Scaling

In [None]:
# Split features and target
X = df_encoded.drop('Churn', axis=1)
y = df_encoded['Churn']

# Train-test split with stratification
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Scale numerical features
scaler = StandardScaler()
X_train[numerical_cols] = scaler.fit_transform(X_train[numerical_cols])
X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])

print("Dataset Split:")
print("=" * 80)
print(f"Training set: {X_train.shape[0]:,} samples ({X_train.shape[0]/len(X)*100:.1f}%)")
print(f"Testing set:  {X_test.shape[0]:,} samples ({X_test.shape[0]/len(X)*100:.1f}%)")
print(f"Features: {X_train.shape[1]}")
print(f"\nChurn distribution in training set:")
print(y_train.value_counts(normalize=True) * 100)
print(f"\nChurn distribution in testing set:")
print(y_test.value_counts(normalize=True) * 100)

## 5. Build and Train Models

In [None]:
# Initialize models
models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1),
    'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, random_state=42)
}

# Train and evaluate models
results = {}

for name, model in models.items():
    print(f"\nTraining {name}...")
    
    # Train model
    model.fit(X_train, y_train)
    
    # Predictions
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    
    # Calculate metrics
    results[name] = {
        'model': model,
        'accuracy': accuracy_score(y_test, y_pred),
        'precision': precision_score(y_test, y_pred),
        'recall': recall_score(y_test, y_pred),
        'f1': f1_score(y_test, y_pred),
        'roc_auc': roc_auc_score(y_test, y_pred_proba),
        'predictions': y_pred,
        'probabilities': y_pred_proba
    }
    
    print(f"✓ {name} trained successfully")

print("\n" + "=" * 80)
print("MODEL PERFORMANCE COMPARISON")
print("=" * 80)
performance_df = pd.DataFrame({
    'Model': list(results.keys()),
    'Accuracy': [results[m]['accuracy'] for m in results],
    'Precision': [results[m]['precision'] for m in results],
    'Recall': [results[m]['recall'] for m in results],
    'F1-Score': [results[m]['f1'] for m in results],
    'ROC-AUC': [results[m]['roc_auc'] for m in results]
})
performance_df = performance_df.round(4)
print(performance_df.to_string(index=False))

## 6. Model Evaluation Visualizations

In [None]:
# ROC Curves
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# ROC Curve comparison
for name in results:
    fpr, tpr, _ = roc_curve(y_test, results[name]['probabilities'])
    auc = results[name]['roc_auc']
    axes[0].plot(fpr, tpr, label=f'{name} (AUC = {auc:.3f})', linewidth=2.5)

axes[0].plot([0, 1], [0, 1], 'k--', label='Random Classifier', linewidth=1.5)
axes[0].set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
axes[0].set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
axes[0].set_title('ROC Curves Comparison', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Metrics comparison bar plot
metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']
x = np.arange(len(metrics))
width = 0.25

for idx, name in enumerate(results.keys()):
    values = [results[name]['accuracy'], results[name]['precision'], 
             results[name]['recall'], results[name]['f1'], results[name]['roc_auc']]
    axes[1].bar(x + idx*width, values, width, label=name, alpha=0.8)

axes[1].set_xlabel('Metrics', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Score', fontsize=12, fontweight='bold')
axes[1].set_title('Model Performance Metrics', fontsize=14, fontweight='bold')
axes[1].set_xticks(x + width)
axes[1].set_xticklabels(metrics)
axes[1].legend(fontsize=10)
axes[1].grid(axis='y', alpha=0.3)
axes[1].set_ylim([0, 1.0])

plt.tight_layout()
plt.savefig('../static/images/model_evaluation.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, name in enumerate(results.keys()):
    cm = confusion_matrix(y_test, results[name]['predictions'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx],
               cbar=True, square=True, linewidths=1, linecolor='black')
    axes[idx].set_xlabel('Predicted', fontsize=12, fontweight='bold')
    axes[idx].set_ylabel('Actual', fontsize=12, fontweight='bold')
    axes[idx].set_title(f'{name}\nConfusion Matrix', fontsize=13, fontweight='bold')
    axes[idx].set_xticklabels(['No Churn', 'Churn'])
    axes[idx].set_yticklabels(['No Churn', 'Churn'])

plt.tight_layout()
plt.savefig('../static/images/confusion_matrices.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Hyperparameter Tuning - Random Forest

In [None]:
# Hyperparameter tuning for Random Forest
print("Performing Grid Search for Random Forest...")

param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [10, 20, None],
    'min_samples_split': [2, 5],
    'min_samples_leaf': [1, 2]
}

rf = RandomForestClassifier(random_state=42, n_jobs=-1)
grid_search = GridSearchCV(rf, param_grid, cv=5, scoring='roc_auc', 
                          verbose=1, n_jobs=-1)
grid_search.fit(X_train, y_train)

print("\n" + "=" * 80)
print("BEST HYPERPARAMETERS")
print("=" * 80)
print(grid_search.best_params_)

# Train final model with best parameters
best_rf = grid_search.best_estimator_
y_pred_best = best_rf.predict(X_test)
y_pred_proba_best = best_rf.predict_proba(X_test)[:, 1]

print("\n" + "=" * 80)
print("BEST RANDOM FOREST PERFORMANCE")
print("=" * 80)
print(f"Accuracy:  {accuracy_score(y_test, y_pred_best):.4f}")
print(f"Precision: {precision_score(y_test, y_pred_best):.4f}")
print(f"Recall:    {recall_score(y_test, y_pred_best):.4f}")
print(f"F1-Score:  {f1_score(y_test, y_pred_best):.4f}")
print(f"ROC-AUC:   {roc_auc_score(y_test, y_pred_proba_best):.4f}")

print("\n" + "=" * 80)
print("CLASSIFICATION REPORT")
print("=" * 80)
print(classification_report(y_test, y_pred_best, target_names=['No Churn', 'Churn']))

## 8. Feature Importance Analysis

In [None]:
# Feature importance from Random Forest
feature_importance = pd.DataFrame({
    'Feature': X.columns,
    'Importance': best_rf.feature_importances_
}).sort_values('Importance', ascending=False)

# Plot top 15 features
plt.figure(figsize=(12, 8))
top_features = feature_importance.head(15)
plt.barh(range(len(top_features)), top_features['Importance'], color='steelblue', alpha=0.8)
plt.yticks(range(len(top_features)), top_features['Feature'])
plt.xlabel('Importance Score', fontsize=12, fontweight='bold')
plt.ylabel('Features', fontsize=12, fontweight='bold')
plt.title('Top 15 Feature Importances (Random Forest)', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('../static/images/feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nTop 15 Most Important Features:")
print("=" * 80)
for idx, row in feature_importance.head(15).iterrows():
    print(f"{row['Feature']:30s}: {row['Importance']:.4f}")

## 9. SHAP Analysis for Model Explainability

SHAP (SHapley Additive exPlanations) values provide a unified measure of feature importance and show how each feature contributes to individual predictions.

In [None]:
# Create SHAP explainer
print("Creating SHAP explainer (this may take a few minutes)...")
explainer = shap.TreeExplainer(best_rf)
shap_values = explainer.shap_values(X_test)

# For binary classification, we get shap values for both classes
# We'll use class 1 (churn) for visualization
if isinstance(shap_values, list):
    shap_values_churn = shap_values[1]
else:
    shap_values_churn = shap_values

print("✓ SHAP explainer created successfully")

In [None]:
# SHAP Summary Plot - Feature Importance
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_churn, X_test, plot_type="bar", show=False)
plt.title('SHAP Feature Importance (Mean Absolute SHAP Values)', fontsize=14, fontweight='bold', pad=20)
plt.xlabel('Mean |SHAP Value|', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig('../static/images/shap_importance.png', dpi=150, bbox_inches='tight')
plt.show()

# SHAP Summary Plot - Feature Impact
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_churn, X_test, show=False)
plt.title('SHAP Feature Impact on Churn Prediction', fontsize=14, fontweight='bold', pad=20)
plt.xlabel('SHAP Value (Impact on Model Output)', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig('../static/images/shap_summary.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ SHAP visualizations created")

In [None]:
# SHAP Dependence Plots for top features
top_4_features = feature_importance.head(4)['Feature'].tolist()

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.ravel()

for idx, feature in enumerate(top_4_features):
    shap.dependence_plot(feature, shap_values_churn, X_test, 
                        ax=axes[idx], show=False)
    axes[idx].set_title(f'SHAP Dependence: {feature}', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('../static/images/shap_dependence.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ SHAP dependence plots created for top {len(top_4_features)} features")

## 10. Save Models and Artifacts for Deployment

In [None]:
# Save the best model
with open('../model.pkl', 'wb') as f:
    pickle.dump(best_rf, f)
print("✓ Best Random Forest model saved to '../model.pkl'")

# Save the SHAP explainer (compressed)
with bz2.BZ2File('../explainer.bz2', 'w') as f:
    pickle.dump(explainer, f)
print("✓ SHAP explainer saved to '../explainer.bz2'")

# Save scaler and encoders
artifacts = {
    'scaler': scaler,
    'label_encoders': label_encoders,
    'feature_names': X.columns.tolist(),
    'numerical_cols': numerical_cols,
    'categorical_cols': categorical_cols
}

with open('../preprocessing_artifacts.pkl', 'wb') as f:
    pickle.dump(artifacts, f)
print("✓ Preprocessing artifacts saved to '../preprocessing_artifacts.pkl'")

# Save model performance metrics
performance_summary = {
    'test_accuracy': accuracy_score(y_test, y_pred_best),
    'test_precision': precision_score(y_test, y_pred_best),
    'test_recall': recall_score(y_test, y_pred_best),
    'test_f1': f1_score(y_test, y_pred_best),
    'test_roc_auc': roc_auc_score(y_test, y_pred_proba_best),
    'best_params': grid_search.best_params_,
    'feature_importance': feature_importance.to_dict('records')
}

with open('../model_performance.pkl', 'wb') as f:
    pickle.dump(performance_summary, f)
print("✓ Model performance summary saved to '../model_performance.pkl'")

print("\n" + "=" * 80)
print("ALL MODELS AND ARTIFACTS SAVED SUCCESSFULLY")
print("=" * 80)
print("\nFiles created:")
print("  - model.pkl (Random Forest classifier)")
print("  - explainer.bz2 (SHAP explainer)")
print("  - preprocessing_artifacts.pkl (scalers and encoders)")
print("  - model_performance.pkl (metrics and parameters)")
print("\nThe Flask app can now use these trained models for predictions!")