# Customer Churn Prediction with Machine Learning

**Course**: CS 5393 - Introduction to Machine Learning  
**Student**: Yaw Boateng 
**Date**: October 11th, 2025

---

## Project Overview

This project aims to develop a comprehensive machine learning pipeline for predicting customer churn in the banking sector. Using a dataset of 10,000 bank customers, we will build and compare multiple classification models to predict customer churn, extract probability scores for risk ranking and prioritization, evaluate models using both classification and regression metrics, deploy the best-performing model through a web application for real-time predictions, and integrate Large Language Models (LLMs) to explain predictions and generate personalized retention emails for customers.


## Dataset Information

### Source and Description
- **Dataset**: Bank Customer Churn Dataset
- **Source**: Kaggle
- **URL**: https://www.kaggle.com/datasets/mathchi/churn-for-bank-customers
- **Size**: 10,000 rows (excluding header)
- **Features**: 14 columns total

### Dataset Structure
The dataset contains the following features:

**Independent Variables (13 features)**:
- `RowNumber`: Sequential row identifier
- `CustomerId`: Unique customer identifier
- `Surname`: Customer's last name
- `CreditScore`: Customer's credit score (300-850)
- `Geography`: Country of residence (France, Germany, Spain)
- `Gender`: Customer gender (Male, Female)
- `Age`: Customer age
- `Tenure`: Number of years as bank customer
- `Balance`: Account balance
- `NumOfProducts`: Number of bank products used
- `HasCrCard`: Credit card ownership (0/1)
- `IsActiveMember`: Active membership status (0/1)
- `EstimatedSalary`: Estimated annual salary

**Target Variable (1 feature)**:
- `Exited`: Churn indicator (0 = stayed, 1 = churned)

### Data Types and Characteristics
- **Numerical Features**: CreditScore, Age, Tenure, Balance, NumOfProducts, EstimatedSalary
- **Categorical Features**: Geography, Gender, HasCrCard, IsActiveMember
- **Binary Features**: HasCrCard, IsActiveMember, Exited


## Machine Learning Task Definition

### Dual Approach: Classification with Probability-Based Analysis

This project employs a **hybrid approach** that combines binary classification with probability-based regression analysis to maximize both predictive accuracy and business utility.

### Primary Task: Binary Classification

**Objective**: Predict customer churn as a binary classification problem (churned vs. stayed).

**Problem Formulation**:
- **Input**: Customer features (CreditScore, Geography, Gender, Age, Tenure, Balance, etc.)
- **Output**: Binary churn prediction (0 = stayed, 1 = churned)
- **Goal**: Maximize recall for churners (class 1) to catch as many at-risk customers as possible

**Justification for Classification Approach**:
1. **Direct Business Decision**: Classification provides clear, actionable predictions (churn or no churn) that directly inform retention strategies
2. **Class Imbalance Handling**: The dataset exhibits class imbalance, allowing us to explore techniques like SMOTE to improve model performance on the minority class
3. **Business Priority**: For banking applications, recall (catching churners) is more critical than precision, as missing a churner has significant financial impact
4. **Model Interpretability**: Classification models provide clear feature importance rankings and can generate probability scores for risk stratification
5. **Standard Practice**: Binary classification is the standard approach for churn prediction in industry, making our solution directly applicable

### Secondary Analysis: Probability-Based Evaluation

**Objective**: Extract probability scores from classification models and evaluate them using regression-style metrics.

**Approach**:
- Use `.predict_proba()` to extract churn probability scores (0.0 to 1.0) from classification models
- Calculate regression metrics (RMSE, MAE, R²) on probability predictions
- Enable risk ranking and prioritization of customers by churn probability
- Compare binary predictions vs probability-based risk assessment

**Why Probability-Based Analysis?**:
1. **Risk Stratification**: Probability scores enable banks to rank customers by churn risk and allocate resources efficiently
2. **Threshold Optimization**: Banks can set different intervention thresholds (e.g., >0.7 = high priority) based on business needs
3. **Resource Allocation**: Probability-based ranking helps prioritize retention efforts on highest-risk customers
4. **Comprehensive Evaluation**: Regression metrics on probabilities provide additional insights into model calibration and probability estimation quality
5. **Business Flexibility**: Allows banks to adjust retention strategies based on risk tiers (high/medium/low probability)

