In [None]:
%load_ext autoreload
%autoreload 2

from datetime import timedelta

import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.constants import SpecialToken as ST

data_dir = PROJECT_ROOT / "data/tokenized_datasets"

In [None]:
from ethos.datasets import HospitalMortalityDataset

d_hosp = HospitalMortalityDataset(data_dir / "mimic_old_ed/test")
d_hosp

In [None]:
from ethos.datasets import ICUAdmissionDataset

d_icu = ICUAdmissionDataset(data_dir / "mimic_old_ed/test")
d_icu

### Find patients for generating trajectories

In [None]:
def get_static_data(dataset, idx):
    patient_id = dataset.patient_id_at_idx[idx].item()
    time_at_start = dataset.times[idx].item()
    static_data = dataset.static_data[patient_id]
    dt = timedelta(microseconds=time_at_start - static_data["MEDS_BIRTH"]["time"][0])
    age_years = min(dt.days / 365.25, 99)
    marital = static_data["MARITAL"]["code"][-1]
    gender = static_data["GENDER"]["code"][-1]
    race = static_data["RACE"]["code"][-1]
    return {
        "age": age_years,
        "marital": marital,
        "gender": gender,
        "race": race,
    }


def dataset_to_df(dataset, n):
    return (
        pl.DataFrame((dataset[i][1] for i in range(n)), infer_schema_length=1_000_000, orient="row")
        .with_row_index()
        .with_columns(pl.col("^.*time.*$").cast(pl.Duration))
    )


n = len(d_hosp)
df_hosp = dataset_to_df(d_hosp, n)
df_icu = dataset_to_df(d_icu, n)
df_hosp

In [None]:
columns = ["index", "expected", "true_token_dist", "true_token_time", "patient_id", "data_idx"]

token_dist_expr = pl.col("true_token_dist").is_between(500, 1500)
df_hosp = df_hosp.filter(token_dist_expr)[columns]
df_icu = df_icu.filter(token_dist_expr)[columns]

In [None]:
seed = 42
n_sample = 10
trajectories_patients = pl.concat(
    [
        df_hosp.filter(expected=ST.DISCHARGE).sample(n_sample, seed=seed),
        df_hosp.filter(expected=ST.DEATH).sample(n_sample, seed=seed),
    ]
)

trajectories_patients = (
    pl.concat(
        [
            trajectories_patients,
            df_icu.filter(
                ~pl.col("index").is_in(trajectories_patients["index"]), expected=ST.ICU_ADMISSION
            ).sample(n_sample, seed=seed),
        ]
    )
    .with_columns(
        static_data=pl.col("data_idx").map_elements(
            lambda idx: get_static_data(d_hosp, idx), return_dtype=pl.Struct
        ),
    )
    .unnest("static_data")
    .sort("index")
)

trajectories_patients

In [None]:
from ethos.constants import PROJECT_ROOT

results_dir = PROJECT_ROOT / "results" / "trajectories"
results_dir.mkdir(exist_ok=True)

trajectories_patients.select(pl.exclude("true_token_time")).write_csv(
    results_dir / "trajectories_patients.csv"
)

### Generate trajectories

In [None]:
from enum import StrEnum

import altair as alt
import numpy as np
import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.constants import SpecialToken as ST
from ethos.inference.constants import Reason, Task
from ethos.metrics import preprocess_inference_results

results_hosp = PROJECT_ROOT / "results" / Task.HOSPITAL_MORTALITY_SINGLE
results_icu = PROJECT_ROOT / "results" / Task.ICU_ADMISSION_SINGLE
results_trajectories = PROJECT_ROOT / "results" / "trajectories"
patients_info = pl.read_csv(results_trajectories / "trajectories_patients.csv")

In [None]:
prolonged_stay_cutoff = timedelta(days=10)
prolonged_stay_cutoff2 = timedelta(days=15)

first_cutoff_condition = (
    pl.col("true_token_time").first() - pl.col("true_token_time") < prolonged_stay_cutoff
)
first_cutoff = pl.col("token_time") >= prolonged_stay_cutoff - (
    pl.col("true_token_time").first() - pl.col("true_token_time")
)
second_cutoff = pl.col("token_time") >= prolonged_stay_cutoff2 - (
    pl.col("true_token_time").first() - pl.col("true_token_time")
)

