In [None]:
import numpy as np
import pandas as pd
import os
import re
from group_robustness_fairness.prediction_utils.pytorch_utils.metrics import StandardEvaluator

In [None]:
result_path = os.path.join("../tables", "cohort_tables")
os.makedirs(result_path, exist_ok=True)

#### STARR Admissions Cohort Table

In [None]:
data_paths = {
    "admissions": "/local-scratch/nigam/projects/spfohl/group_robustness_fairness/cohorts/admissions/starr_20210130"
}

cohort_paths = {
    key: os.path.join(value, "cohort", "cohort_fold_1_5.parquet")
    for key, value in data_paths.items()
}

db_key = "admissions"
cohort = pd.read_parquet(os.path.join(cohort_paths[db_key]))

In [None]:
attributes = ["gender_concept_name", "age_group", "race_eth"]
tasks = ["hospital_mortality", "LOS_7", "readmission_30"]

In [None]:
cohort_df_long = cohort.melt(
    id_vars=["person_id"] + attributes,
    value_vars=tasks,
    var_name="task",
    value_name="labels",
).melt(
    id_vars=["person_id", "task", "labels"],
    value_vars=attributes,
    var_name="attribute",
    value_name="group",
)

In [None]:
cohort_statistics_df = (
    cohort_df_long.groupby(["task", "attribute", "group"])
    .agg(
        prevalence=("labels", "mean"),
    )
    .reset_index()
    .groupby("attribute")
    .apply(lambda x: x.pivot_table(index="group", columns="task", values="prevalence"))
    .reset_index()
)

group_size_df = (
    cohort_df_long.groupby(["task", "attribute", "group"])
    .agg(size=("labels", lambda x: x.shape[0]))
    .reset_index()
    .drop(columns="task")
    .drop_duplicates()
)

cohort_statistics_df = cohort_statistics_df.merge(group_size_df)
cohort_statistics_df = cohort_statistics_df.set_index(["attribute", "group"])[
    ["size"] + tasks
]

In [None]:
result_df = (
    cohort_statistics_df.reset_index()
    .query('~(group == "No matching concept" & attribute == "gender_concept_name")')
    .drop(columns="attribute")
    .set_index(["group"])
    .rename(
        columns={
            "size": "Count",
            "hospital_mortality": "Hospital Mortality",
            "LOS_7": "Prolonged Length of Stay",
            "readmission_30": "30-Day Readmission",
        },
        index={
            "Black or African American": "Black",
            "Hispanic or Latino": "Hispanic",
            "FEMALE": "Female",
            "MALE": "Male",
        },
    )
    .assign(Count=lambda x: x.Count.apply("{:,}".format))
)

In [None]:
result_df

In [None]:
caption_string = "Cohort characteristics for patients drawn from STARR. Data are grouped on the basis of age, sex, and the race and ethnicity category. Shown, for each group, is the number of patients extracted and the incidence of hospital mortality, prolonged length of stay, and 30-day readmission"
table_str = (
    result_df.to_latex(
        buf=None,
        float_format="%.3g",
        index_names=False,
        index=True,
        label="tab:cohort_starr_admissions",
        position="!t",
        caption=caption_string,
    )
    .replace("75-91", "75-90")
    .replace("toprule\n{}", "toprule\n Group")
    .replace(
        "toprule\n Group",
        "toprule\n{} & {} & \multicolumn{3}{c}{Outcome Incidence} \\\\\n\\cmidrule{3-5}\nGroup",
    )
)

if isinstance(table_str, tuple):
    table_str = table_str[0]

table_str = re.sub(pattern="\[(?=\d)", repl=r"\\lbrack", string=table_str)

with open(os.path.join(result_path, "admissions_starr.txt"), "w") as fp:
    fp.write(table_str)