### Success Metrics

**Classification Metrics (Primary)**:
- **Recall (Class 1)**: Primary metric - ability to catch customers who will actually churn
- **Precision**: Of predicted churners, how many actually churned
- **F1-Score**: Harmonic mean of precision and recall (balanced metric)
- **Accuracy**: Overall correctness of predictions
- **Confusion Matrix**: Detailed breakdown of predictions vs. actuals

**Probability-Based Metrics (Secondary)**:
- **RMSE (Root Mean Squared Error)**: Measures how well probability predictions match actual outcomes
- **MAE (Mean Absolute Error)**: Average difference between predicted and actual probabilities
- **R² Score**: Proportion of variance in churn explained by probability predictions
- **Probability Calibration**: How well-calibrated the probability estimates are

**Business Metrics**:
- Maximize identification of at-risk customers for proactive retention efforts
- Enable risk-based customer ranking and prioritization
- Support tiered intervention strategies based on churn probability


## Project Motivation

### Learning Objectives
This project serves multiple educational and practical purposes:

1. **Practical ML Application**: Apply machine learning techniques to a real-world business problem that affects millions of customers globally
2. **End-to-End Pipeline Development**: Learn to build complete ML workflows from data preprocessing to model deployment
3. **Algorithm Comparison**: Gain hands-on experience comparing different machine learning algorithms and understanding their strengths/weaknesses
4. **Business Analytics Integration**: Understand how ML models can drive business decisions and create value

### Business Relevance
Customer churn prediction is critical for the banking industry:

- **Financial Impact**: Customer acquisition costs are 5-25x higher than retention costs
- **Revenue Protection**: Preventing churn directly protects revenue streams
- **Competitive Advantage**: Proactive retention strategies improve customer satisfaction and loyalty
- **Resource Optimization**: Targeted interventions are more cost-effective than blanket retention programs

### Technical Interest
This project offers rich learning opportunities:

- **Data Preprocessing**: Handle mixed data types (numerical, categorical, binary)
- **Feature Engineering**: Create meaningful features from customer behavior data
- **Model Selection**: Compare linear, tree-based, and ensemble methods
- **Hyperparameter Optimization**: Learn systematic approaches to model tuning
- **Deployment**: Build a web application for real-time predictions

### Real-World Impact
The project addresses a genuine business need:

- **Banking Sector**: Help banks identify at-risk customers before they churn
- **Customer Experience**: Enable personalized retention strategies
- **Operational Efficiency**: Optimize customer service resource allocation
- **Strategic Planning**: Provide data-driven insights for customer lifecycle management


## Proposed Approach

### 1. Data Preprocessing

**Data Cleaning**:
- Handle missing values using appropriate imputation strategies
- Detect and treat outliers using IQR method or domain knowledge
- Remove irrelevant features (RowNumber, CustomerId, Surname)

**Feature Engineering**:
- One-hot encode categorical variables (Geography, Gender)
- Create interaction features (Age × Balance, Tenure × NumOfProducts)
- Normalize numerical features using StandardScaler
- Handle class imbalance using SMOTE (Synthetic Minority Oversampling Technique)

**Data Splitting**:
- Train/Test split: 80%/20%
- Stratified sampling to maintain churn distribution
- Random state (42) for reproducibility
- Cross-validation for robust model evaluation

### 2. Model Selection (7 Models)

We will implement and compare the following classification models:

1. **Logistic Regression**
   - Baseline model for comparison
   - Fast training and interpretable coefficients
   - Provides probability scores for churn prediction
   - Good starting point for binary classification problems

2. **XGBoost Classifier**
   - Gradient boosting ensemble method
   - Often achieves state-of-the-art performance
   - Built-in regularization and feature importance
   - Handles non-linear relationships effectively

