In [None]:
%load_ext autoreload
%autoreload 2

import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.constants import SpecialToken as ST
from ethos.datasets.base import InferenceDataset
from ethos.inference.constants import Task
from ethos.inference.utils import get_dataset_cls

input_dir = PROJECT_ROOT / "data/tokenized_datasets/mimic_ed"
output_dir = PROJECT_ROOT / "data/ed_task_labels"

In [None]:
import time
from typing import Generator

from tqdm import tqdm


def iter_dataset(dataset: InferenceDataset) -> Generator[dict, None, None]:
    for _, y in tqdm(dataset):
        yield y


def retrieve_labels(
    dataset: InferenceDataset, boolean_value_expr: pl.Expr | None = None
) -> pl.DataFrame:
    df = (
        pl.from_dicts(iter_dataset(dataset))
        .lazy()
        .with_columns(
            pl.col("patient_id").alias("subject_id"),
            pl.col("data_idx")
            .map_elements(lambda v: dataset.times[v], return_dtype=pl.Int64)
            .cast(pl.Datetime)
            .alias("time"),
            pl.col("expected").alias("boolean_value"),
            pl.col("true_token_time").cast(pl.Duration),
        )
    )
    if boolean_value_expr is not None:
        df = df.with_columns(boolean_value=boolean_value_expr)
    else:
        df = df.with_columns(pl.col("expected").cast(pl.Boolean).alias("boolean_value"))
    return df.select("subject_id", "time", "boolean_value").collect()


def dump_labels(
    task: Task, output_fn: str | None = None, boolean_value_expr: pl.Expr | None = None
) -> None:

    if output_fn is None:
        output_fn = str(task)
    output_fp = output_dir / f"{output_fn}.parquet"

    if output_fp.exists():
        print(f"Output file {output_fp} already exists, skipping.")
        return

    processed_datasets = []
    for fold in ("test", "train"):
        t = time.time()
        dataset = get_dataset_cls(task)(input_dir / fold)
        print(f"Time taken to init {fold} dataset: {time.time() - t:.2f}s")
        processed_datasets.append(
            retrieve_labels(dataset, boolean_value_expr).with_columns(fold=pl.lit(fold))
        )

    output_dir.mkdir(parents=True, exist_ok=True)
    pl.concat(processed_datasets).write_parquet(output_fp, use_pyarrow=True)

In [None]:
dump_labels(Task.ED_HOSPITALIZATION)

In [None]:
dump_labels(
    Task.ED_CRITICAL_OUTCOME,
    boolean_value_expr=pl.col("expected").is_in([ST.ICU_ADMISSION, ST.DEATH])
    & (pl.col("true_token_time") <= pl.duration(hours=12)),
)

In [None]:
dump_labels(
    Task.ED_REPRESENTATION,
    boolean_value_expr=pl.col("boolean_value")
    & (pl.col("true_token_time") <= pl.duration(hours=72)),
)

In [None]:
prolonged_stay_cutoff = pl.duration(days=10)
dump_labels(
    Task.HOSPITAL_MORTALITY,
    output_fn="prolonged_stay",
    boolean_value_expr=pl.col("true_token_time") >= prolonged_stay_cutoff,
)

In [None]:
from ethos.constants import SpecialToken as ST

dump_labels(
    Task.HOSPITAL_MORTALITY,
    boolean_value_expr=pl.col("boolean_value") == ST.DEATH,
)

In [None]:
from ethos.constants import SpecialToken as ST

dump_labels(
    Task.ICU_ADMISSION,
    boolean_value_expr=pl.col("boolean_value").is_in([ST.ICU_ADMISSION]),
)

In [None]:
from ethos.constants import SpecialToken as ST

dump_labels(
    Task.ICU_ADMISSION,
    output_fn="composite",
    boolean_value_expr=pl.col("boolean_value").is_in([ST.ICU_ADMISSION, ST.DEATH])
    | (pl.col("true_token_time") >= prolonged_stay_cutoff),
)

### Task Prevalence

In [None]:
pl.concat(
    [
        (
            pl.scan_parquet(output_fp)
            .group_by("fold")
            .agg(pl.mean("boolean_value"))
            .collect()
            .with_columns(task=pl.lit(output_fp.stem))
        )
        for output_fp in output_dir.iterdir()
    ]
).pivot("fold", index="task", values="boolean_value").sort("task")