## 1. Setup and Dependencies
Installing required packages and configuring the environment for mushroom classification analysis.

In [2]:
pip install -q catboost shap optuna scikit-learn matplotlib pandas numpy psutil

Note: you may need to restart the kernel to use updated packages.


In [3]:
# Suppress warnings and set seeds
import warnings
warnings.filterwarnings('ignore')
import os, random, json, gc, hashlib, psutil

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from catboost import CatBoostClassifier
import shap
import optuna
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve
from sklearn.inspection import permutation_importance

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
optuna.logging.set_verbosity(optuna.logging.WARNING)
DATA_DIR = './data'
FIG_DIR = './figures'
MODEL_DIR = './models'
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(FIG_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

## 2. Data Loading and Initial Exploration
Loading the mushroom dataset and performing initial data exploration to understand the structure and characteristics of the data.

In [21]:
import itertools

file_path = f"{DATA_DIR}/mushroom.csv"

df = pd.read_csv(file_path)
display(df.head())

Unnamed: 0,class,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,gill-attachment,gill-spacing,gill-color,stem-height,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
0,p,15.26,x,g,o,f,e,,w,16.95,...,s,y,w,u,w,t,g,,d,w
1,p,16.6,x,g,o,f,e,,w,17.99,...,s,y,w,u,w,t,g,,d,u
2,p,14.07,x,g,o,f,e,,w,17.8,...,s,y,w,u,w,t,g,,d,w
3,p,14.17,f,h,e,f,e,,w,15.77,...,s,y,w,u,w,t,p,,d,w
4,p,14.64,x,h,o,f,e,,w,16.53,...,s,y,w,u,w,t,p,,d,w


In [6]:
# Dataset overview
print("Shape:", df.shape)
print("Columns:", df.columns.tolist())
print("Dtypes:\n", df.dtypes)

Shape: (61069, 21)
Columns: ['class', 'cap-diameter', 'cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-height', 'stem-width', 'stem-root', 'stem-surface', 'stem-color', 'veil-type', 'veil-color', 'has-ring', 'ring-type', 'spore-print-color', 'habitat', 'season']
Dtypes:
 class                    object
cap-diameter            float64
cap-shape                object
cap-surface              object
cap-color                object
does-bruise-or-bleed     object
gill-attachment          object
gill-spacing             object
gill-color               object
stem-height             float64
stem-width              float64
stem-root                object
stem-surface             object
stem-color               object
veil-type                object
veil-color               object
has-ring                 object
ring-type                object
spore-print-color        object
habitat                  object
season           

## 3. Exploratory Data Analysis (EDA)
Comprehensive analysis of the dataset including missing values, feature distributions, and class balance visualization.

In [7]:
# Missing values analysis
missing_pct = df.isna().mean() * 100
plt.figure()
missing_pct.plot.bar()
plt.title("Missing Value Percentage per Column")
plt.ylabel("Percentage (%)")
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/eda_missing.png')
plt.close()
print(missing_pct)

class                    0.000000
cap-diameter             0.000000
cap-shape                0.000000
cap-surface             23.121387
cap-color                0.000000
does-bruise-or-bleed     0.000000
gill-attachment         16.184971
gill-spacing            41.040462
gill-color               0.000000
stem-height              0.000000
stem-width               0.000000
stem-root               84.393064
stem-surface            62.427746
stem-color               0.000000
veil-type               94.797688
veil-color              87.861272
has-ring                 0.000000
ring-type                4.046243
spore-print-color       89.595376
habitat                  0.000000
season                   0.000000
dtype: float64


In [8]:
# Descriptive statistics
numeric_cols = df.select_dtypes(include=['int64','float64']).columns.tolist()
cat_cols = df.select_dtypes(include=['object']).columns.tolist()

print("Numeric columns:", numeric_cols)
print("Categorical columns:", cat_cols)

if numeric_cols:
    num_summary = pd.DataFrame({
        'mean': df[numeric_cols].mean(),
        'median': df[numeric_cols].median(),
        'mode': df[numeric_cols].mode().iloc[0],
        'std': df[numeric_cols].std()
    })
    print(num_summary)

for col in cat_cols:
    vc = df[col].value_counts().head(10)
    plt.figure()
    vc.plot.bar()
    plt.title(f"Top 10 Value Counts for {col}")
    plt.tight_layout()
    plt.savefig(f'{FIG_DIR}/eda_vc_{col}.png')
    plt.close()

Numeric columns: ['cap-diameter', 'stem-height', 'stem-width']
Categorical columns: ['class', 'cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-root', 'stem-surface', 'stem-color', 'veil-type', 'veil-color', 'has-ring', 'ring-type', 'spore-print-color', 'habitat', 'season']
                   mean  median  mode        std
cap-diameter   6.733854    5.86  3.18   5.264845
stem-height    6.581538    5.95  0.00   3.370017
stem-width    12.149410   10.19  0.00  10.035955


In [9]:
# Numeric distributions
for col in numeric_cols:
    plt.figure()
    df[col].plot.hist(density=True, alpha=0.6)
    df[col].plot.kde()
    plt.title(f"Density & Histogram for {col}")
    plt.tight_layout()
    plt.savefig(f'{FIG_DIR}/eda_dist_{col}.png')
    plt.close()

In [10]:
# Class balance
plt.figure()
df['class'].value_counts().plot.bar()
plt.title("Class Balance")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/eda_class_balance.png')
plt.close()

In [11]:
# Pre-processing pipeline
print("Missing before:\n", df.isna().sum())

df[cat_cols] = df[cat_cols].fillna('unknown').astype(str)
df[numeric_cols] = df[numeric_cols].astype('float32')

print("Missing after:\n", df.isna().sum())

df['target'] = df['class'].map({'e':0, 'p':1})
df.drop(columns=['class'], inplace=True)

X = df.drop(columns=['target'])
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=SEED
)