3. **Decision Tree Classifier**
   - Simple, interpretable tree-based model
   - Can capture non-linear patterns
   - Provides clear decision rules
   - Prone to overfitting without regularization

4. **Random Forest Classifier**
   - Ensemble method combining multiple decision trees
   - Reduces overfitting through averaging
   - Handles non-linear relationships and feature interactions
   - Provides feature importance rankings

5. **Gaussian Naive Bayes**
   - Probabilistic classifier based on Bayes' theorem
   - Fast and works well with small datasets
   - Assumes feature independence
   - Good baseline for comparison

6. **K-Nearest Neighbors (KNN) Classifier**
   - Instance-based learning algorithm
   - Classifies based on similarity to training examples
   - Non-parametric method
   - Sensitive to feature scaling (which we address with StandardScaler)

7. **Support Vector Machine (SVM) Classifier**
   - Finds optimal decision boundary
   - Effective with high-dimensional data
   - Can use different kernel functions for non-linear patterns
   - Robust to outliers

### 3. Hyperparameter Tuning

**Optimization Strategy**:
- Use GridSearchCV for smaller parameter spaces
- Use RandomizedSearchCV for larger parameter spaces
- 5-fold cross-validation for robust evaluation
- Optimize for recall (class 1) as primary metric, given the business priority of catching churners

**Key Hyperparameters**:
- **Logistic Regression**: C (regularization strength), penalty (L1/L2)
- **KNN**: n_neighbors, weights, distance metric
- **SVM**: C, gamma, kernel type
- **Random Forest**: n_estimators, max_depth, min_samples_split, class_weight
- **XGBoost**: learning_rate, max_depth, n_estimators, subsample, scale_pos_weight (for class imbalance)
- **Decision Tree**: max_depth, min_samples_split, min_samples_leaf, class_weight

### 4. Model Evaluation

**Classification Performance Metrics**:
- **Recall (Class 1)**: Primary metric - ability to catch customers who will churn
- **Precision**: Of predicted churners, how many actually churned
- **F1-Score**: Harmonic mean of precision and recall (balanced metric)
- **Accuracy**: Overall correctness of predictions
- **Confusion Matrix**: Detailed breakdown of predictions vs. actuals
- **Cross-validation**: 5-fold CV for robust performance estimates

**Probability-Based Regression Metrics**:
- **RMSE (Root Mean Squared Error)**: Evaluate how well probability predictions match actual churn outcomes
- **MAE (Mean Absolute Error)**: Average absolute difference between predicted probabilities and actual outcomes
- **R² Score**: Measure how well probability predictions explain variance in churn behavior
- **Probability Distribution Analysis**: Visualize and analyze the distribution of churn probabilities
- **Calibration Analysis**: Assess how well-calibrated probability estimates are (using calibration curves)

**Probability Extraction and Analysis**:
- Extract probability scores using `.predict_proba()` from all classification models
- Rank customers by churn probability to enable risk-based prioritization
- Create risk tiers (high/medium/low) based on probability thresholds
- Compare probability distributions across different customer segments
- Analyze probability calibration to ensure reliable risk estimates

**Model Comparison**:
- Compare recall scores across all models (primary criterion for classification)
- Compare RMSE and MAE for probability predictions (evaluation of probability estimation quality)
- Statistical significance testing between models
- Learning curves to assess bias-variance tradeoff
- Feature importance analysis for business insights (using XGBoost)
- Confusion matrix analysis to understand prediction patterns
- Probability calibration comparison across models

### 5. Deployment Strategy

**Web Application Development**:
- Build interactive web interface using Streamlit
- User-friendly input forms for customer data
- Real-time churn probability predictions
- Visualization of prediction confidence and feature contributions

**Model Serving**:
- Save trained models using joblib/pickle
- Implement model versioning and A/B testing capabilities
- Add input validation and error handling
- Create API endpoints for integration with other systems

### 6. Advanced Techniques

**Class Imbalance Handling**:
- Apply SMOTE to balance the training dataset
- Generate synthetic samples of the minority class (churners)
- Improve model's ability to learn patterns from churned customers

