In [None]:
from mlflow.tracking import MlflowClient
from os.path import join
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

mlflow_path = "/home/davina/Private/repos/CRRT/mlruns"
data_path = "/home/davina/Private/crrt-data"
client = MlflowClient(mlflow_path)

In [None]:
from sklearn.metrics import (
    roc_auc_score,
    brier_score_loss,
    accuracy_score,
    f1_score,
    average_precision_score,
    recall_score,
    precision_score,
    confusion_matrix,
)

METRIC_MAP = {
    "auroc": lambda gt, pred_probs, decision_thresh: roc_auc_score(gt, pred_probs),
    "ap": lambda gt, pred_probs, decision_thresh: average_precision_score(
        gt, pred_probs
    ),
    "brier": lambda gt, pred_probs, decision_thresh: brier_score_loss(gt, pred_probs),
    "accuracy": lambda gt, pred_probs, decision_thresh: accuracy_score(
        gt, (pred_probs >= decision_thresh).astype(int)
    ),
    "f1": lambda gt, pred_probs, decision_thresh: f1_score(
        gt, (pred_probs >= decision_thresh).astype(int)
    ),
    "recall": lambda gt, pred_probs, decision_thresh: recall_score(
        gt, (pred_probs >= decision_thresh).astype(int)
    ),
    "specificity": lambda gt, pred_probs, decision_thresh: recall_score(
        gt, (pred_probs >= decision_thresh).astype(int), pos_label=0
    ),
    "precision": lambda gt, pred_probs, decision_thresh: precision_score(
        gt, (pred_probs >= decision_thresh).astype(int)
    ),
    # "conf_matrix": lambda gt, pred_probs, decision_thresh: confusion_matrix(
        # gt, (pred_probs >= decision_thresh).astype(int)
    # ),
    "TN": lambda gt, pred_probs, decision_thresh: confusion_matrix(
        gt, (pred_probs >= decision_thresh).astype(int)
    )[0, 0],
    "FN": lambda gt, pred_probs, decision_thresh: confusion_matrix(
        gt, (pred_probs >= decision_thresh).astype(int)
    )[1, 0],
    "TP": lambda gt, pred_probs, decision_thresh: confusion_matrix(
        gt, (pred_probs >= decision_thresh).astype(int)
    )[1, 1],
    "FP": lambda gt, pred_probs, decision_thresh: confusion_matrix(
        gt, (pred_probs >= decision_thresh).astype(int)
    )[0, 1],
}

In [None]:
experiment_name = "serialize-test"

In [None]:
# exclude  // tune trial and //eval best
window_runs = client.search_runs(
    experiment_ids=client.get_experiment_by_name("static_learning").experiment_id,
    filter_string=f'tags.mlflow.runName="{experiment_name}"'
)

In [None]:
best = client.search_runs(
    experiment_ids=client.get_experiment_by_name("static_learning").experiment_id,
    filter_string=f'tags.mlflow.runName="{experiment_name} // eval best"'
)

In [None]:
runs = window_runs + best

In [None]:
# [run.info.artifact_uri for run in runs]
# [run.data.tags["slide_window_by"] for run in runs]
runs

In [None]:
from datetime import datetime, timezone, timedelta
lookback = {"days": 5, "hours": 18}
each_window_results = [
    run for run in runs
    if (run.data.tags.get("slide_window_by", None) is not None)
    and (datetime.now() - datetime.fromtimestamp(run.info.start_time/1000) <= timedelta(**lookback))
]
[run.data.tags["slide_window_by"] for run in each_window_results]

In [None]:
# pd.read_pickle("/home/davina/Private/repos/CRRT/predict_probas/xgb_val__predict_probas.pkl")
shapes = []
labels = []
predict_probas = []
slides = []
for run in each_window_results:
    i = run.data.tags["slide_window_by"]
    slides.append(i)
    if int(i):
        file = f"df_[startdate+{i}-7d,startdate+{i}].parquet"
    else:
        file = "df_[startdate-7d,startdate].parquet"
    f_df = pd.read_parquet(join(data_path, file))

    shapes.append(f_df.shape)
    labels.append(f_df["recommend_crrt"])

    predict_probas.append(
        pd.read_pickle(
            join(run.info.artifact_uri, "predict_probas/lgb_test__predict_probas.pkl")
            # join(run.info.artifact_uri, "predict_probas/xgb_test__predict_probas.pkl")
            # join(run.info.artifact_uri, "xgb_test__predict_probas.pkl")
        )
    )

