# Training binary classification model for Jivi restart writers

In [0]:
%pip install shap

In [0]:
# %restart_python

In [0]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Any
import warnings
warnings.filterwarnings('ignore')
import mlflow
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from sklearn.base import clone
from sklearn.model_selection import TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    roc_auc_score, precision_score, recall_score, 
    f1_score, precision_recall_curve, auc, confusion_matrix, classification_report
)
import shap

In [0]:
%run "../00_config/set-up"

In [0]:
# Month and Date parameters for manual control
first_month = "2019-12"
last_month = "2024-11"

In [0]:
# Reading the feature master table from Hivestore
hcp_feats_master_w_target_sdf = spark.sql("SELECT * FROM jivi_new_writer_model.hcp_feats_master_w_target")
print(
    "Row count: ",
    hcp_feats_master_w_target_sdf.count(),
    "Column Count: ",
    len(hcp_feats_master_w_target_sdf.columns),
)

In [0]:
# Converting Spark dataframe to Pandas dataframe
hcp_feats_master_w_target_pdf = hcp_feats_master_w_target_sdf.toPandas()

In [0]:
feat_cols = [col for col in hcp_feats_master_w_target_pdf.columns if col not in ['BH_ID', 'COHORT_MONTH', 'JIVI_NEW_WRITER_FLG']]
binary_cols = ['AFFL_WI_INSN', 'AFFL_WI_JIVI_HCP_12M']
numeric_cols = [col for col in feat_cols if col not in binary_cols]
target_col_nm = 'JIVI_NEW_WRITER_FLG'
print("Names of binary feats", binary_cols)
print("Names of numeric feats", numeric_cols)
print("Number of features: ", len(feat_cols))

In [0]:
X = hcp_feats_master_w_target_pdf[feat_cols]
y = hcp_feats_master_w_target_pdf[target_col_nm]
print("Positives/Negatives in the dataset: \n", y.value_counts())
print("Shape of dataset before oversampling: ", hcp_feats_master_w_target_pdf.shape)

**Applying Oversampling**

In [0]:
# applying oversampling for the minority class
ros = RandomOverSampler()
X_oversampled, y_oversampled = ros.fit_resample(X, y)
hcp_feats_master_w_target_oversampled_pdf = pd.concat([X_oversampled, y_oversampled], axis=1)

In [0]:
print("Positives/Negatives in dataset after oversampling: \n", hcp_feats_master_w_target_oversampled_pdf.JIVI_NEW_WRITER_FLG.value_counts())
print("Shape of dataset after oversampling: ", hcp_feats_master_w_target_oversampled_pdf.shape)

**Applying undersampling**

In [0]:
# applying undersampling for the majority class
rus = RandomUnderSampler()
X_undersampled, y_undersampled = rus.fit_resample(X, y)
hcp_feats_master_w_target_undersampled_pdf = pd.concat([X_undersampled, y_undersampled], axis=1)

In [0]:
print("Positives/Negatives in dataset after undersampling: \n", hcp_feats_master_w_target_undersampled_pdf.JIVI_NEW_WRITER_FLG.value_counts())
print("Shape of dataset after undersampling: ", hcp_feats_master_w_target_undersampled_pdf.shape)

### Overall logistic regression performs consistently and undersampling seems to be working better than oversampling for class imbalance

In [0]:
# For Ridge regression
# logit_reg = LogisticRegression(penalty='l2', class_weight='balanced', random_state=42, max_iter=1000)

In [0]:
# Initialize the Lasso logistic regression models
logit_reg_undersampled = LogisticRegression(penalty='l1', solver='liblinear', class_weight='balanced', random_state=42, max_iter=1000)
logit_reg_oversampled = LogisticRegression(penalty='l1', solver='liblinear', class_weight='balanced', random_state=42, max_iter=1000)

In [0]:
# Fit the models on the respective datasets
mlflow.autolog(disable=True)
logit_reg_undersampled.fit(X_undersampled, y_undersampled)
logit_reg_oversampled.fit(X_oversampled, y_oversampled)