cat_cols_train = X_train.select_dtypes(include=['object']).columns.tolist()
print("Categorical features:", cat_cols_train)

Missing before:
 class                       0
cap-diameter                0
cap-shape                   0
cap-surface             14120
cap-color                   0
does-bruise-or-bleed        0
gill-attachment          9884
gill-spacing            25063
gill-color                  0
stem-height                 0
stem-width                  0
stem-root               51538
stem-surface            38124
stem-color                  0
veil-type               57892
veil-color              53656
has-ring                    0
ring-type                2471
spore-print-color       54715
habitat                     0
season                      0
dtype: int64
Missing after:
 class                   0
cap-diameter            0
cap-shape               0
cap-surface             0
cap-color               0
does-bruise-or-bleed    0
gill-attachment         0
gill-spacing            0
gill-color              0
stem-height             0
stem-width              0
stem-root               0
stem-surface

## 4. Data Preprocessing and Feature Engineering
Cleaning the data, handling missing values, encoding categorical variables, and preparing the dataset for machine learning.

## 5. Baseline Model - Logistic Regression
Establishing a baseline performance using Logistic Regression with one-hot encoding for categorical features.

In [12]:
# Baseline model: Logistic Regression
baseline_pipe = Pipeline([
    ('ohe', OneHotEncoder(handle_unknown='ignore')),
    ('clf', LogisticRegression(random_state=SEED, max_iter=1000))
])
baseline_pipe.fit(X_train, y_train)

y_pred_base = baseline_pipe.predict(X_test)
y_proba_base = baseline_pipe.predict_proba(X_test)[:,1]

print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred_base))
print("Classification Report:\n", classification_report(y_test, y_pred_base))
auc_base = roc_auc_score(y_test, y_proba_base)
print(f"ROC-AUC: {auc_base:.4f}")

fpr, tpr, _ = roc_curve(y_test, y_proba_base)
plt.figure()
plt.plot(fpr, tpr)
plt.plot([0,1],[0,1], linestyle='--')
plt.title("Baseline ROC Curve")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/roc_baseline.png')
plt.close()

