In [63]:
import numpy as np
import pandas as pd
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
import yaml
from ipywidgets import interact, widgets, VBox
from typing import List, Tuple, Dict, Any

In [2]:
import os
from pathlib import Path

# Go up one directory level from the notebook's location
project_root = Path().resolve().parent  # Navigate to the parent directory
os.chdir(project_root)  # Set this as the working directory

# print("Current working directory set to:", os.getcwd())

In [23]:
with open("config/paths.yaml", "r") as file:
    paths = yaml.load(file, Loader=yaml.FullLoader)

# Load the data
X = pd.read_csv(paths["train"]["final_features"])
y = pd.read_csv(paths["train"]["labels"])
X_val = pd.read_csv(paths["val"]["final_features"])
y_val = pd.read_csv(paths["val"]["labels"])
# test_X = pd.read_csv(paths["test"]["final_features"])

clf_mask = pd.read_csv(paths["clf_mask_file"], index_col="label")
regr_mask = pd.read_csv(paths["regr_mask_file"], index_col="label")

# Path to stored classification models
svc_models_path =  os.path.normpath(os.path.join(paths["models"]["clf"], "SVC"))
rf_models_path = os.path.normpath(
    os.path.join(paths["models"]["clf"], "RandomForestClassifier")
)

# Path to stored regression models
kr_models_path =  os.path.normpath(os.path.join(paths["models"]["regr"], "KernelRidge"))

print(f"SVC: {svc_models_path}")
print(f"RandomForest: {rf_models_path}")
print(f"KernelRidge: {kr_models_path}")

SVC: models\classification\SVC
RandomForest: models\classification\RandomForestClassifier
KernelRidge: models\regression\KernelRidge


In [24]:
def load_models(dir: str) -> Dict:
    models = {}
    for file in os.listdir(dir):
        if file.endswith(".joblib"):
            model = joblib.load(os.path.join(dir, file))
            label = file.split(".")[0]
            models[label] = model
    return models

svc_models = load_models(svc_models_path)
rf_models = load_models(rf_models_path)
kr_models = load_models(kr_models_path)

print(f"SVC models: {len(svc_models)}")
print(f"RandomForest models: {len(rf_models)}")
print(f"KernelRidge models: {len(kr_models)}")

SVC models: 11
RandomForest models: 11
KernelRidge models: 4


In [42]:
from src.helper import standardize_data


def predict(X_train, X_val, feature_mask, model):
    X_train_selected = X_train.loc[:, feature_mask]
    X_val_selected = X_val.loc[:, feature_mask]

    norm_X_train, norm_X_val = standardize_data(X_train_selected, X_val_selected)

    y_pred = model.predict(norm_X_val)

    return y_pred

In [43]:
def compute_clf_metrics(
    X: pd.DataFrame,
    X_val: pd.DataFrame,
    y_val: pd.DataFrame,
    mask: pd.DataFrame,
    models: Dict,
    out_file: str,
) -> Dict:

    metrics = {}

    X = X.drop(columns=["pid"])
    X_val = X_val.drop(columns=["pid"])

    for label, model in models.items():
        y_true = y_val[label]
        label_mask = mask.loc[label, :]
        y_pred = predict(X, X_val, label_mask, model)

        # Compute metrics
        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)

        # Compute confusion matrix
        cm = confusion_matrix(y_true, y_pred)

        metrics[label] = {
            "Accuracy": acc,
            "Precision": prec,
            "Recall": rec,
            "F1 Score": f1,
            "Confusion": cm,
        }

    # Convert metrics to a DataFrame for summary
    metrics_df = pd.DataFrame(metrics)
    metrics_df.to_csv(out_file)

    return metrics

In [46]:
svc_metrics_file = paths["evaluation"]["svc"]
rf_metrics_file = paths["evaluation"]["random_forest"]

svc_metrics = compute_clf_metrics(X, X_val, y_val, clf_mask, svc_models, svc_metrics_file)
rf_metrics = compute_clf_metrics(X, X_val, y_val, clf_mask, rf_models, rf_metrics_file)

print(f"SVC: {svc_metrics}")
print(f"RandomForest: {rf_metrics}")