In [0]:
# Log the undersampled model
with mlflow.start_run(run_name="Logistic Regression Lasso - Undersampled") as run:
    logit_reg_undersampled.fit(X_undersampled, y_undersampled)
    mlflow.sklearn.log_model(logit_reg_undersampled, "logit_reg_undersampled")
    model_uri = f"runs:/{run.info.run_id}/logit_reg_undersampled"
    mlflow.register_model(model_uri, "LogisticRegressionLassoUndersampled")

# Log the oversampled model
with mlflow.start_run(run_name="Logistic Regression Lasso - Oversampled") as run:
    logit_reg_oversampled.fit(X_oversampled, y_oversampled)
    mlflow.sklearn.log_model(logit_reg_oversampled, "logit_reg_oversampled")
    model_uri = f"runs:/{run.info.run_id}/logit_reg_oversampled"
    mlflow.register_model(model_uri, "LogisticRegressionLassoOversampled")

### Check the model performance after fitting to the whole dataset to assess what can be the upper bounds on the model performance on the unseen/new data

In [0]:
# Predict on the full dataset
y_pred = logit_reg.predict(X)
y_pred_proba = logit_reg.predict_proba(X)[:, 1]

# Calculate metrics
metrics = {
    'auc_roc': roc_auc_score(y, y_pred_proba),
    'precision': precision_score(y, y_pred),
    'recall': recall_score(y, y_pred),
    'f1': f1_score(y, y_pred),
}

# Calculate PR AUC
precision, recall, _ = precision_recall_curve(y, y_pred_proba)
metrics['auc_pr'] = auc(recall, precision)

# Calculate confusion matrix
cm = confusion_matrix(y, y_pred)
metrics['TNs'] = cm[0, 0]
metrics['FPs'] = cm[0, 1]
metrics['FNs'] = cm[1, 0]
metrics['TPs'] = cm[1, 1]


for metric_name, value in metrics.items():
  print(f"{metric_name}: {value:.3f}")
  
print("Classification Report: ")
print(classification_report(y, y_pred))

if y_pred_proba is not None:
    plt.figure(figsize=(8, 6))
    sns.histplot(y_pred_proba, bins=50)
    plt.title('Prediction Probability Distribution')
    plt.show()

In [0]:
hcp_feats_master_w_target_pdf['y_pred_proba'] = y_pred_proba

In [0]:
display(hcp_feats_master_w_target_pdf)

In [0]:
def plot_prediction_probability_distribution(y_pred_proba, y_true, figsize=(12, 6)):
    """
    Plot the distribution of prediction probabilities for both classes
    
    Parameters:
    y_pred_proba: Predicted probabilities from the model
    y_true: True labels
    figsize: Size of the figure (width, height)
    """
    
    # Create figure and axis
    plt.figure(figsize=figsize)
    
    # Get probabilities for class 1
    probabilities = y_pred_proba
    
    # Separate probabilities for actual positive and negative classes
    prob_positive = probabilities[y_true == 1]
    prob_negative = probabilities[y_true == 0]
    
    # Create the distribution plot
    sns.kdeplot(prob_negative, label='Class 0 (Actual)', color='blue', shade=True)
    sns.kdeplot(prob_positive, label='Class 1 (Actual)', color='red', shade=True)
    
    # Customize the plot
    plt.title('Distribution of Predicted Probabilities by Actual Class', fontsize=12)
    plt.xlabel('Predicted Probability of Class 1', fontsize=10)
    plt.ylabel('Density', fontsize=10)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    
    # Add vertical line at 0.5 threshold
    plt.axvline(x=0.5, color='green', linestyle='--', alpha=0.5, label='Decision Threshold (0.5)')
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print("\nProbability Distribution Statistics:")
    print(f"Class 0 - Mean: {np.mean(prob_negative):.3f}, Std: {np.std(prob_negative):.3f}")
    print(f"Class 1 - Mean: {np.mean(prob_positive):.3f}, Std: {np.std(prob_positive):.3f}")

# Assuming you have your model predictions stored in y_pred_proba and actual values in the dataframe
# Replace 'your_model' with your actual model
# y_pred_proba = your_model.predict_proba(X_test)  # If you're working with test data