Confusion Matrix:
 [[4606  830]
 [ 769 6009]]
Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.85      0.85      5436
           1       0.88      0.89      0.88      6778

    accuracy                           0.87     12214
   macro avg       0.87      0.87      0.87     12214
weighted avg       0.87      0.87      0.87     12214

ROC-AUC: 0.9416


In [None]:
import multiprocessing

# Optimized CatBoost with dynamic resource allocation

# Get available CPU cores
n_cores = multiprocessing.cpu_count()
print(f"Available CPU cores: {n_cores}")

cat_indices = [X_train.columns.get_loc(c) for c in cat_cols_train]

def objective(trial):
    params = {
        'iterations': 500,
        'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 0.2),
        'depth': trial.suggest_int('depth', 4, 6),
        'l2_leaf_reg': trial.suggest_loguniform('l2_leaf_reg', 1, 5),
        'random_seed': SEED,
        'auto_class_weights': 'Balanced',   
        'verbose': 0,
        'early_stopping_rounds': 30,
        'thread_count': n_cores  # Use all available cores
    }
    
    cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=SEED)
    losses = []
    
    for fold_idx, (train_idx, val_idx) in enumerate(cv.split(X_train, y_train)):
        print(f"    Trial {trial.number}: Fold {fold_idx + 1}/3")
        X_tr, X_val = X_train.iloc[train_idx], X_train.iloc[val_idx]
        y_tr, y_val = y_train.iloc[train_idx], y_train.iloc[val_idx]
        
        m = CatBoostClassifier(**params)
        m.fit(X_tr, y_tr, cat_features=cat_indices, eval_set=(X_val, y_val))
        losses.append(m.get_best_score()['validation']['Logloss'])
        
    return np.mean(losses)

study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler(seed=SEED))
print("Starting hyperparameter optimization...")
study.optimize(objective, n_trials=10, show_progress_bar=True)

best_params = study.best_params
print("Best parameters:", best_params)

# Train final model with optimized parameters
print("Training final model...")
model = CatBoostClassifier(**best_params,
                           iterations=500,
                           auto_class_weights='Balanced',
                           random_seed=SEED,
                           verbose=100,  # Show progress every 100 iterations
                           thread_count=n_cores)  # Use all cores
model.fit(X_train, y_train, cat_features=cat_indices)
model.save_model(f'{MODEL_DIR}/catboost_model.cbm')
print("Model training completed!")

Available CPU cores: 12
Starting hyperparameter optimization...


  0%|          | 0/10 [00:00<?, ?it/s]

    Trial 0: Fold 1/3
    Trial 0: Fold 2/3
    Trial 0: Fold 2/3
    Trial 0: Fold 3/3
    Trial 0: Fold 3/3


Best trial: 0. Best value: 0.00628773:  10%|█         | 1/10 [00:18<02:48, 18.67s/it]

    Trial 1: Fold 1/3
    Trial 1: Fold 2/3
    Trial 1: Fold 2/3
    Trial 1: Fold 3/3
    Trial 1: Fold 3/3


Best trial: 1. Best value: 0.00621021:  20%|██        | 2/10 [00:31<02:03, 15.43s/it]

    Trial 2: Fold 1/3
    Trial 2: Fold 2/3
    Trial 2: Fold 2/3
    Trial 2: Fold 3/3
    Trial 2: Fold 3/3


Best trial: 1. Best value: 0.00621021:  30%|███       | 3/10 [00:50<01:58, 16.87s/it]

    Trial 3: Fold 1/3
    Trial 3: Fold 2/3
    Trial 3: Fold 2/3
    Trial 3: Fold 3/3
    Trial 3: Fold 3/3


Best trial: 1. Best value: 0.00621021:  40%|████      | 4/10 [01:02<01:30, 15.12s/it]

    Trial 4: Fold 1/3
    Trial 4: Fold 2/3
    Trial 4: Fold 2/3
    Trial 4: Fold 3/3
    Trial 4: Fold 3/3


