# Training binary classification model for Jivi restart writers

## TODO: Hyper-parameter tuning. Further champion model hunt

In [0]:
%pip install shap

In [0]:
%restart_python

In [0]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Any
import warnings
warnings.filterwarnings('ignore')
import mlflow
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from sklearn.base import clone
from sklearn.model_selection import TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    roc_auc_score, precision_score, recall_score, 
    f1_score, precision_recall_curve, auc, confusion_matrix, classification_report
)
import shap

### Section for user-defined functions

In [0]:
def train_test_split_udf(
    df: pd.DataFrame,
    target_col: str,
    feature_cols: List[str],
    numeric_cols: List[str],
    train_end_month: str,
    scale: bool = False
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    """
    Prepare data for training and testing based on COHORT_MONTH.
    
    Args:
        df: Input Pandas DataFrame
        target_col: Name of target column
        feature_cols: List of feature column names
        train_end_month: End month for training data (YYYY-MM format)
        scale: Whether to apply StandardScaler to the features
    
    Returns:
        X_train, X_test, y_train, y_test as Pandas DataFrames/Series
    """
    # Ensure input is a pandas DataFrame
    if not isinstance(df, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    
    # Split data into train and test
    train_mask = pd.to_datetime(df['COHORT_MONTH']).dt.strftime('%Y-%m') <= train_end_month
    
    # Create train/test splits using pandas
    X_train = df[train_mask][feature_cols]
    X_test = df[~train_mask][feature_cols]
    y_train = df[train_mask][target_col]
    y_test = df[~train_mask][target_col]
    # Get the corresponding HCP IDs
    bh_id_train = df[train_mask]['BH_ID']
    bh_id_test = df[~train_mask]['BH_ID']

    print("No. of features in input dataframe: ", len(feature_cols))
    print("Positives/Negatives in train: \n", y_train.value_counts())
    print("Positives/Negatives in test: \n", y_test.value_counts())
    print("Shape of X_train: ", X_train.shape)
    print("Shape of X_test: ", X_test.shape)
    
    # Scale features if scale is True
    if scale:
        scaler = StandardScaler()
        X_train[numeric_cols] = pd.DataFrame(
            scaler.fit_transform(X_train[numeric_cols]),
            columns=numeric_cols,
            index=X_train.index
        )
        X_test[numeric_cols] = pd.DataFrame(
            scaler.transform(X_test[numeric_cols]),
            columns=numeric_cols,
            index=X_test.index
        )
    
    return X_train, X_test, y_train, y_test, bh_id_train, bh_id_test

In [0]:
def custom_confusion_matrix(y_true, y_pred, df):
    """
    Calculate confusion matrix based on unique BH_ID counts
    
    Parameters:
    y_true: true labels
    y_pred: predicted labels
    df: dataframe containing BH_ID column and predictions
    """
    # Create a DataFrame with true labels, predictions, and BH_ID
    results_df = pd.DataFrame({
        'y_true': y_true,
        'y_pred': y_pred,
        'BH_ID': df['BH_ID']
    })
    
    # Calculate unique BH_ID counts for each combination
    # True Negatives (TN)
    tn = len(results_df[(results_df['y_true'] == 0) & 
                       (results_df['y_pred'] == 0)]['BH_ID'].unique())
    
    # False Positives (FP)
    fp = len(results_df[(results_df['y_true'] == 0) & 
                       (results_df['y_pred'] == 1)]['BH_ID'].unique())
    
    # False Negatives (FN)
    fn = len(results_df[(results_df['y_true'] == 1) & 
                       (results_df['y_pred'] == 0)]['BH_ID'].unique())
    
    # True Positives (TP)
    tp = len(results_df[(results_df['y_true'] == 1) & 
                       (results_df['y_pred'] == 1)]['BH_ID'].unique())
    

    # Calculate metrics
    metrics = {}
    
    # Precision
    metrics['precision'] = tp / (tp + fp) if (tp + fp) > 0 else 0
    
    # Recall (Sensitivity)
    metrics['recall'] = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    # Specificity
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # F1 Score
    metrics['f1_score'] = 2 * (metrics['precision'] * metrics['recall']) / \
                         (metrics['precision'] + metrics['recall']) \
                         if (metrics['precision'] + metrics['recall']) > 0 else 0
    
    # Accuracy
    metrics['accuracy'] = (tp + tn) / (tp + tn + fp + fn)
    
    # Print metrics with formatted output
    print("\n=== Model Performance Metrics (Based on Unique BH_ID) ===")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall (Sensitivity): {metrics['recall']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print(f"F1 Score: {metrics['f1_score']:.4f}")
    print(f"Accuracy: {metrics['accuracy']:.4f}")

    return np.array([[tn, fp], [fn, tp]])

In [0]:
def plot_custom_confusion_matrix(y_test, y_pred, X_test):
    # Calculate the modified confusion matrix
    modified_cm = custom_confusion_matrix(y_test, y_pred, X_test)

    # Plot confusion matrix with counts
    plt.figure(figsize=(10, 8))
    sns.heatmap(modified_cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Predicted Negative', 'Predicted Positive'],
                yticklabels=['Actual Negative', 'Actual Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix (Based on Unique BH_ID counts)')
    plt.show()

    # Plot confusion matrix with percentages
    modified_cm_percent = modified_cm.astype('float') / modified_cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 8))
    sns.heatmap(modified_cm_percent, annot=True, fmt='.2%', cmap='Blues', cbar=False,
                xticklabels=['Predicted Negative', 'Predicted Positive'],
                yticklabels=['Actual Negative', 'Actual Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix with Percentages (Based on Unique BH_ID counts)')
    plt.show()

In [0]:
def plot_prediction_probability_distribution(y_pred_proba, y_true, figsize=(12, 6), threshold=0.5):
    """
    Plot the distribution of prediction probabilities for both classes
    
    Parameters:
    y_pred_proba: Predicted probabilities from the model
    y_true: True labels
    figsize: Size of the figure (width, height)
    threshold: Decision threshold for classification
    """
    
    # Create figure and axis
    plt.figure(figsize=figsize)
    
    # Get probabilities for class 1
    probabilities = y_pred_proba
    
    # Separate probabilities for actual positive and negative classes
    prob_positive = probabilities[y_true == 1]
    prob_negative = probabilities[y_true == 0]
    
    # Create the distribution plot
    sns.kdeplot(prob_negative, label='Class 0 (Actual)', color='blue', shade=True)
    sns.kdeplot(prob_positive, label='Class 1 (Actual)', color='red', shade=True)
    
    # Customize the plot
    plt.title('Distribution of Predicted Probabilities by Actual Class', fontsize=12)
    plt.xlabel('Predicted Probability of Class 1', fontsize=10)
    plt.ylabel('Density', fontsize=10)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    
    # Add vertical line at the specified threshold
    plt.axvline(x=threshold, color='green', linestyle='--', alpha=0.5, label=f'Decision Threshold ({threshold})')
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print("\nProbability Distribution Statistics:")
    print(f"Class 0 - Mean: {np.mean(prob_negative):.3f}, Std: {np.std(prob_negative):.3f}")
    print(f"Class 1 - Mean: {np.mean(prob_positive):.3f}, Std: {np.std(prob_positive):.3f}")

# Assuming you have your model predictions stored in y_pred_proba and actual values in the dataframe
# Replace 'your_model' with your actual model
# y_pred_proba = your_model.predict_proba(X_test)  # If you're working with test data

In [0]:
def plot_prediction_probability_histogram(y_pred_proba, y_true, bins=50, figsize=(12, 6), threshold=0.5):
    """
    Plot histogram of prediction probabilities for both classes
    
    Parameters:
    y_pred_proba: Predicted probabilities from the model
    y_true: True labels
    bins: Number of bins for histogram
    figsize: Size of the figure (width, height)
    threshold: Decision threshold for classification
    """
    
    # Create figure and axis
    plt.figure(figsize=figsize)
    
    # Get probabilities for class 1
    probabilities = y_pred_proba
    
    # Separate probabilities for actual positive and negative classes
    prob_positive = probabilities[y_true == 1]
    prob_negative = probabilities[y_true == 0]
    
    # Create histograms
    plt.hist(prob_negative, bins=bins, alpha=0.6, color='blue', 
             label=f'Class 0 (n={len(prob_negative)})', 
             density=True)
    plt.hist(prob_positive, bins=bins, alpha=0.6, color='red', 
             label=f'Class 1 (n={len(prob_positive)})', 
             density=True)
    
    # Customize the plot
    plt.title('Distribution of Predicted Probabilities by Actual Class', fontsize=12)
    plt.xlabel('Predicted Probability of Class 1', fontsize=10)
    plt.ylabel('Density', fontsize=10)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    
    # Add vertical line at the specified threshold
    plt.axvline(x=threshold, color='green', linestyle='--', alpha=0.5, 
                label=f'Decision Threshold ({threshold})')
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print("\nProbability Distribution Statistics:")
    print(f"Class 0 - Mean: {np.mean(prob_negative):.3f}, Std: {np.std(prob_negative):.3f}")
    print(f"Class 1 - Mean: {np.mean(prob_positive):.3f}, Std: {np.std(prob_positive):.3f}")
    print(f"\nClass 0 count: {len(prob_negative)}")
    print(f"Class 1 count: {len(prob_positive)}")

In [0]:
def optimize_threshold_best_precision(y_true, y_prob):
    thresholds = np.arange(0.1, 1.0, 0.1)
    best_threshold = 0.5
    best_precision = 0
    
    for threshold in thresholds:
        y_pred = (y_prob >= threshold).astype(int)
        precision = precision_score(y_true, y_pred)
        
        if precision > best_precision:
            best_precision = precision
            best_threshold = threshold
    
    return best_threshold

In [0]:
def optimize_threshold_best_recall(y_true, y_prob):
    thresholds = np.arange(0.1, 1.0, 0.1)
    best_threshold = 0.5
    best_recall = 0
    
    for threshold in thresholds:
        y_pred = (y_prob >= threshold).astype(int)
        recall = recall_score(y_true, y_pred)
        
        if recall > best_recall:
            best_recall = recall
            best_threshold = threshold
    
    return best_threshold

In [0]:
def calculate_metrics(y_true, y_pred, y_pred_proba, threshold=0.5):
    """
    Calculate and print various classification metrics.

    Parameters:
    y_true (array-like): True labels.
    y_pred (array-like): Predicted labels.
    y_pred_proba (array-like): Predicted probabilities.

    Returns:
    dict: A dictionary containing various classification metrics.
    """
    # Calculate metrics
    metrics = {
        'auc_roc': roc_auc_score(y_true, y_pred_proba),
        'precision': precision_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred),
    }

    # Calculate PR AUC
    precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
    metrics['auc_pr'] = auc(recall, precision)

    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    metrics['TNs'] = cm[0, 0]
    metrics['FPs'] = cm[0, 1]
    metrics['FNs'] = cm[1, 0]
    metrics['TPs'] = cm[1, 1]

    for metric_name, value in metrics.items():
        print(f"{metric_name}: {value:.3f}")

    print("Classification Report: ")
    print(classification_report(y_true, y_pred))

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, 
                xticklabels=['Predicted Negative', 'Predicted Positive'], 
                yticklabels=['Actual Negative', 'Actual Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix with Counts')
    plt.show()

    # Plot confusion matrix with percentages
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percent, annot=True, fmt='.2%', cmap='Blues', cbar=False, 
                xticklabels=['Predicted Negative', 'Predicted Positive'], 
                yticklabels=['Actual Negative', 'Actual Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix with Percentages')
    plt.show()

    # Plot y_pred_proba histogram with respect to classes in y
    if y_pred_proba is not None:
        if threshold != 0.5:
            # Create the plot with given threshold val
            plot_prediction_probability_histogram(y_pred_proba=y_pred_proba, 
                                                  y_true=y_true, 
                                                  threshold=threshold)
        else:
            # Create the plot with default threshold val
            plot_prediction_probability_histogram(y_pred_proba=y_pred_proba,
                                                  y_true=y_true)

    return metrics

### Start of modeling workflow

In [0]:
%run "../00_config/set-up"

In [0]:
# Month and Date parameters for manual control
first_month = "2019-12"
last_month = "2024-11"

train_start_month = "2023-01"
train_end_month = "2024-04"
test_start_month = "2024-05"
test_end_month = "2024-11"

In [0]:
# Reading the feature master table from Hivestore
hcp_feats_master_w_target_sdf = spark.sql("SELECT * FROM jivi_new_writer_model.hcp_feats_master_w_target")
print(
    "Row count: ",
    hcp_feats_master_w_target_sdf.count(),
    "Column Count: ",
    len(hcp_feats_master_w_target_sdf.columns),
)

In [0]:
# Converting Spark dataframe to Pandas dataframe
hcp_feats_master_w_target_pdf = hcp_feats_master_w_target_sdf.toPandas()

In [0]:
feat_cols = [
    col for col in hcp_feats_master_w_target_pdf.columns 
    if col not in ['BH_ID', 'COHORT_MONTH', 'JIVI_NEW_WRITER_FLG']
]
binary_cols = ['AFFL_WI_INSN', 'AFFL_WI_JIVI_HCP_12M']
numeric_cols = [col for col in feat_cols if col not in binary_cols]
target_col_nm = 'JIVI_NEW_WRITER_FLG'
print("Names of binary feats", binary_cols)
print("Number of features: ", len(feat_cols))
print("Names of numeric feats")
display(pd.DataFrame(numeric_cols, columns=['continuous_feats']))

### Splitting data for train and test

In [0]:
# Create train and test dataset
X_train, X_test, y_train, y_test, bh_id_train, bh_id_test = train_test_split_udf(
  hcp_feats_master_w_target_pdf, 
  target_col_nm, 
  feat_cols,
  numeric_cols,
  train_end_month, 
  scale=True)

**Applying Oversampling**

In [0]:
# applying oversampling for the minority class
ros = RandomOverSampler()
X_train_oversampled, y_train_oversampled = ros.fit_resample(X_train, y_train)

In [0]:
print("Positives/Negatives in dataset after oversampling: \n", y_train_oversampled.value_counts())
print("Shape of dataset after oversampling: ", X_train_oversampled.shape)

**Applying undersampling**

In [0]:
# applying undersampling for the majority class
rus = RandomUnderSampler()
X_train_undersampled, y_train_undersampled = rus.fit_resample(X_train, y_train)

In [0]:
print("Positives/Negatives in dataset after undersampling: \n", y_train_undersampled.value_counts())
print("Shape of dataset after undersampling: ", X_train_undersampled.shape)

### Lasso Logistic Regressions for both trained on Undersampled and Oversampled

In [0]:
# For Ridge regression
# logit_reg = LogisticRegression(penalty='l2', class_weight='balanced', random_state=42, max_iter=1000)

In [0]:
# Initialize the Lasso logistic regression models
logit_reg_undersampled = LogisticRegression(penalty='l1', solver='liblinear', class_weight='balanced', random_state=42, max_iter=1000)
logit_reg_oversampled = LogisticRegression(penalty='l1', solver='liblinear', class_weight='balanced', random_state=42, max_iter=1000)

In [0]:
# Log the undersampled model
with mlflow.start_run(run_name="Logistic Regression Lasso - Undersampled") as run:
    mlflow.autolog()
    logit_reg_undersampled.fit(X_train_undersampled, y_train_undersampled)
    mlflow.sklearn.log_model(logit_reg_undersampled, "logit_reg_undersampled")
    model_uri = f"runs:/{run.info.run_id}/logit_reg_undersampled"
    mlflow.register_model(model_uri, "LogisticRegressionLassoUndersampled")

# Log the oversampled model
with mlflow.start_run(run_name="Logistic Regression Lasso - Oversampled") as run:
    mlflow.autolog()
    logit_reg_oversampled.fit(X_train_oversampled, y_train_oversampled)
    mlflow.sklearn.log_model(logit_reg_oversampled, "logit_reg_oversampled")
    model_uri = f"runs:/{run.info.run_id}/logit_reg_oversampled"
    mlflow.register_model(model_uri, "LogisticRegressionLassoOversampled")

### Check the performance of the model, fitted to Undersampled class balanced dataset, against the whole dataset in the TEST period

In [0]:
# Predict on the full test dataset
y_pred = logit_reg_undersampled.predict(X_test)
y_pred_proba = logit_reg_undersampled.predict_proba(X_test)[:, 1]

calculate_metrics(y_test, y_pred, y_pred_proba)

### Go ahead with logistic regression fitted to Undersampled dataset

**Optimize probability cut-offs for Precision and Recall**

In [0]:
y_pred = logit_reg_undersampled.predict(X_test)
y_pred_proba = logit_reg_undersampled.predict_proba(X_test)[:, 1]

In [0]:
# Find optimal threshold for PRECISION
threshold = float(optimize_threshold_best_precision(y_test, y_pred_proba))
print("Optimal probability threshold for precision: ", threshold)  
# Make final predictions
y_pred = (y_pred_proba >= threshold).astype(int)
# Calculate metrics
calculate_metrics(y_test, y_pred, y_pred_proba, threshold=threshold)

In [0]:
# Find optimal threshold for RECALL
threshold = optimize_threshold_best_recall(y_test, y_pred_proba) 
print("Optimal probability threshold for recall: ", threshold)  
# Make final predictions
# y_pred = (y_pred_proba >= 0.3).astype(int)
y_pred = (y_pred_proba >= threshold).astype(int)
# Calculate metrics
calculate_metrics(y_test, y_pred, y_pred_proba, threshold=threshold)

#### Custom threshold

In [0]:
# threshold = 0.83
threshold = 0.95
y_pred = (y_pred_proba >= threshold).astype(int)
# Calculate metrics
calculate_metrics(y_test, y_pred, y_pred_proba, threshold=threshold)

#### Plotting confusion matrix based on unique HCP counts instead of based on number of total records in the negative class

In [0]:
# # ADDING THE PREDICTED PROBABILITIES TO THE MASTER DATAFRAME
# hcp_feats_master_w_target_pdf['y_pred_proba'] = y_pred_proba
# hcp_feats_master_w_target_pdf['y_pred'] = y_pred
# display(hcp_feats_master_w_target_pdf)

In [0]:
y_pred_proba = logit_reg_undersampled.predict_proba(X_test)[:, 1]
# threshold = 0.83
threshold = 0.95
y_pred = (y_pred_proba >= threshold).astype(int)

In [0]:
# ADDING THE PREDICTED PROBABILITIES TO THE MASTER DATAFRAME
X_test_copy = X_test.copy()
X_test_copy['y_pred_proba'] = y_pred_proba
X_test_copy['y_pred'] = y_pred
X_test_copy['BH_ID'] = bh_id_test
display(X_test_copy.head(15))

In [0]:
y_test.value_counts()

In [0]:
X_test_copy.y_pred.value_counts()

In [0]:
print("Unique HCPs in Trainset: ", len(set(bh_id_train)))
print("Unique HCPs in Testset: ", len(set(bh_id_test)))
bh_id_train_unq = set(bh_id_train)
bh_id_test_unq = set(bh_id_test)
common_bh_ids = bh_id_train_unq.intersection(bh_id_test_unq)
print("Common HCPs in train and test: ", len(common_bh_ids))

In [0]:
plot_custom_confusion_matrix(y_test, y_pred, X_test_copy)

In [0]:
from sklearn.metrics import precision_recall_curve, average_precision_score

def plot_precision_recall_curve(y_test, y_pred_proba):
    precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
    average_precision = average_precision_score(y_test, y_pred_proba)

    plt.figure(figsize=(10, 8))
    plt.plot(recall, precision, marker='.', label=f'Average Precision = {average_precision:.2f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.show()

In [0]:
# Assuming y_test and y_pred_proba are already defined
plot_precision_recall_curve(y_test, y_pred_proba)

In [0]:
mask = hcp_feats_master_w_target_pdf.BH_ID.isin(common_bh_ids)
hcp_feats_master_w_target_pdf_common = hcp_feats_master_w_target_pdf[mask]
hcp_feats_master_w_target_pdf_common.JIVI_NEW_WRITER_FLG.value_counts()

### Check the performance of the model, fitted to Oversampled class balanced dataset, against the whole dataset in test period

In [0]:
# Predict on the full dataset
y_pred = logit_reg_oversampled.predict(X_test)
y_pred_proba = logit_reg_oversampled.predict_proba(X_test)[:, 1]

calculate_metrics(y_test, y_pred, y_pred_proba)

In [0]:
# GET MODEL COEFFICIENTS OF UNDERSAMPLED DATA FITTED LOGISTIC REGRESSION
co_eff = logit_reg_undersampled.coef_[0]

# Put in DataFrame and sort by effect size
co_eff_df = pd.DataFrame()
co_eff_df['feature'] = feat_cols
co_eff_df['co_eff'] = co_eff
co_eff_df['abs_co_eff'] = np.abs(co_eff)
co_eff_df_sorted = co_eff_df.sort_values(by='co_eff', ascending=False, inplace=False)
display(co_eff_df_sorted)

### SHAP feature importance

In [0]:
# Initialize the SHAP explainer
explainer = shap.Explainer(logit_reg_undersampled, X_train)

# Calculate SHAP values
shap_values = explainer(X_test)
# shap_values = explainer(X_train)

# # Plot the SHAP summary plot
# shap.summary_plot(shap_values, X_test, feature_names=feat_cols)
shap.plots.beeswarm(shap_values)

In [0]:
shap.summary_plot(shap_values = explainer(X_test), 
                  features = X_undersampled.values,
                  feature_names = X_undersampled.columns.values,
                  plot_type='dot',
                  max_display=15,
                  show=False)
plt.tight_layout(rect=[0, 0, 2, 1])
plt.show()

In [0]:
shap.summary_plot(shap_values = explainer(X_train), 
                  features = X_train.values,
                  feature_names = X_train.columns.values,
                  plot_type='bar',
                  max_display=15,
                  show=False)
plt.tight_layout(rect=[0, 0, 2, 1])
plt.show()

In [0]:
# Initialize SHAP JavaScript visualization
shap.initjs()

# Select an index for the SHAP force plot
ind = 1

# Plot the SHAP force plot
shap.force_plot(shap_values[ind], matplotlib=True)

In [0]:
# top 20 features to show importance
max_display = 20

# For linear models, use coefficients directly
importance = np.abs(logit_reg_undersampled.coef_[0])
feature_importance_df = pd.DataFrame({
    'feature': feat_cols,
    'importance': importance
})
feature_importance_df = feature_importance_df.sort_values(
    'importance', ascending=False
).head(max_display)

plt.figure(figsize=(10, 8))
plt.barh(
    range(len(feature_importance_df)),
    feature_importance_df['importance']
)
plt.yticks(
    range(len(feature_importance_df)),
    feature_importance_df['feature']
)
plt.xlabel('|Coefficient|')
plt.title('Feature Importance (Logistic Regression Coefficients)')