In [3]:
import os
import json
import joblib
import pandas as pd
import shap
import matplotlib.pyplot as plt
import sys
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

# Define project_root
project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.append(project_root)
from src.data_processing.feature_engineer import feature_engineer

def train_and_compare_models(
    input_path="../data/processed/cleaned_customer_purchase_data.csv",
    model_paths={"xgboost": "../models/xgboost_v1.joblib", "random_forest": "../models/rf_v1.joblib"},
    feature_path="../data/outputs/feature_columns.json",
    shap_path="../data/outputs/xgboost_shap_summary_v1.png",
    test_size=0.2,
    random_state=42,
    save_artifacts=True,
    verbose=True
):
    # Load and preprocess
    df = pd.read_csv(input_path)
    df = feature_engineer(df)

    X = df.drop("PurchaseStatus", axis=1)
    y = df["PurchaseStatus"]
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=y
    )

    # Initialize models
    models = {
        "XGBoost": XGBClassifier(eval_metric="logloss", random_state=random_state),
        "Random Forest": RandomForestClassifier(random_state=random_state)
    }

    results = {}

    # Train and evaluate each model
    for model_name, model in models.items():
        model.fit(X_train, y_train)
        y_pred = model.predict(X_val)
        
        # Store results
        accuracy = accuracy_score(y_val, y_pred)
        results[model_name] = accuracy
        
        if verbose:
            print(f"📈 {model_name} Classification Report:\n", classification_report(y_val, y_pred))

        # Generate SHAP values and summary plot for XGBoost only
        if model_name == "XGBoost":
            try:
                explainer = shap.TreeExplainer(model)
                shap_values = explainer.shap_values(X_val)
                summary = shap_values[1] if isinstance(shap_values, list) else shap_values
                shap.summary_plot(summary, X_val, show=False)
                if save_artifacts:
                    os.makedirs(os.path.dirname(shap_path), exist_ok=True)
                    plt.tight_layout()
                    plt.savefig(shap_path)
                    if verbose:
                        print(f"📸 SHAP plot saved to: {shap_path}")
                plt.close()
            except Exception as e:
                if verbose:
                    print(f"⚠️ Failed to generate SHAP plot for {model_name}: {e}")

        # Save model
        if save_artifacts:
            model_key = model_name.lower().replace(" ", "_")  # Use lowercase and replace spaces with underscores
            os.makedirs(os.path.dirname(model_paths[model_key]), exist_ok=True)
            joblib.dump(model, model_paths[model_key])
            if verbose:
                print(f"💾 {model_name} model saved to: {model_paths[model_key]}")

    # Visualization of model performance
    plt.figure(figsize=(8, 5))
    plt.bar(results.keys(), results.values(), color=['blue', 'green'])
    plt.ylabel('Accuracy')
    plt.title('Model Comparison: Accuracy of XGBoost vs Random Forest')
    plt.ylim(0, 1)
    plt.grid(axis='y')
    plt.savefig("../data/outputs/model_comparison.png")
    plt.show()

    # Save feature columns
    if save_artifacts:
        with open(feature_path, "w") as f:
            json.dump(list(X.columns), f)
        if verbose:
            print(f"📜 Feature columns saved to: {feature_path}")

    return models, X_val, y_val

if __name__ == "__main__":
    train_and_compare_models()

📈 XGBoost Classification Report:
               precision    recall  f1-score   support

           0       0.91      0.92      0.91       148
           1       0.91      0.89      0.90       130

    accuracy                           0.91       278
   macro avg       0.91      0.91      0.91       278
weighted avg       0.91      0.91      0.91       278

📸 SHAP plot saved to: ../data/outputs/xgboost_shap_summary_v1.png


KeyError: 'xgboost'