Best trial: 4. Best value: 0.00316043:  50%|█████     | 5/10 [01:15<01:11, 14.35s/it]

    Trial 5: Fold 1/3
    Trial 5: Fold 2/3
    Trial 5: Fold 2/3
    Trial 5: Fold 3/3
    Trial 5: Fold 3/3


Best trial: 4. Best value: 0.00316043:  60%|██████    | 6/10 [01:29<00:55, 13.95s/it]

    Trial 6: Fold 1/3
    Trial 6: Fold 2/3
    Trial 6: Fold 2/3
    Trial 6: Fold 3/3
    Trial 6: Fold 3/3


Best trial: 4. Best value: 0.00316043:  70%|███████   | 7/10 [01:42<00:41, 13.67s/it]

    Trial 7: Fold 1/3
    Trial 7: Fold 2/3
    Trial 7: Fold 2/3
    Trial 7: Fold 3/3
    Trial 7: Fold 3/3


Best trial: 4. Best value: 0.00316043:  80%|████████  | 8/10 [01:55<00:27, 13.54s/it]

    Trial 8: Fold 1/3
    Trial 8: Fold 2/3
    Trial 8: Fold 2/3
    Trial 8: Fold 3/3
    Trial 8: Fold 3/3


Best trial: 4. Best value: 0.00316043:  90%|█████████ | 9/10 [02:14<00:15, 15.27s/it]

    Trial 9: Fold 1/3
    Trial 9: Fold 2/3
    Trial 9: Fold 2/3
    Trial 9: Fold 3/3
    Trial 9: Fold 3/3


Best trial: 4. Best value: 0.00316043: 100%|██████████| 10/10 [02:30<00:00, 15.04s/it]



Best parameters: {'learning_rate': 0.12106896936002161, 'depth': 4, 'l2_leaf_reg': 1.3399549522183016}
Training final model...
0:	learn: 0.6598917	total: 16.3ms	remaining: 8.15s
100:	learn: 0.0168697	total: 1.11s	remaining: 4.38s
100:	learn: 0.0168697	total: 1.11s	remaining: 4.38s
200:	learn: 0.0095387	total: 2.16s	remaining: 3.21s
200:	learn: 0.0095387	total: 2.16s	remaining: 3.21s
300:	learn: 0.0071895	total: 3.16s	remaining: 2.09s
300:	learn: 0.0071895	total: 3.16s	remaining: 2.09s
400:	learn: 0.0057588	total: 4.17s	remaining: 1.03s
400:	learn: 0.0057588	total: 4.17s	remaining: 1.03s
499:	learn: 0.0044922	total: 5.17s	remaining: 0us
Model training completed!
499:	learn: 0.0044922	total: 5.17s	remaining: 0us
Model training completed!


## 6. Advanced Model - CatBoost with Hyperparameter Optimization
Training an optimized CatBoost classifier using Optuna for hyperparameter optimization and leveraging categorical feature handling.

## 7. Model Evaluation and Performance Analysis
Evaluating the advanced model performance using various metrics including confusion matrix, classification report, and ROC curves.

In [17]:
# Evaluation: Advanced Model
y_pred_adv = model.predict(X_test)
y_proba_adv = model.predict_proba(X_test)[:,1]

print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred_adv))
print("Classification Report:\n", classification_report(y_test, y_pred_adv))
auc_adv = roc_auc_score(y_test, y_proba_adv)
print(f"ROC-AUC: {auc_adv:.4f}")

fpr2, tpr2, _ = roc_curve(y_test, y_proba_adv)
plt.figure()
plt.plot(fpr2, tpr2)
plt.plot([0,1],[0,1], linestyle='--')
plt.title("Advanced Model ROC Curve")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/roc_advanced.png')
plt.close()

Confusion Matrix:
 [[5436    0]
 [   0 6778]]
Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00      5436
           1       1.00      1.00      1.00      6778

    accuracy                           1.00     12214
   macro avg       1.00      1.00      1.00     12214