SVC: {'LABEL_Alkalinephos': {'Accuracy': 0.7141352987628323, 'Precision': np.float64(0.44017632241813603), 'Recall': np.float64(0.7801339285714286), 'F1 Score': np.float64(0.5628019323671497), 'Confusion': array([[2014,  889],
       [ 197,  699]])}, 'LABEL_AST': {'Accuracy': 0.7078178468017899, 'Precision': np.float64(0.4416918429003021), 'Recall': np.float64(0.797164667393675), 'F1 Score': np.float64(0.5684292379471229), 'Confusion': array([[1958,  924],
       [ 186,  731]])}, 'LABEL_BaseExcess': {'Accuracy': 0.8373256120031587, 'Precision': np.float64(0.6487053883834849), 'Recall': np.float64(0.8887823585810163), 'F1 Score': np.float64(0.75), 'Confusion': array([[2254,  502],
       [ 116,  927]])}, 'LABEL_Bilirubin_direct': {'Accuracy': 0.9010265859436694, 'Precision': np.float64(0.25101214574898784), 'Recall': np.float64(0.9538461538461539), 'F1 Score': np.float64(0.3974358974358974), 'Confusion': array([[3299,  370],
       [   6,  124]])}, 'LABEL_Bilirubin_total': {'Accuracy': 

In [65]:
# Interactive plot function
def plot_metrics(label_name):

    label_metrics = rf_metrics[label_name]
    confusion_matrix = label_metrics["Confusion"]

    # Create a plot with two sections: heatmap and text
    fig, ax = plt.subplots(2, 2, figsize=(12, 6), gridspec_kw={"width_ratios": [1, 1]})

    # Plot metrics for Random Forest
    sns.heatmap(
        confusion_matrix,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["False", "True"],
        yticklabels=["False", "True"],
        ax=ax[0,0],
    )
    ax[0,0].set_title(f"Confusion Matrix")
    ax[0,0].set_xlabel("Predicted")
    ax[0,0].set_ylabel("True")

    # Extract other metrics (accuracy, precision, recall, F1-score)
    accuracy = label_metrics["Accuracy"]
    precision = label_metrics["Precision"]
    recall = label_metrics["Recall"]
    f1 = label_metrics["F1 Score"]

    # Prepare metrics text
    metrics_text = (
        f"Accuracy:  {accuracy:.2f}\n"
        f"Precision: {precision:.2f}\n"
        f"Recall:    {recall:.2f}\n"
        f"F1 Score:  {f1:.2f}"
    )

    # Add metrics text to the second subplot
    ax[0,1].axis("off")  # Turn off axis for text display
    ax[0,1].text(
        0.5,
        0.5,
        metrics_text,
        fontsize=14,
        ha="center",
        va="center",
        bbox=dict(boxstyle="round", facecolor="white", edgecolor="black"),
    )

    # Plot metrics for SVC
    label_metrics = svc_metrics[label_name]
    confusion_matrix = label_metrics["Confusion"]

    sns.heatmap(
        confusion_matrix,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["False", "True"],
        yticklabels=["False", "True"],
        ax=ax[1,0],
    )
    ax[1,0].set_title(f"Confusion Matrix")
    ax[1,0].set_xlabel("Predicted")
    ax[1,0].set_ylabel("True")

    # Extract other metrics (accuracy, precision, recall, F1-score)
    accuracy = label_metrics["Accuracy"]
    precision = label_metrics["Precision"]
    recall = label_metrics["Recall"]
    f1 = label_metrics["F1 Score"]

    # Prepare metrics text
    metrics_text = (
        f"Accuracy:  {accuracy:.2f}\n"
        f"Precision: {precision:.2f}\n"
        f"Recall:    {recall:.2f}\n"
        f"F1 Score:  {f1:.2f}"
    )

    # Add metrics text to the second subplot
    ax[1,1].axis("off")  # Turn off axis for text display
    ax[1,1].text(
        0.5,
        0.5,
        metrics_text,
        fontsize=14,
        ha="center",
        va="center",
        bbox=dict(boxstyle="round", facecolor="white", edgecolor="black"),
    )

    # Add row labels for "Random Forest" and "SVC"
    ax[0, 0].text(
        -0.5,
        1.25,
        "Random Forest",
        transform=ax[0, 0].transAxes,
        fontsize=14,
        fontweight="bold",
    )
    ax[1, 0].text(
        -0.5, 1.25, "SVC", transform=ax[1, 0].transAxes, fontsize=14, fontweight="bold"
    )

    fig.suptitle(f"Classification metrics comparison of {label_name}")

    # Display the plot
    plt.tight_layout()
    plt.show()


# Create interactive selector at the bottom
label_selector = widgets.Dropdown(
    options=list(svc_metrics.keys()),
    description="Label:",
    style={"description_width": "initial"},
)
out = widgets.interactive_output(plot_metrics, {"label_name": label_selector})
ui = VBox([out, label_selector])
display(ui)

VBox(children=(Output(), Dropdown(description='Label:', options=('LABEL_Alkalinephos', 'LABEL_AST', 'LABEL_Basâ€¦