In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import altair as alt
import pickle
import time

In [2]:
SCRATCH_DIR = "/scratch/siads696f23_class_root/siads696f23_class/psollars"

# For local dev
SCRATCH_DIR = "./../data"

In [3]:
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
)



In [4]:
DATASET = "model_selection_all_delays_ORD_UA"

models = [
    "logistic_regression",
    "random_forest",
    "svm",
    "gradient_boost",
    "xgboost",
    "catboost",
]

results = pd.DataFrame()

for m in models:
    start_time = time.time()

    with open(f"{SCRATCH_DIR}/{DATASET}_{m}_model.pkl", "rb") as f:
        model = pickle.load(f)

    with open(f"{SCRATCH_DIR}/{DATASET}_{m}_X_test.pkl", "rb") as f:
        X_test = pickle.load(f)

    with open(f"{SCRATCH_DIR}/{DATASET}_{m}_y_test.pkl", "rb") as f:
        y_test = pickle.load(f)

    # evaluate these metrics and add each of them as a column
    y_pred = model.predict(X_test)

    end_time = time.time()

    print(f"{m} model evaluation took: {(end_time - start_time):.4f} seconds")

    test_metrics = {
        "accuracy": [accuracy_score(y_test, y_pred)],
        "precision": [precision_score(y_test, y_pred)],
        "recall": [recall_score(y_test, y_pred)],
        "f1": [f1_score(y_test, y_pred)],
        "roc_auc": [roc_auc_score(y_test, y_pred, average="macro", multi_class="ovr")],
    }

    # Add cross validation scores
    cv_results = {
        key: value for key, value in model.cv_results_.items() if key != "params"
    }

    all_metrics = {"model": m, **test_metrics, **cv_results}

    results = pd.concat([results, pd.DataFrame(all_metrics)])

results

logistic_regression model evaluation took: 0.1518 seconds
random_forest model evaluation took: 5.1299 seconds


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


svm model evaluation took: 737.6583 seconds
gradient_boost model evaluation took: 0.1277 seconds
xgboost model evaluation took: 0.1006 seconds
catboost model evaluation took: 0.4115 seconds


Unnamed: 0,model,accuracy,precision,recall,f1,roc_auc,mean_fit_time,std_fit_time,mean_score_time,std_score_time,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score
0,logistic_regression,0.601263,0.339987,0.594295,0.43253,0.598976,305.535439,2.894408,0.029695,0.003447,0.596246,0.600611,0.609063,0.605169,0.60974,0.604166,0.005127,1
0,random_forest,0.761136,0.556951,0.321915,0.408005,0.616971,81.82745,1.60677,2.150829,0.403742,0.415065,0.636571,0.935571,0.93847,0.93792,0.772719,0.213419,1
0,svm,0.573424,0.31978,0.592894,0.415473,0.579815,2338.458683,233.249205,347.279762,19.023874,0.582261,0.577708,0.579413,0.576404,0.57825,0.578807,0.00198,1
0,gradient_boost,0.780021,0.682286,0.261429,0.378015,0.609804,7.634593,0.079829,0.405706,0.008689,0.226339,0.565866,0.959464,0.9574,0.957811,0.733376,0.295574,1
0,xgboost,0.784579,0.680058,0.297466,0.413891,0.624694,3.839633,0.04552,0.154759,0.013912,0.268204,0.591766,0.959026,0.956063,0.958151,0.746642,0.278061,1
0,catboost,0.786859,0.692489,0.299376,0.41803,0.626852,101.621017,0.121989,0.668524,0.044319,0.259853,0.589035,0.963483,0.96158,0.961145,0.747019,0.283208,1


In [5]:
results

Unnamed: 0,model,accuracy,precision,recall,f1,roc_auc,mean_fit_time,std_fit_time,mean_score_time,std_score_time,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,mean_test_score,std_test_score,rank_test_score
0,logistic_regression,0.601263,0.339987,0.594295,0.43253,0.598976,305.535439,2.894408,0.029695,0.003447,0.596246,0.600611,0.609063,0.605169,0.60974,0.604166,0.005127,1
0,random_forest,0.761136,0.556951,0.321915,0.408005,0.616971,81.82745,1.60677,2.150829,0.403742,0.415065,0.636571,0.935571,0.93847,0.93792,0.772719,0.213419,1
0,svm,0.573424,0.31978,0.592894,0.415473,0.579815,2338.458683,233.249205,347.279762,19.023874,0.582261,0.577708,0.579413,0.576404,0.57825,0.578807,0.00198,1
0,gradient_boost,0.780021,0.682286,0.261429,0.378015,0.609804,7.634593,0.079829,0.405706,0.008689,0.226339,0.565866,0.959464,0.9574,0.957811,0.733376,0.295574,1
0,xgboost,0.784579,0.680058,0.297466,0.413891,0.624694,3.839633,0.04552,0.154759,0.013912,0.268204,0.591766,0.959026,0.956063,0.958151,0.746642,0.278061,1
0,catboost,0.786859,0.692489,0.299376,0.41803,0.626852,101.621017,0.121989,0.668524,0.044319,0.259853,0.589035,0.963483,0.96158,0.961145,0.747019,0.283208,1


In [8]:
with open(f"{SCRATCH_DIR}/model_evaluation_all_delays_ORD_UA_results.pkl", "wb") as f:
    pickle.dump(results, f)

In [11]:
data_melted = results.melt(
    id_vars=["model"],
    value_vars=["mean_test_score", "accuracy", "precision", "recall", "f1", "roc_auc"],
    var_name="metric",
    value_name="value",
)

base = alt.Chart(data_melted).encode(
    y=alt.Y("value:Q", axis=alt.Axis(title="Value"), scale=alt.Scale(domain=[0, 1])),
    x=alt.X("metric:N", axis=alt.Axis(title="Metric"), sort=results.columns.values),
    # color="metric:N",
)

bar = (
    base.mark_bar(width=25)
    .encode(
        tooltip=["model", "metric", "value"],
        color=alt.Color("metric:N", legend=None),
    )
    .properties(width=180)
)

text = base.mark_text(
    align="center",
    color="white",
    size=10,
    dy=10,
).encode(text=alt.Text("value:Q", format=".2f"))

chart = (
    (bar + text)
    .facet(column=alt.Column("model:N", sort=models))
    .properties(
        # title="Model Evaluation Metrics",
        # width=200
    )
    # .configure_axis(labelFontSize=12, titleFontSize=14)
    .configure_mark(
        # opacity=0.2,
        # color='red'
    )
)

chart

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


In [7]:
data_melted

Unnamed: 0,model,metric,value
0,logistic_regression,mean_test_score,0.604166
1,random_forest,mean_test_score,0.772719
2,svm,mean_test_score,0.578807
3,gradient_boost,mean_test_score,0.733376
4,xgboost,mean_test_score,0.746642
5,catboost,mean_test_score,0.747019
6,logistic_regression,accuracy,0.601263
7,random_forest,accuracy,0.761136
8,svm,accuracy,0.573424
9,gradient_boost,accuracy,0.780021