weighted avg       1.00      1.00      1.00     12214

ROC-AUC: 1.0000


In [18]:
# Explainability: Global
feat_imp = model.get_feature_importance(type='FeatureImportance')
feat_names = model.feature_names_
plt.figure()
plt.barh(feat_names, feat_imp)
plt.title("CatBoost Feature Importance (Gain)")
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/feat_imp_gain.png')
plt.close()

sample_size = min(1000, X_test.shape[0])
X_shap = X_test.sample(n=sample_size, random_state=SEED)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_shap)

plt.figure()
shap.summary_plot(shap_values, X_shap, show=False)
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/shap_beeswarm.png')
plt.close()

plt.figure()
shap.summary_plot(shap_values, X_shap, plot_type='bar', show=False)
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/shap_bar.png')
plt.close()

perm = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=SEED)
idx = perm.importances_mean.argsort()
plt.figure()
plt.barh(X_test.columns[idx], perm.importances_mean[idx])
plt.title("Permutation Feature Importance")
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/perm_importance.png')
plt.close()

## 8. Model Explainability and Feature Importance
Understanding model decisions through feature importance analysis, SHAP values, and permutation importance for both global and local interpretability.

In [19]:
# Explainability: Local
correct = X_test.index[(y_test == y_pred_adv)]
edible_idx = next(i for i in correct if y_test.loc[i]==0)
poison_idx = next(i for i in correct if y_test.loc[i]==1)

fp_e = shap.force_plot(explainer.expected_value, shap_values[X_shap.index.get_indexer([edible_idx])[0]], X_test.loc[edible_idx])
shap.save_html(f'{FIG_DIR}/force_edible.html', fp_e)

fp_p = shap.force_plot(explainer.expected_value, shap_values[X_shap.index.get_indexer([poison_idx])[0]], X_test.loc[poison_idx])
shap.save_html(f'{FIG_DIR}/force_poison.html', fp_p)

In [13]:
# Clean-Up & Memory Hygiene
del df, X_shap, shap_values
gc.collect()
proc = psutil.Process(os.getpid())
print(f"Peak memory usage: {proc.memory_info().rss/1024**2:.2f} MB")

NameError: name 'X_shap' is not defined

## 9. Memory Management and Model Persistence
Cleaning up memory usage and saving model artifacts for reproducibility and future use.

In [14]:
# Reproducibility Helpers
with open(f'{MODEL_DIR}/best_params.json','w') as f:
    json.dump(best_params, f, indent=4)

!pip freeze | tail -20
!zip -r artifacts.zip {FIG_DIR} {MODEL_DIR}

NameError: name 'best_params' is not defined

In [22]:
# Notebook Summary
top_feats = [feat_names[i] for i in idx[-5:]]
print(f"- Baseline ROC-AUC: {auc_base:.4f}")
print(f"- Advanced ROC-AUC: {auc_adv:.4f}")
print(f"- Top 5 features: {', '.join(top_feats)}")
print("- Next steps: ensemble methods, external validation.")

- Baseline ROC-AUC: 0.9416
- Advanced ROC-AUC: 1.0000
- Top 5 features: gill-spacing, stem-color, gill-color, cap-surface, gill-attachment
- Next steps: ensemble methods, external validation.


## 10. Project Summary and Key Results
Summary of model performance, key findings, and recommendations for future improvements.

In [16]:
# Create a detailed classification view for each mushroom
# Combine test data with predictions and probabilities
results_df = X_test.copy()
results_df['actual_class'] = y_test.map({0: 'edible', 1: 'poisonous'})
results_df['predicted_class'] = model.predict(X_test)
results_df['predicted_class'] = results_df['predicted_class'].map({0: 'edible', 1: 'poisonous'})
results_df['prediction_probability'] = model.predict_proba(X_test)[:, 1]
results_df['correct_prediction'] = (y_test == model.predict(X_test))

