In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.constants import SpecialToken as ST
from ethos.inference.constants import Reason, Task
from ethos.metrics import compute_and_print_metrics, preprocess_inference_results

results_dir = PROJECT_ROOT / "results"

## ICU Admission results

In [None]:
def process_icu_admission_results(input_dir, **kwargs) -> pl.DataFrame:
    return preprocess_inference_results(
        input_dir,
        actual_expr=pl.col("actual").is_in([ST.DEATH, ST.ICU_ADMISSION]),
        expected_expr=pl.col("expected").is_in([ST.DEATH, ST.ICU_ADMISSION]),
        filter_ambiguous=pl.col("stop_reason") == Reason.GOT_TOKEN,
        **kwargs,
    )

In [None]:
icu_admission_results = []
for input_dir in (results_dir / Task.ICU_ADMISSION).iterdir():
    df = process_icu_admission_results(input_dir)
    rep_num = df["counts"].mean()
    res = compute_and_print_metrics(
        df["expected"],
        df["actual"],
        f"ICU Admission (rep_num={rep_num:.2f})\n{input_dir.name}",
    )
    icu_admission_results.append(
        {
            "name": input_dir.name,
            "auc": res["fitted_auc"],
            "rep_num": rep_num,
        }
    )
    plt.show()
icu_admission_results = pl.DataFrame(icu_admission_results).sort("auc", descending=True)

In [None]:
icu_admission_results