**Ensembling**:
- Implement voting classifiers to combine multiple models
- Use hard voting and soft voting strategies
- Leverage strengths of different algorithms for improved performance

### 7. LLM Integration for Explainability and Personalization

**Prediction Explanation**:
- Use Groq LLM to explain model predictions in natural language
- Explain feature contributions and their impact on churn probability
- Generate human-readable interpretations of technical model outputs
- Leverage Groq's fast inference for real-time explanations

**Personalized Email Generation**:
- Generate personalized retention emails based on churn probability predictions
- Include specific recommendations based on customer features and risk factors
- Create actionable retention strategies tailored to each customer
- Utilize Groq's low-cost, high-speed inference for efficient email generation

**Implementation Approach**:
- Integrate Groq API (OpenAI-compatible, easy to use)
- Design effective prompts for explanation and email generation
- Extract feature importance and contributions from models
- Format customer data and predictions for LLM processing
- Display explanations and emails in web application
- Take advantage of Groq's speed and affordability for responsive user experience


## Expected Outcomes

### Technical Deliverables

1. **Trained Models**
   - Seven optimized classification models with hyperparameter tuning
   - Performance comparison across all models
   - Best-performing model selected based on recall (class 1) and business criteria
   - Models saved for deployment (XGBoost, Random Forest, Decision Tree, Naive Bayes, KNN, SVM, Logistic Regression)

2. **Model Performance Analysis**
   - **Classification Metrics**: Comprehensive evaluation (Recall, Precision, F1-Score, Accuracy)
   - **Probability-Based Metrics**: RMSE, MAE, R² scores on probability predictions
   - Cross-validation results for robust performance estimates
   - Feature importance rankings and business insights (using XGBoost)
   - Confusion matrix analysis for each model
   - Probability calibration analysis and visualization
   - Comparison of models with and without SMOTE
   - Risk ranking and stratification based on probability scores

3. **Data Insights**
   - Key factors driving customer churn
   - Customer segmentation based on churn risk (binary classification)
   - **Probability-based risk ranking**: Customers ranked by churn probability for prioritization
   - **Risk tier analysis**: High/medium/low risk customer categorization
   - Recommendations for retention strategies
   - Visualization of model performance and feature contributions
   - Probability distribution analysis across customer segments

4. **Web Application**
   - Interactive interface for churn prediction
   - Real-time model inference capabilities
   - User-friendly input forms and result visualization
   - Display churn prediction (churned/stayed) with probability scores
   - **Probability-based risk ranking**: Sort and filter customers by churn probability
   - **Risk tier visualization**: Display high/medium/low risk categories
   - LLM-powered prediction explanations
   - Personalized retention email generation
   - API endpoints for integration with other systems

5. **LLM Integration**
   - Natural language explanations of model predictions
   - Feature contribution explanations
   - Personalized retention email generation
   - Actionable recommendations for customer retention

### Business Value

1. **Customer Retention Strategy**
   - Identify high-risk customers for proactive intervention
   - **Probability-based prioritization**: Rank customers by churn probability to focus on highest-risk individuals
   - **Tiered intervention strategies**: Different retention approaches for high/medium/low probability customers
   - Optimize resource allocation for customer service based on risk levels
   - Focus on maximizing recall to catch as many potential churners as possible
   - Enable data-driven decision making through both binary predictions and probability scores

2. **Risk Management**
   - Early warning system for customer churn
   - Data-driven approach to customer lifecycle management
   - Improved customer satisfaction through targeted interventions

3. **Operational Efficiency**
   - Automated churn prediction reduces manual analysis
   - Scalable solution for large customer bases
   - Integration capabilities with existing banking systems
   - LLM-generated explanations make predictions accessible to non-technical stakeholders
   - Automated email generation reduces manual content creation

### Learning Outcomes

1. **Technical Skills**
   - End-to-end machine learning pipeline development
   - Model comparison and hyperparameter optimization
   - Web application development and model deployment
   - Data preprocessing and feature engineering

