In [None]:
%load_ext autoreload
%autoreload 2

import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.datasets.base import InferenceDataset
from ethos.inference.constants import Task

input_dir = PROJECT_ROOT / "data/tokenized_datasets/mimic_old_ed"
output_dir = PROJECT_ROOT / "data/ed_task_labels"

In [None]:
import time
from typing import Generator

from tqdm import tqdm


def iter_dataset(dataset: InferenceDataset) -> Generator[dict, None, None]:
    for _, y in tqdm(dataset):
        yield y


def retrieve_labels(
    dataset: InferenceDataset, boolean_value_expr: pl.Expr | None = None
) -> pl.DataFrame:
    df = (
        pl.from_dicts(iter_dataset(dataset))
        .lazy()
        .with_columns(
            pl.col("patient_id").alias("subject_id"),
            pl.col("data_idx")
            .map_elements(lambda v: dataset.times[v], return_dtype=pl.Int64)
            .cast(pl.Datetime)
            .alias("time"),
            pl.col("expected").alias("boolean_value"),
            pl.col("true_token_time").cast(pl.Duration),
        )
    )
    if boolean_value_expr is not None:
        df = df.with_columns(boolean_value=boolean_value_expr)
    return df.select("subject_id", "time", "boolean_value").collect()


def dump_labels(
    output_fp, dataset_cls: type[InferenceDataset], boolean_value_expr: pl.Expr | None = None
) -> None:
    processed_datasets = []
    for fold in ("test", "train"):
        t = time.time()
        dataset = dataset_cls(input_dir / fold)
        print(f"Time taken to init {fold} dataset: {time.time() - t:.2f}s")
        processed_datasets.append(
            retrieve_labels(dataset, boolean_value_expr).with_columns(fold=pl.lit(fold))
        )

    pl.concat(processed_datasets).write_parquet(output_fp.with_suffix(".parquet"), use_pyarrow=True)

In [None]:
from ethos.datasets import HospitalAdmissionAtTriageDataset

dump_labels(output_dir / Task.ED_HOSPITALIZATION, HospitalAdmissionAtTriageDataset)

In [None]:
from ethos.datasets import CriticalOutcomeAtTriageDataset

dump_labels(
    output_dir / Task.ED_CRITICAL_OUTCOME,
    CriticalOutcomeAtTriageDataset,
    boolean_value_expr=pl.col("boolean_value")
    & (pl.col("true_token_time") <= pl.duration(hours=12)),
)

In [None]:
from ethos.datasets import EdReattendenceDataset

dump_labels(
    output_dir / Task.ED_REPRESENTATION,
    EdReattendenceDataset,
    boolean_value_expr=pl.col("boolean_value")
    & (pl.col("true_token_time") <= pl.duration(hours=72)),
)

In [None]:
from ethos.datasets import HospitalMortalityDataset

prolonged_stay_cutoff = pl.duration(days=10)
dump_labels(
    output_dir / "prolonged_stay",
    HospitalMortalityDataset,
    boolean_value_expr=pl.col("true_token_time") >= prolonged_stay_cutoff,
)

In [None]:
from ethos.constants import SpecialToken as ST
from ethos.datasets import HospitalMortalityDataset

dump_labels(
    output_dir / Task.HOSPITAL_MORTALITY,
    HospitalMortalityDataset,
    boolean_value_expr=pl.col("boolean_value").is_in([ST.DEATH]),
)

In [None]:
from ethos.constants import SpecialToken as ST
from ethos.datasets import ICUAdmissionDataset

dump_labels(
    output_dir / Task.ICU_ADMISSION,
    ICUAdmissionDataset,
    boolean_value_expr=pl.col("boolean_value").is_in([ST.ICU_ADMISSION]),
)

In [None]:
from ethos.constants import SpecialToken as ST
from ethos.datasets import ICUAdmissionDataset

dump_labels(
    output_dir / "composite",
    ICUAdmissionDataset,
    boolean_value_expr=pl.col("boolean_value").is_in([ST.ICU_ADMISSION, ST.DEATH])
    | (pl.col("true_token_time") >= prolonged_stay_cutoff),
)

### Prevalence of Tasks

In [None]:
pl.concat(
    [
        (
            pl.scan_parquet(output_fp)
            .group_by("fold")
            .agg(pl.mean("boolean_value"))
            .collect()
            .with_columns(task=pl.lit(output_fp.stem))
        )
        for output_fp in output_dir.iterdir()
    ]
).pivot("fold", index="task", values="boolean_value").sort("task")