# Create the plot
plot_prediction_probability_distribution(
    y_pred_proba,  # Your predicted probabilities
    hcp_feats_master_w_target_pdf['JIVI_NEW_WRITER_FLG']  # Your actual labels
)

In [0]:
def plot_prediction_probability_histogram(y_pred_proba, y_true, bins=50, figsize=(12, 6)):
    """
    Plot histogram of prediction probabilities for both classes
    
    Parameters:
    y_pred_proba: Predicted probabilities from the model
    y_true: True labels
    bins: Number of bins for histogram
    figsize: Size of the figure (width, height)
    """
    
    # Create figure and axis
    plt.figure(figsize=figsize)
    
    # Get probabilities for class 1
    probabilities = y_pred_proba
    
    # Separate probabilities for actual positive and negative classes
    prob_positive = probabilities[y_true == 1]
    prob_negative = probabilities[y_true == 0]
    
    # Create histograms
    plt.hist(prob_negative, bins=bins, alpha=0.6, color='blue', 
             label=f'Class 0 (n={len(prob_negative)})', 
             density=True)
    plt.hist(prob_positive, bins=bins, alpha=0.6, color='red', 
             label=f'Class 1 (n={len(prob_positive)})', 
             density=True)
    
    # Customize the plot
    plt.title('Distribution of Predicted Probabilities by Actual Class', fontsize=12)
    plt.xlabel('Predicted Probability of Class 1', fontsize=10)
    plt.ylabel('Density', fontsize=10)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    
    # Add vertical line at 0.5 threshold
    plt.axvline(x=0.5, color='green', linestyle='--', alpha=0.5, 
                label='Decision Threshold (0.5)')
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print("\nProbability Distribution Statistics:")
    print(f"Class 0 - Mean: {np.mean(prob_negative):.3f}, Std: {np.std(prob_negative):.3f}")
    print(f"Class 1 - Mean: {np.mean(prob_positive):.3f}, Std: {np.std(prob_positive):.3f}")
    print(f"\nClass 0 count: {len(prob_negative)}")
    print(f"Class 1 count: {len(prob_positive)}")

# Create the plot
plot_prediction_probability_histogram(
    y_pred_proba,  # Your predicted probabilities
    hcp_feats_master_w_target_pdf['JIVI_NEW_WRITER_FLG']  # Your actual labels
)

**Optimize probability cut-offs for Precision and Recall**

In [0]:
def optimize_threshold_best_precision(y_true, y_prob):
    thresholds = np.arange(0.1, 1.0, 0.1)
    best_threshold = 0.5
    best_precision = 0
    
    for threshold in thresholds:
        y_pred = (y_prob >= threshold).astype(int)
        precision = precision_score(y_true, y_pred)
        
        if precision > best_precision:
            best_precision = precision
            best_threshold = threshold
    
    return best_threshold

In [0]:
def optimize_threshold_best_recall(y_true, y_prob):
    thresholds = np.arange(0.1, 1.0, 0.1)
    best_threshold = 0.5
    best_recall = 0
    
    for threshold in thresholds:
        y_pred = (y_prob >= threshold).astype(int)
        recall = recall_score(y_true, y_pred)
        
        if recall > best_recall:
            best_recall = recall
            best_threshold = threshold
    
    return best_threshold

In [0]:
# Find optimal threshold for precision
threshold = optimize_threshold_best_precision(y, y_pred_proba) 
print("Optimal probability threshold for precision: ", threshold)  
# Make final predictions
y_pred = (y_pred_proba >= threshold).astype(int)

In [0]:
# Calculate metrics
metrics = {
    'auc_roc': roc_auc_score(y, y_pred_proba),
    'precision': precision_score(y, y_pred),
    'recall': recall_score(y, y_pred),
    'f1': f1_score(y, y_pred),
}

# Calculate PR AUC
precision, recall, _ = precision_recall_curve(y, y_pred_proba)
metrics['auc_pr'] = auc(recall, precision)

# Calculate confusion matrix
cm = confusion_matrix(y, y_pred)
metrics['TNs'] = cm[0, 0]
metrics['FPs'] = cm[0, 1]
metrics['FNs'] = cm[1, 0]
metrics['TPs'] = cm[1, 1]


