In [None]:
%load_ext autoreload
%autoreload 2

import polars as pl
from matplotlib import pyplot as plt

from ethos.constants import PROJECT_ROOT
from ethos.constants import SpecialToken as ST
from ethos.inference.constants import Task
from ethos.metrics import compute_and_print_metrics, preprocess_inference_results

result_dir = PROJECT_ROOT / "results"

## Hospitalization Predicted at Triage

In [None]:
for result_fp in (result_dir / Task.ED_HOSPITALIZATION).iterdir():
    df = preprocess_inference_results(
        result_fp,
        actual_expr=(pl.col("actual") == ST.ADMISSION)
        & (pl.col("token_time") <= pl.duration(days=3)),
    )
    compute_and_print_metrics(
        df["expected"],
        df["actual"],
        f"Prediction of Hospitalization at Triage (rep_num={df['counts'].mean():.2f})\n{result_fp.name}",
    )
    plt.show()

## Critical Outcome Within 12h Predicted at Triage

In [None]:
for result_fp in (result_dir / Task.ED_CRITICAL_OUTCOME).iterdir():
    df = preprocess_inference_results(
        result_fp,
        actual_expr=pl.col("actual").is_in([ST.ICU_ADMISSION, ST.DEATH]),
        expected_expr=pl.col("expected").is_in([ST.ICU_ADMISSION, ST.DEATH])
        & (pl.col("true_token_time") <= pl.duration(hours=12)),
    )
    compute_and_print_metrics(
        df["expected"],
        df["actual"],
        "Prediction of Critical Outcome within 12h at Triage "
        f"(rep_num={df['counts'].mean():.2f})\n{result_fp.name}",
    )
    plt.show()

## Emergency Department Representation Within 72h Predicted at Triage

In [None]:
for result_fp in (result_dir / Task.ED_REPRESENTATION).iterdir():
    df = preprocess_inference_results(
        result_fp,
        actual_expr=pl.col("actual").is_in([ST.ED_ADMISSION]),
        expected_expr=pl.col("expected") & (pl.col("true_token_time") <= pl.duration(hours=72)),
    )
    compute_and_print_metrics(
        df["expected"],
        df["actual"],
        "Prediction of ED representation within 72h "
        f"(rep_num={df['counts'].mean():.2f})\n{result_fp.name}",
    )
    plt.show()