config_by_task = {
    ST.DEATH: {
        "actual_expr": pl.col("actual").is_in([ST.DEATH]),
        "expected_expr": pl.col("expected").is_in([ST.DEATH]),
    },
    ST.ICU_ADMISSION: {
        "actual_expr": pl.col("actual").is_in([ST.ICU_ADMISSION]),
        "expected_expr": pl.col("expected").is_in([ST.ICU_ADMISSION]),
    },
    "COMPOSITE": {
        "actual_expr": pl.col("actual").is_in([ST.ICU_ADMISSION, ST.DEATH])
        | (pl.when(first_cutoff_condition).then(first_cutoff).otherwise(second_cutoff)),
        "expected_expr": pl.col("expected").is_in([ST.ICU_ADMISSION, ST.DEATH])
        | (pl.when(first_cutoff_condition).then(first_cutoff).otherwise(second_cutoff)),
    },
    "PROLONGED_STAY": {
        "actual_expr": (
            pl.when(first_cutoff_condition).then(first_cutoff).otherwise(second_cutoff)
        ),
        "expected_expr": (
            pl.when(first_cutoff_condition).then(first_cutoff).otherwise(second_cutoff)
        ),
    },
}


class TrajectoryType(StrEnum):
    TOKEN = "TOKEN"
    TIME = "TIME"
    TOKEN_GROUPED = "TOKEN_GROUPED"


def group_tokens_by_piece_of_info(df: pl.DataFrame):
    full_tokens = {
        "ED_REGISTRATION": 2,
        "ED_ACUITY": 2,
        "HOSPITAL_ADMISSION": 2,
        "ICU_ADMISSION": 2,
        "SOFA": 2,
    }

    prefix_tokens = {
        "ICD//PCS": 7,
        "ICD//CM": 3,
        "VITAL": 2,
        "LAB": 2,
        "ATC": 3,
    }

    tokens = df["start_token"].to_list()
    counter = iter(range(len(tokens)))
    groups = []
    while tokens:
        number_of_pops = 1
        token, group_numer = tokens[0], next(counter)
        if token in full_tokens:
            number_of_pops = full_tokens[token]
        elif prefix := [prefix for prefix in prefix_tokens if token.startswith(prefix)]:
            number_of_pops = prefix_tokens[prefix[0]]

        for _ in range(number_of_pops):
            groups.append(group_numer)
            tokens.pop(0)

    return df.group_by(by=pl.Series(groups).set_sorted()).agg(
        pl.col("start_token").str.join(" "),
        pl.exclude("start_token", "true_token_time").last(),
        pl.first("true_token_time"),
    )