In [None]:
features_describe = {}
for run in each_window_results:
    i = run.data.tags["slide_window_by"]
    slides.append(i)
    if int(i):
        file = f"df_[startdate+{i}-7d,startdate+{i}].parquet"
    else:
        file = "df_[startdate-7d,startdate].parquet"
    f_df = pd.read_parquet(join(data_path, file))
    top_features = ["CT-INTEM_skew", "CT-EXTEM_skew", "ABSOLUTE NUCLEATED RBC COUNT_skew", "ALPHA-ANGLE-EX_len"]
    # print(f_df.columns[f_df.columns.str.contains("ANGLE")])
    # print(f_df[top_features].describe())
    features_describe[i] = f_df[top_features].describe()
# [df.loc["mean","CT-INTEM_skew"] for df in features_describe]
df = pd.concat(features_describe.values(), axis=0, keys=features_describe.keys()).sort_index(level=0)
melted = df.melt(ignore_index=False,var_name="Features").reset_index()


In [None]:
# for feature in df.columns:
#     print(feature)
#     g = sns.FacetGrid(
#         melted[melted["Features"] == feature],
#         col="level_1", col_wrap=4, sharey=False
#     )
#     g.map(sns.lineplot, "level_0", "value")

g = sns.FacetGrid( melted, col="level_1", col_wrap=4, sharey=False)
g.map(sns.lineplot, "level_0", "value", hue=melted["Features"])
g.add_legend()

In [None]:
from matplotlib import pyplot as plt
from matplotlib import image as mpimg

file = join(each_window_results[0].info.artifact_uri, "img_artifacts", "xgb_test__feature_importance.png")
img = mpimg.imread(file)
plt.imshow(img)
plt.show()

In [None]:
import numpy as np
smallest_dataset = np.argmin(np.array(shapes)[:,0])
reference_ids = predict_probas[smallest_dataset].index

metrics = {}
for label, predict_proba, slide in zip(labels, predict_probas, slides):
    label_idxs = label.index.intersection(reference_ids)
    predict_proba_idxs = predict_proba.index.intersection(reference_ids)
    metrics[slide] = {
        metric_name: metric_fn(label[label_idxs], predict_proba[predict_proba_idxs], 0.5)
        for metric_name, metric_fn in METRIC_MAP.items()
    }
equal_sized_df = pd.DataFrame.from_dict(metrics, orient="index").rename_axis(index="Slide").sort_index(axis=0)
equal_sized_df

In [None]:
# g = sns.lineplot(data=metrics_over_windows)
# g = sns.relplot(data=metrics_over_windows, kind="line")
# g.legend(loc='center left', bbox_to_anchor=(1.25, 0.5), ncol=1)
# plt.show()
g = sns.FacetGrid(
    equal_sized_df.melt(ignore_index=False,var_name="Metrics").reset_index(),
    col="Metrics", col_wrap=4, sharey=False
)
g.map(sns.lineplot, "Slide", "value")

# No filter, visualize metrics across windows

In [None]:
metrics_over_windows = pd.DataFrame(
    [run.data.metrics for run in each_window_results],
    index=[run.data.tags["slide_window_by"] for run in each_window_results]
).rename_axis(index="Slide").sort_index(axis=0)
# metrics_over_windows.columns = pd.MultiIndex.from_product([["Metrics"], metrics_over_windows.columns])
metrics_over_windows

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
# g = sns.lineplot(data=metrics_over_windows)
# g = sns.relplot(data=metrics_over_windows, kind="line")
# g.legend(loc='center left', bbox_to_anchor=(1.25, 0.5), ncol=1)
# plt.show()
g = sns.FacetGrid(
    metrics_over_windows.melt(ignore_index=False,var_name="Metrics").reset_index(),
    col="Metrics", col_wrap=4, sharey=False
)
g.map(sns.lineplot, "Slide", "value")

In [None]:
# AUROC ONLY
auroc_mask = metrics_over_windows.columns.str.contains("auroc")
auroc_cols = metrics_over_windows.columns[auroc_mask]
g = sns.FacetGrid(
    metrics_over_windows[auroc_cols].melt(ignore_index=False,var_name="Metrics").reset_index(),
    col="Metrics", col_wrap=4, sharey=False
)
g.map(sns.lineplot, "Slide", "value")

In [None]:
# cleanup
# for run in runs:
#     client.delete_run(run.info.run_id)