# Add confidence level
results_df['confidence'] = results_df['prediction_probability'].apply(
    lambda x: 'high' if x > 0.8 or x < 0.2 else 'medium' if x > 0.6 or x < 0.4 else 'low'
)

# Reorder columns for better readability
cols_order = ['actual_class', 'predicted_class', 'correct_prediction', 'prediction_probability', 'confidence'] + list(X_test.columns)
results_df = results_df[cols_order]

print(f"Total mushrooms classified: {len(results_df)}")
print(f"Correct predictions: {results_df['correct_prediction'].sum()}")
print(f"Accuracy: {results_df['correct_prediction'].mean():.4f}")
print("\nFirst 10 classifications:")
display(results_df.head(10))

# Show some interesting cases
print("\n=== High confidence correct predictions ===")
high_conf_correct = results_df[(results_df['confidence'] == 'high') & (results_df['correct_prediction'] == True)]
display(high_conf_correct.head(5))

print("\n=== Low confidence or incorrect predictions ===")
low_conf_or_wrong = results_df[(results_df['confidence'] == 'low') | (results_df['correct_prediction'] == False)]
if len(low_conf_or_wrong) > 0:
    display(low_conf_or_wrong.head(5))
else:
    print("No low confidence or incorrect predictions found!")

# Save detailed results
results_df.to_csv(f'{DATA_DIR}/mushroom_classification_results.csv', index=True)
print(f"\nDetailed results saved to {DATA_DIR}/mushroom_classification_results.csv")

Total mushrooms classified: 12214
Correct predictions: 12214
Accuracy: 1.0000

First 10 classifications:


Unnamed: 0,actual_class,predicted_class,correct_prediction,prediction_probability,confidence,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
49474,edible,edible,True,0.000392,high,5.19,x,d,y,t,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,u
22798,poisonous,poisonous,True,0.997514,high,6.84,s,t,y,f,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,a
60027,poisonous,poisonous,True,0.994316,high,10.44,o,unknown,n,f,...,unknown,g,n,unknown,unknown,f,f,unknown,d,s
35232,edible,edible,True,0.000815,high,3.9,x,unknown,n,f,...,b,unknown,w,unknown,w,t,unknown,unknown,g,a
42968,edible,edible,True,0.000412,high,10.76,p,s,n,t,...,s,unknown,w,unknown,unknown,t,l,unknown,m,a
50840,poisonous,poisonous,True,0.999997,high,10.82,x,unknown,n,f,...,c,unknown,n,unknown,unknown,f,f,p,d,a
13595,poisonous,poisonous,True,0.999935,high,1.43,x,g,y,f,...,unknown,s,n,unknown,unknown,f,f,unknown,h,a
47407,poisonous,poisonous,True,0.999176,high,3.67,b,unknown,n,f,...,unknown,unknown,g,unknown,w,f,f,k,g,a
5465,edible,edible,True,5.5e-05,high,7.96,f,unknown,n,f,...,b,unknown,u,unknown,unknown,f,f,unknown,g,w
3106,edible,edible,True,0.000501,high,17.34,f,y,w,f,...,s,unknown,n,unknown,unknown,t,m,unknown,m,u



=== High confidence correct predictions ===


Unnamed: 0,actual_class,predicted_class,correct_prediction,prediction_probability,confidence,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
49474,edible,edible,True,0.000392,high,5.19,x,d,y,t,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,u
22798,poisonous,poisonous,True,0.997514,high,6.84,s,t,y,f,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,a
60027,poisonous,poisonous,True,0.994316,high,10.44,o,unknown,n,f,...,unknown,g,n,unknown,unknown,f,f,unknown,d,s
35232,edible,edible,True,0.000815,high,3.9,x,unknown,n,f,...,b,unknown,w,unknown,w,t,unknown,unknown,g,a
42968,edible,edible,True,0.000412,high,10.76,p,s,n,t,...,s,unknown,w,unknown,unknown,t,l,unknown,m,a



=== Low confidence or incorrect predictions ===