for metric_name, value in metrics.items():
  print(f"{metric_name}: {value:.3f}")
  
print("Classification Report: ")
print(classification_report(y, y_pred))

In [0]:
# Find optimal threshold for recall
threshold = optimize_threshold_best_recall(y, y_pred_proba) 
print("Optimal probability threshold for recall: ", threshold)  
# Make final predictions
# y_pred = (y_pred_proba >= 0.3).astype(int)
y_pred = (y_pred_proba >= threshold).astype(int)

In [0]:
# Calculate metrics
metrics = {
    'auc_roc': roc_auc_score(y, y_pred_proba),
    'precision': precision_score(y, y_pred),
    'recall': recall_score(y, y_pred),
    'f1': f1_score(y, y_pred),
}

# Calculate PR AUC
precision, recall, _ = precision_recall_curve(y, y_pred_proba)
metrics['auc_pr'] = auc(recall, precision)

# Calculate confusion matrix
cm = confusion_matrix(y, y_pred)
metrics['TNs'] = cm[0, 0]
metrics['FPs'] = cm[0, 1]
metrics['FNs'] = cm[1, 0]
metrics['TPs'] = cm[1, 1]


for metric_name, value in metrics.items():
  print(f"{metric_name}: {value:.3f}")
  
print("Classification Report: ")
print(classification_report(y, y_pred))

In [0]:
# # Get model co_efficients
# co_eff = logit_reg.coef_[0]

# # Put in DataFrame and sort by effect size
# co_eff_df = pd.DataFrame()
# co_eff_df['feature'] = feat_cols
# co_eff_df['co_eff'] = co_eff
# co_eff_df['abs_co_eff'] = np.abs(co_eff)
# co_eff_df_sorted = co_eff_df.sort_values(by='abs_co_eff', ascending=False, inplace=False)
# display(co_eff_df_sorted)

### SHAP feature importance

In [0]:
# Initialize the SHAP explainer
explainer = shap.Explainer(logit_reg, X_train)

# Calculate SHAP values
shap_values = explainer(X_test)
# shap_values = explainer(X_train)

# Plot the SHAP summary plot
# shap.summary_plot(shap_values, X_test, feature_names=feat_cols)
shap.plots.beeswarm(shap_values)

In [0]:
shap.summary_plot(shap_values = explainer(X_train), 
                  features = X_train.values,
                  feature_names = X_train.columns.values,
                  plot_type='bar',
                  max_display=15,
                  show=False)
plt.tight_layout(rect=[0, 0, 2, 1])
plt.show()

In [0]:
# # Initialize SHAP JavaScript visualization
# shap.initjs()

# # Select an index for the SHAP force plot
# ind = 1

# # Plot the SHAP force plot
# shap.force_plot(shap_values[ind], matplotlib=True)

In [0]:
# top 20 features to show importance
max_display = 20

# For linear models, use coefficients directly
importance = np.abs(logit_reg.coef_[0])
feature_importance_df = pd.DataFrame({
    'feature': feat_cols,
    'importance': importance
})
feature_importance_df = feature_importance_df.sort_values(
    'importance', ascending=False
).head(max_display)

plt.figure(figsize=(10, 8))
plt.barh(
    range(len(feature_importance_df)),
    feature_importance_df['importance']
)
plt.yticks(
    range(len(feature_importance_df)),
    feature_importance_df['feature']
)
plt.xlabel('|Coefficient|')
plt.title('Feature Importance (Logistic Regression Coefficients)')

### Extracting reasons based on SHAP

In [0]:
# Convert SHAP values to a DataFrame
shap_values_df = pd.DataFrame(shap_values.values, columns=feat_cols)

# Get the feature with the highest absolute SHAP value for each instance
top_reasons = shap_values_df.abs().idxmax(axis=1)

# Create a DataFrame to store the top reason and its SHAP value
top_reasons_df = pd.DataFrame({
    'instance': np.arange(len(top_reasons)),
    'top_reason': top_reasons,
    'shap_value': shap_values_df.lookup(np.arange(len(top_reasons)), top_reasons)
})

display(top_reasons_df)