def format_duration(duration):
    days = duration // 1_000_000 // 86400
    hours = (duration // 1_000_000 % 86400) // 3600

    parts = []
    if days > 0:
        parts.append(f"{days}d")
    if hours > 0:
        parts.append(f"{hours}h")
    if len(parts) == 0:
        return "0h"
    return " ".join(parts)


time_label_expr_one_liner_d = (
    "(floor(datum.value / 1e6 / 86400) > 0 ? floor(datum.value / 1e6 / 86400) + 'd ' : '')"
)
time_label_expr_one_liner_d0 = "floor(datum.value / 1e6 / 86400) + 'd')"

time_label_expr_one_liner_d_h = (
    time_label_expr_one_liner_d
    + " +  (floor(datum.value / 1e6 % 86400 / 3600) > 0 ? floor(datum.value / 1e6 % 86400 / 3600) + 'h ' : '')"
)

time_label_expr_one_liner_d_h_m = (
    time_label_expr_one_liner_d_h
    + " + (floor(datum.value / 1e6 % 3600 / 60) > 0 ? floor(datum.value / 1e6 % 3600 / 60) + 'min ' : '')"
)
color_gradient = alt.Color(
    "actual", legend=None, scale=alt.Scale(domain=[0, 0.5, 1], range=["lightgrey", "orange", "red"])
)


def process_df_for_display(df):
    return (
        df.with_row_index("index")
        .with_columns(
            actual_diff=pl.col("actual").diff(),
            actual_display=pl.when(pl.col("actual") == 1.0)
            .then(pl.lit("1.0"))
            .otherwise(
                pl.col("actual")
                .round(2)
                .cast(pl.String)
                .str.slice(1)
                .str.replace(r"(\.\d)$", r"${1}" + "0")
            ),
            start_token_display=pl.col("start_token").str.replace("//", "-").str.replace("//", "-"),
            zero=0,
            true_token_time_duration=(
                pl.col("true_token_time").first() - pl.col("true_token_time")
            ).cast(pl.UInt64),
            true_token_time=pl.col("true_token_time").cast(pl.UInt64),
            token_time=pl.col("token_time").cast(pl.UInt64),
            prob_error=((pl.col("actual") * (1 - pl.col("actual"))) / pl.col("counts")).sqrt(),
        )
        .with_columns(
            start_token_display=pl.when(pl.col("start_token_display").str.len_chars() > 15)
            .then(pl.col("start_token_display").str.slice(0, 15) + "...")
            .otherwise(pl.col("start_token_display")),
            true_token_time_duration_display=pl.col("true_token_time_duration").map_elements(
                format_duration, return_dtype=pl.String
            ),
            actual_plus_error=pl.min_horizontal(
                pl.col("actual") + pl.col("prob_error") * 1.96, pl.lit(1)
            ),
            actual_minus_error=pl.col("actual") - pl.col("prob_error") * 1.96,
        )
    )


def inference_results_for_token_of_interest(
    dir_name, token_of_interest=ST.ICU_ADMISSION, trajectory_type=TrajectoryType.TOKEN
):
    df = preprocess_inference_results(
        dir_name,
        actual_expr=config_by_task[token_of_interest]["actual_expr"],
        expected_expr=config_by_task[token_of_interest]["expected_expr"],
        filter_ambiguous=(
            ~pl.col("actual").is_in([ST.TIMELINE_END])
            & pl.col("stop_reason").is_in([Reason.GOT_TOKEN])
        ),
        additional_columns=["start_token"],
        warn_on_dropped=False,
    ).sort("data_idx")

    match trajectory_type:
        case TrajectoryType.TOKEN:
            return process_df_for_display(df)
        case TrajectoryType.TIME:
            df = df.group_by("true_token_time").agg(pl.all().last()).sort("data_idx")
        case TrajectoryType.TOKEN_GROUPED:
            df = group_tokens_by_piece_of_info(df).sort("data_idx")

    return process_df_for_display(df)


def create_token_view(df, brushes, width, trajectory_type):
    spacing = 10
    width_token_view = (width - (spacing * len(brushes) - 1)) / len(brushes)
    match trajectory_type:
        case TrajectoryType.TOKEN:
            x_col = "index"
        case TrajectoryType.TIME:
            x_col = "true_token_time_duration"
        case TrajectoryType.TOKEN_GROUPED:
            x_col = "index"
        case _:
            raise ValueError(trajectory_type)
    result = None
    for brush in brushes:
        base_token_view = alt.Chart(df).encode(
            x=alt.X(
                x_col,
                axis=alt.Axis(grid=False, title=None, format="d", labels=False, ticks=False),
                scale=alt.Scale(domain=brush),
            ),
            y=alt.Y("zero:Q", scale=alt.Scale(domain=[-0.2, 0.2])),
        )

        line = base_token_view.mark_line(size=1, color="#5a5255")
        squares = base_token_view.mark_square(size=400, opacity=1).encode(color=color_gradient)

        actual_text = base_token_view.mark_text(dy=5, fontSize=10, color="black").encode(
            text=alt.Text("actual_display")
        )

        start_token_text = (
            base_token_view.mark_text(fontSize=10)
            .transform_calculate(y_adjusted="datum.zero + (datum.index % 2 == 0 ? -0.15 : 0.15)")
            .encode(
                y=alt.Y(
                    "y_adjusted:Q",
                    axis=alt.Axis(grid=False, title=None, labels=False, ticks=False),
                    scale=alt.Scale(domain=[-0.2, 0.2]),
                ),
                text="start_token_display:N",
                color=alt.value("black"),
            )
        )

        token_view = (line + squares + actual_text + start_token_text).properties(
            width=width_token_view, height=60
        )
        result = token_view if result is None else alt.hconcat(result, token_view, spacing=spacing)

    return result


def create_main_chart(
    dfs,
    labels,
    width,
    trajectory_type,
    return_focused=False,
    token_view_window_num=1,
    token_view_ranges=None,
):
    full = None
    focused = None

    match trajectory_type:
        case TrajectoryType.TOKEN:
            x_col = "index"
            x_title = "Token Number"
            x_format = "d"
            labelExpr_full = labelExpr_focused = "datum.value"
        case TrajectoryType.TIME:
            x_col = "true_token_time_duration"
            x_title = "Time"
            x_format = ""
            labelExpr_full = time_label_expr_one_liner_d_h
            labelExpr_focused = time_label_expr_one_liner_d_h_m
        case TrajectoryType.TOKEN_GROUPED:
            x_col = "index"
            x_title = "Token Grouped"
            x_format = "d"
            labelExpr_full = labelExpr_focused = "datum.value"
        case _:
            raise ValueError(trajectory_type)

    if token_view_ranges is not None and len(token_view_ranges) != token_view_window_num:
        raise ValueError(
            f"Length of ranges expected {token_view_window_num}, got {len(token_view_ranges)}"
        )

    if token_view_ranges is None:
        token_view_ranges = []
        step = len(dfs[-1]) // (token_view_window_num + 1)
        for i in range(token_view_window_num):
            r = step * (i + 1) - 2.5, step * (i + 1) + 2.5
            token_view_ranges.append(r)

    x_min, x_max = dfs[0]["index"].min(), dfs[0]["index"].max()

    for idx, (df, label) in enumerate(zip(dfs, labels)):
        color_scale = alt.Scale(
            domain=["COMPOSITE", "DEATH", "ICU ADMISSION", "PROLONGED STAY"],
            range=["purple", "#bf5b17", "#386cb0", "#666666"],
        )
        brushes = [
            alt.selection_interval(encodings=["x"], value={"x": r}) for r in token_view_ranges
        ]
        df = df.with_columns(Risk=pl.lit(label))

        error_band = (
            alt.Chart(df)
            .mark_area(opacity=0.3)
            .encode(
                x=alt.X(
                    x_col,
                    title=x_title,
                    axis=alt.Axis(grid=False),
                    scale=alt.Scale(domain=[x_min, x_max]),
                ),
                y=alt.Y(
                    "actual_minus_error",
                    title=None,
                    axis=alt.Axis(grid=False),
                    scale=alt.Scale(domain=[0.0, 1.0]),
                ),
                y2=alt.Y2("actual_plus_error"),
                color=alt.Color("Risk:N", scale=color_scale),
            )
        )

        full_view = (
            alt.Chart(df, width=width, height=150)
            .mark_line(strokeWidth=1.3)
            .encode(
                x=alt.X(
                    x_col,
                    title=x_title,
                    axis=alt.Axis(grid=False, labelExpr=labelExpr_full),
                    scale=alt.Scale(domain=[x_min, x_max]),
                ),
                y=alt.Y(
                    "actual",
                    title=None,
                    axis=alt.Axis(grid=False),
                    scale=alt.Scale(domain=[0.0, 1.0]),
                ),
                color=alt.Color(
                    "Risk:N",
                    scale=color_scale,
                    legend=alt.Legend(orient="none", legendX=5, legendY=5, title=None, padding=5),
                ),
            )
        )
        full_view = (error_band + full_view).add_params(*brushes)

        if return_focused:
            focused_view = (
                alt.Chart(df, width=width, height=200)
                .mark_line()
                .encode(
                    x=alt.X(
                        x_col,
                        title=None,
                        axis=alt.Axis(format=x_format, grid=False, labelExpr=labelExpr_focused),
                        scale=alt.Scale(domain=brushes[0]),
                    ),
                    y=alt.Y(
                        "actual",
                        title=None,
                        axis=alt.Axis(grid=False),
                        scale=alt.Scale(domain=[0.0, 1.0]),
                    ),
                    color=alt.Color("Risk:N", title="Risk Score", legend=None),
                )
            )
            focused = focused_view if focused is None else focused + focused_view
        full = full_view if full is None else full + full_view

    if return_focused:
        return focused & full, brushes

    return full, brushes


def sample_df_and_transform_for_rect_drawing(
    df, indices, offset, m1_offset, p1_offset, block_width, shift
):
    return (
        df.filter(pl.col("index").is_in(indices))
        .with_row_index("n_index")
        .with_columns(
            n_index=pl.col("index") + offset * block_width - shift,
            n_index_p1=(pl.col("index") + offset * block_width) + block_width - shift,
            m1=(pl.col("zero") + m1_offset),
            p1=(pl.col("zero") + p1_offset),
        )
    )


def create_ares_block(df, x_col, x2_col, y_col, y2_col, n_samples, scale):
    return (
        alt.Chart(df)
        .mark_rect(opacity=1, cornerRadius=10)
        .encode(
            x=alt.X(
                f"{x_col}:Q",
                scale=alt.Scale(domain=[0, scale]),
                axis=alt.Axis(
                    grid=False, title=None, orient="top", labels=False, ticks=False, domain=False
                ),
            ),
            x2=f"{x2_col}:Q",
            y=alt.Y(
                f"{y_col}:Q",
                scale=alt.Scale(domain=[-2.0, 1.5]),
                axis=alt.Axis(grid=False, title=None, labels=False, ticks=False, domain=False),
            ),
            y2=f"{y2_col}:Q",
            color=color_gradient,
        )
    )


def create_timeline(df, tokens_num):
    arrow = alt.layer(
        alt.Chart()
        .mark_line(size=2)
        .encode(
            x=alt.datum(0, scale=alt.Scale(domain=[0, tokens_num])),
            y=alt.datum(-1.25),
            x2=alt.datum(tokens_num),
            y2=alt.datum(-1.30),
        ),
        alt.Chart()
        .mark_point(shape="triangle", filled=True, fillOpacity=1)
        .encode(
            x=alt.datum(tokens_num, scale=alt.Scale(domain=[0, tokens_num])),
            y=alt.datum(-1.275),
            angle=alt.AngleValue(90),
            size=alt.SizeValue(100),
            color=alt.ColorValue("#000000"),
        ),
    )

    dots = (
        alt.Chart(df)
        .mark_point(filled=True, fillOpacity=1)
        .encode(
            x=alt.X(
                "index:Q",
                axis=alt.Axis(grid=False, title=None, labels=False, ticks=False),
                scale=alt.Scale(domain=[0, tokens_num]),
            ),
            y=alt.datum(-1.275),
            size=alt.SizeValue(80),
            color=alt.ColorValue("#000000"),
        )
    )

    texts = (
        alt.Chart(df)
        .mark_text(fillOpacity=1)
        .encode(
            x=alt.X(
                "index:Q",
                axis=alt.Axis(grid=False, title=None, labels=False, ticks=False),
                scale=alt.Scale(domain=[0, tokens_num]),
            ),
            text="true_token_time_duration_display",
            y=alt.datum(-1.825),
            size=alt.SizeValue(11),
            color=alt.ColorValue("#000000"),
        )
    )

    return arrow + dots + texts


def create_ares_overview(dfs, width):
    n_samples = 10
    block_width = len(dfs[-1]) / (n_samples * 3 - 1)
    shift = block_width // 2
    indices = (
        np.linspace(0, len(dfs[-1]) - 1 + (len(dfs[-1]) / n_samples / 3), n_samples + 1, dtype=int)
        + shift
    )[:-1]

    tokens_num = len(dfs[-1])

    offsets = [(1, 0.5, 1), (1, 0.25, -0.25), (1, -1, -0.5), (0, -1, 1)]
    sampled_dfs = [
        sample_df_and_transform_for_rect_drawing(df, indices, *offset, block_width, shift)
        for df, offset in zip(dfs, offsets)
    ]
    charts = [
        create_ares_block(df, "n_index", "n_index_p1", "m1", "p1", n_samples, tokens_num)
        for df in sampled_dfs
    ]

    timeline = create_timeline(sampled_dfs[-1], tokens_num)

    return alt.layer(timeline, *charts).properties(
        width=width, height=70, title="ARES", view=alt.ViewConfig(stroke=None)
    )


def plot_adaptive_ews(
    dir_name_hosp,
    dir_name_icu,
    trajectory_type=TrajectoryType.TOKEN,
    token_view_window_num=1,
    token_view_ranges=None,
    return_dfs=False,
):
    width = 800
    df_death = inference_results_for_token_of_interest(dir_name_hosp, ST.DEATH, trajectory_type)
    df_icu = inference_results_for_token_of_interest(
        dir_name_icu, ST.ICU_ADMISSION, trajectory_type
    )
    df_prolonged = inference_results_for_token_of_interest(
        dir_name_hosp, "PROLONGED_STAY", trajectory_type
    )
    # TODO using hospital trajectory for compound as it is longer
    df_compound = inference_results_for_token_of_interest(
        dir_name_hosp, "COMPOSITE", trajectory_type
    )

    icu_values = df_icu["actual"].to_list() + [0.0] * (len(df_death) - len(df_icu))
    df_icu_extended = pl.Series("actual_icu", icu_values, dtype=pl.Float64)

    df_compound = (
        df_compound.with_columns(
            actual=pl.min_horizontal(
                [df_death["actual"] + df_prolonged["actual"] + df_icu_extended, pl.lit(1)]
            )
        )
        .with_columns(
            prob_error=((pl.col("actual") * (1 - pl.col("actual"))) / pl.col("counts")).sqrt(),
            actual_display=pl.when(pl.col("actual") == 1.0)
            .then(pl.lit("1.0"))
            .otherwise(
                pl.col("actual")
                .round(2)
                .cast(pl.String)
                .str.slice(1)
                .str.replace(r"(\.\d)$", r"${1}" + "0")
            ),
        )
        .with_columns(
            actual_plus_error=pl.min_horizontal(
                pl.col("actual") + pl.col("prob_error") * 1.96, pl.lit(1)
            ),
            actual_minus_error=pl.col("actual") - pl.col("prob_error") * 1.96,
        )
    )

    dfs = [df_death, df_icu, df_prolonged, df_compound]
    labels = ["DEATH", "ICU ADMISSION", "PROLONGED STAY", "COMPOSITE"]

    main_chart, brushes = create_main_chart(
        dfs,
        labels,
        width,
        trajectory_type,
        token_view_window_num=token_view_window_num,
        token_view_ranges=token_view_ranges,
    )
    token_view = create_token_view(df_compound, brushes, width, trajectory_type)
    composite_overview = create_ares_overview(dfs, width)

    if return_dfs:
        return token_view & main_chart & composite_overview, dfs

    return token_view & main_chart & composite_overview


chart, dfs = plot_adaptive_ews(
    results_hosp / "6982_rep_size_50_2025-01-25_18-29-17",
    results_icu / "6982_rep_size_50_2025-01-26_04-59-01",
    TrajectoryType.TOKEN,
    3,
    return_dfs=True,
)
chart

In [None]:
import os

files_hosp = os.listdir(results_hosp)
files_icu = os.listdir(results_icu)

pid_to_file_hosp = {p.split("_")[0]: p for p in files_hosp}
pid_to_file_icu = {p.split("_")[0]: p for p in files_icu}

ranges_token_view = {
    "7329": [(28.5, 33.5), (508.5, 513.5), (1000.5, 1005.5)],
    "7110": [(92.5, 97.5), (508.5, 513.5), (1393.5, 1398.5)],
    "719": [(110.5, 115.5), (378.5, 383.5), (474.5, 479.5)],
    "415": [(245.5, 250.5), (375.5, 380.5), (940.5, 945.5)],
    "6387830": [(245.5, 250.5), (375.5, 380.5), (941.5, 946.5)],
}
for pid, p_icu in pid_to_file_icu.items():
    p_hosp = p_icu if pid not in pid_to_file_hosp else pid_to_file_hosp[pid]
    p_full_icu = results_icu / p_icu
    p_full_hosp = results_hosp / p_hosp if pid in pid_to_file_hosp else results_icu / p_hosp
    ranges_token = ranges_token_view[pid] if pid in ranges_token_view else None
    results_path = results_trajectories / pid

    if results_path.exists():
        continue
    print(f"Generating trajectories for {pid}, {p_icu}")
    results_path.mkdir(parents=True, exist_ok=True)
    token_view_window_num = len(ranges_token) if ranges_token else 3
    chart_token, dfs_token = plot_adaptive_ews(
        p_full_hosp,
        p_full_icu,
        trajectory_type=TrajectoryType.TOKEN,
        token_view_window_num=token_view_window_num,
        token_view_ranges=ranges_token,
        return_dfs=True,
    )
    chart_time, dfs_time = plot_adaptive_ews(
        p_full_hosp,
        p_full_icu,
        trajectory_type=TrajectoryType.TIME,
        token_view_window_num=3,
        return_dfs=True,
    )
    chart_grouped, dfs_token_grouped = plot_adaptive_ews(
        p_full_hosp,
        p_full_icu,
        trajectory_type=TrajectoryType.TOKEN_GROUPED,
        token_view_window_num=3,
        return_dfs=True,
    )
    chart_token_1 = plot_adaptive_ews(
        p_full_hosp,
        p_full_icu,
        trajectory_type=TrajectoryType.TOKEN,
        token_view_window_num=1,
        token_view_ranges=[ranges_token[-1]] if ranges_token else None,
        return_dfs=False,
    )

    chart_token.save(results_path / f"ares_{pid}_token.html")
    chart_time.save(results_path / f"ares_{pid}_time.html")
    chart_grouped.save(results_path / f"ares_{pid}_token_grouped.html")
    chart_token_1.save(results_path / f"ares_{pid}_token_1.html")