Unnamed: 0,actual_class,predicted_class,correct_prediction,prediction_probability,confidence,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
58808,poisonous,poisonous,True,0.528509,low,26.41,x,d,n,f,...,unknown,unknown,n,unknown,unknown,f,f,unknown,d,a



Detailed results saved to ./data/mushroom_classification_results.csv


## 11. Detailed Classification Results and Analysis
Individual mushroom classification results with confidence levels and detailed analysis of model predictions.

In [20]:
# Load the saved CatBoost model
from catboost import CatBoostClassifier
import os

# Load model from saved file
model_path = os.path.join(MODEL_DIR, 'catboost_model.cbm')
if os.path.exists(model_path):
    loaded_model = CatBoostClassifier()
    loaded_model.load_model(model_path)
    print(f"Model loaded successfully from {model_path}")
else:
    print(f"Model file not found at {model_path}")
    # Use the existing model if file doesn't exist
    loaded_model = model

# Create a detailed classification view for each mushroom
# Combine test data with predictions and probabilities
results_df = X_test.copy()
results_df['actual_class'] = y_test.map({0: 'edible', 1: 'poisonous'})
results_df['predicted_class'] = loaded_model.predict(X_test)
results_df['predicted_class'] = results_df['predicted_class'].map({0: 'edible', 1: 'poisonous'})
results_df['prediction_probability'] = loaded_model.predict_proba(X_test)[:, 1]
results_df['correct_prediction'] = (y_test == loaded_model.predict(X_test))

# Add confidence level
results_df['confidence'] = results_df['prediction_probability'].apply(
    lambda x: 'high' if x > 0.8 or x < 0.2 else 'medium' if x > 0.6 or x < 0.4 else 'low'
)

# Reorder columns for better readability
cols_order = ['actual_class', 'predicted_class', 'correct_prediction', 'prediction_probability', 'confidence'] + list(X_test.columns)
results_df = results_df[cols_order]

print(f"Total mushrooms classified: {len(results_df)}")
print(f"Correct predictions: {results_df['correct_prediction'].sum()}")
print(f"Accuracy: {results_df['correct_prediction'].mean():.4f}")
print("\nFirst 10 classifications:")
display(results_df.head(10))

# Show some interesting cases
print("\n=== High confidence correct predictions ===")
high_conf_correct = results_df[(results_df['confidence'] == 'high') & (results_df['correct_prediction'] == True)]
display(high_conf_correct.head(5))

print("\n=== Low confidence or incorrect predictions ===")
low_conf_or_wrong = results_df[(results_df['confidence'] == 'low') | (results_df['correct_prediction'] == False)]
if len(low_conf_or_wrong) > 0:
    display(low_conf_or_wrong.head(5))
else:
    print("No low confidence or incorrect predictions found!")

# Save detailed results
results_df.to_csv(f'{DATA_DIR}/mushroom_classification_results.csv', index=True)
print(f"\nDetailed results saved to {DATA_DIR}/mushroom_classification_results.csv")

# Additional analysis with the loaded model
print("\n=== Model Information ===")
print(f"Model type: {type(loaded_model).__name__}")
try:
    print(f"Number of features: {len(loaded_model.feature_names_)}")
    print(f"Number of trees: {loaded_model.tree_count_}")
except AttributeError:
    print(f"Number of features: {X_test.shape[1]}")
    print("Model tree count information not available for loaded model")

# Show confidence distribution
print("\n=== Confidence Distribution ===")
conf_dist = results_df['confidence'].value_counts()
print(conf_dist)
print(f"\nConfidence percentages:")
for conf_level in ['high', 'medium', 'low']:
    if conf_level in conf_dist:
        pct = (conf_dist[conf_level] / len(results_df)) * 100
        print(f"{conf_level.capitalize()}: {conf_dist[conf_level]} ({pct:.1f}%)")

# Show prediction distribution
print("\n=== Prediction Distribution ===")
pred_dist = results_df['predicted_class'].value_counts()
print(pred_dist)