2. **Business Understanding**
   - Application of ML to real-world business problems
   - Customer analytics and retention strategies
   - Performance metrics interpretation for business decisions
   - LLM integration for explainable AI and natural language generation

3. **Project Management**
   - Structured approach to ML project development
   - Documentation and reproducibility best practices
   - Model validation and deployment considerations

4. **Advanced Skills**
   - LLM API integration and prompt engineering
   - Explainable AI techniques
   - Natural language generation for business applications


## Project Timeline and Milestones

### Phase 1: Data Exploration and Preprocessing (Week 1)
- [ ] Load and explore the dataset
- [ ] Perform exploratory data analysis (EDA)
- [ ] Handle missing values and outliers
- [ ] Engineer features and encode categorical variables
- [ ] Split data into train/validation/test sets

### Phase 2: Model Development (Week 2)
- [ ] Implement Logistic Regression baseline model
- [ ] Implement XGBoost Classifier
- [ ] Implement Decision Tree Classifier
- [ ] Implement Random Forest Classifier
- [ ] Implement Gaussian Naive Bayes
- [ ] Implement K-Nearest Neighbors Classifier
- [ ] Implement Support Vector Machine Classifier

### Phase 3: Model Optimization (Week 3)
- [ ] Perform hyperparameter tuning for all models (optimize for recall)
- [ ] Cross-validation and performance evaluation
- [ ] Model comparison and selection (focus on recall for class 1)
- [ ] Feature importance analysis using XGBoost
- [ ] Implement SMOTE for class imbalance handling
- [ ] Compare models with and without SMOTE
- [ ] Implement ensemble methods (voting classifiers)
- [ ] Extract probability scores from all models using `.predict_proba()`
- [ ] Calculate regression metrics (RMSE, MAE, R²) on probability predictions
- [ ] Perform probability calibration analysis
- [ ] Create risk ranking and tiering based on probability scores

### Phase 4: Deployment and LLM Integration (Week 4)
- [ ] Build web application for model inference
- [ ] Integrate LLM API for prediction explanations
- [ ] Implement personalized email generation
- [ ] Create user interface with explanation and email features
- [ ] Test and validate deployed model and LLM integration

### Phase 5: Documentation and Finalization (Week 5)
- [ ] Document results and create visualizations
- [ ] Document both classification and probability-based analysis results
- [ ] Create visualizations comparing binary predictions vs probability scores
- [ ] Test LLM explanations and email quality
- [ ] Prepare final presentation and report
- [ ] Document LLM integration and prompt engineering approach
- [ ] Document risk ranking methodology and business applications

### Optional Challenges
- [ ] **Challenge 1**: Retrain ML models with different feature engineering and preprocessing techniques to increase accuracy
- [ ] **Challenge 2**: Experiment with additional ensemble techniques (StackingClassifier, etc.) and compare performance
- [ ] **Challenge 3**: Experiment with different LLMs and prompting techniques to generate better explanations and emails
- [ ] **Challenge 4**: Host ML models on cloud with API serving capabilities
- [ ] **Challenge 5**: Train ML models on different datasets and achieve highest accuracy possible

---

## Conclusion

This project will provide comprehensive experience in applying machine learning to solve real-world business problems. By developing a complete pipeline from data preprocessing to model deployment, and integrating LLMs for explainability and personalization, we will gain valuable insights into customer churn prediction while building practical skills in machine learning, data science, web development, and LLM integration.

The combination of technical rigor, business relevance, and AI explainability makes this project an excellent learning opportunity that bridges academic concepts with practical applications in the banking industry. Through our dual approach of binary classification and probability-based analysis, we will address the critical business need of identifying at-risk customers with both clear predictions and nuanced risk assessment. The focus on maximizing recall ensures we catch as many potential churners as possible, while probability scores enable sophisticated risk ranking and resource allocation. The integration of SMOTE for class imbalance handling, ensemble methods, and probability-based evaluation further demonstrates advanced ML techniques, while also exploring the intersection of traditional ML and modern LLM technologies. This comprehensive approach provides banks with both actionable binary decisions and flexible probability-based strategies for customer retention.
