In [1]:
import os
import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix

In [2]:
try:
    _ = first_run
except NameError:
    first_run = True
    os.chdir(os.getcwd().rsplit("/", 1)[0])

# Load Data

In [3]:
X_train, y_train = joblib.load(
    "../data/train/preprocessed/train_features_labels.joblib.gz"
)

X_validation, y_validation = joblib.load(
    "../data/train/preprocessed/validation_features_labels.joblib.gz"
)

# Define baseline model

In [4]:
baseline = RandomForestClassifier().fit(X_train, y_train)

_ = joblib.dump(baseline, "../ml_artifacts/baseline_model.joblib.gz")

In [5]:
prediction = baseline.predict_proba(X_validation)

threshold_perf = pd.DataFrame(
    [
        (
            threshold,
            *confusion_matrix(
                y_validation, (prediction[:, 1] > threshold).astype(int)
            ).ravel(),
        )
        for threshold in np.arange(0.05, 0.95, 0.05)
    ],
    columns=["threshold", "tn", "fp", "fn", "tp"],
).assign(
    precision=lambda df: df["tp"] / (df["tp"] + df["fp"]),
    recall=lambda df: df["tp"] / (df["tp"] + df["fn"]),
    f1=lambda df: 2
    * (df["precision"] * df["recall"])
    / (df["precision"] + df["recall"]),
)

threshold_perf.to_csv("../ml_artifacts/baseline_model_performance.csv", index=False)

In [6]:
def highlight_max(data, color="yellow"):
    """
    highlight the maximum in a Series or DataFrame
    """
    attr = "background-color: {}".format(color)
    if data.ndim == 1:  # Series from .apply(axis=0) or axis=1
        is_max = data == data.max()
        return [attr if v else "" for v in is_max]
    else:  # from .apply(axis=None)
        is_max = data == data.max().max()
        return pd.DataFrame(
            np.where(is_max, attr, ""), index=data.index, columns=data.columns
        )


threshold_perf.style.apply(
    highlight_max, color="darkorange", subset=["precision", "recall", "f1"]
)

Unnamed: 0,threshold,tn,fp,fn,tp,precision,recall,f1
0,0.05,13695,487,141,73,0.130357,0.341121,0.18863
1,0.1,13867,315,160,54,0.146341,0.252336,0.185249
2,0.15,13971,211,174,40,0.159363,0.186916,0.172043
3,0.2,14029,153,179,35,0.18617,0.163551,0.174129
4,0.25,14058,124,182,32,0.205128,0.149533,0.172973
5,0.3,14080,102,185,29,0.221374,0.135514,0.168116
6,0.35,14099,83,187,27,0.245455,0.126168,0.166667
7,0.4,14115,67,190,24,0.263736,0.11215,0.157377
8,0.45,14129,53,192,22,0.293333,0.102804,0.152249
9,0.5,14138,44,195,19,0.301587,0.088785,0.137184