# Show accuracy by confidence level
print("\n=== Accuracy by Confidence Level ===")
for conf_level in ['high', 'medium', 'low']:
    subset = results_df[results_df['confidence'] == conf_level]
    if len(subset) > 0:
        accuracy = subset['correct_prediction'].mean()
        print(f"{conf_level.capitalize()} confidence: {accuracy:.4f} ({len(subset)} samples)")

Model loaded successfully from ./models/catboost_model.cbm
Total mushrooms classified: 12214
Correct predictions: 12214
Accuracy: 1.0000

First 10 classifications:


Unnamed: 0,actual_class,predicted_class,correct_prediction,prediction_probability,confidence,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
49474,edible,edible,True,0.000392,high,5.19,x,d,y,t,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,u
22798,poisonous,poisonous,True,0.997514,high,6.84,s,t,y,f,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,a
60027,poisonous,poisonous,True,0.994316,high,10.44,o,unknown,n,f,...,unknown,g,n,unknown,unknown,f,f,unknown,d,s
35232,edible,edible,True,0.000815,high,3.9,x,unknown,n,f,...,b,unknown,w,unknown,w,t,unknown,unknown,g,a
42968,edible,edible,True,0.000412,high,10.76,p,s,n,t,...,s,unknown,w,unknown,unknown,t,l,unknown,m,a
50840,poisonous,poisonous,True,0.999997,high,10.82,x,unknown,n,f,...,c,unknown,n,unknown,unknown,f,f,p,d,a
13595,poisonous,poisonous,True,0.999935,high,1.43,x,g,y,f,...,unknown,s,n,unknown,unknown,f,f,unknown,h,a
47407,poisonous,poisonous,True,0.999176,high,3.67,b,unknown,n,f,...,unknown,unknown,g,unknown,w,f,f,k,g,a
5465,edible,edible,True,5.5e-05,high,7.96,f,unknown,n,f,...,b,unknown,u,unknown,unknown,f,f,unknown,g,w
3106,edible,edible,True,0.000501,high,17.34,f,y,w,f,...,s,unknown,n,unknown,unknown,t,m,unknown,m,u



=== High confidence correct predictions ===


Unnamed: 0,actual_class,predicted_class,correct_prediction,prediction_probability,confidence,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
49474,edible,edible,True,0.000392,high,5.19,x,d,y,t,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,u
22798,poisonous,poisonous,True,0.997514,high,6.84,s,t,y,f,...,unknown,unknown,y,unknown,unknown,f,f,unknown,d,a
60027,poisonous,poisonous,True,0.994316,high,10.44,o,unknown,n,f,...,unknown,g,n,unknown,unknown,f,f,unknown,d,s
35232,edible,edible,True,0.000815,high,3.9,x,unknown,n,f,...,b,unknown,w,unknown,w,t,unknown,unknown,g,a
42968,edible,edible,True,0.000412,high,10.76,p,s,n,t,...,s,unknown,w,unknown,unknown,t,l,unknown,m,a



=== Low confidence or incorrect predictions ===


Unnamed: 0,actual_class,predicted_class,correct_prediction,prediction_probability,confidence,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
58808,poisonous,poisonous,True,0.528509,low,26.41,x,d,n,f,...,unknown,unknown,n,unknown,unknown,f,f,unknown,d,a



Detailed results saved to ./data/mushroom_classification_results.csv

=== Model Information ===
Model type: CatBoostClassifier
Number of features: 20
Number of trees: 500

=== Confidence Distribution ===
confidence
high      12208
medium        5
low           1
Name: count, dtype: int64

Confidence percentages:
High: 12208 (100.0%)
Medium: 5 (0.0%)
Low: 1 (0.0%)

=== Prediction Distribution ===
predicted_class
poisonous    6778
edible       5436
Name: count, dtype: int64

=== Accuracy by Confidence Level ===
High confidence: 1.0000 (12208 samples)
Medium confidence: 1.0000 (5 samples)
Low confidence: 1.0000 (1 samples)
