In [None]:
import numpy as np
import pandas as pd
import os
import re
from prediction_utils.pytorch_utils.metrics import StandardEvaluator

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
result_path = os.path.join("../zipcode_cvd/experiments/tables", "cohort_tables")
os.makedirs(result_path, exist_ok=True)

### ASCVD Cohort Tables

Variables to compute
    
    * Total counts
    * Uncensored frac
    * Incidence of outcome in uncensored population

In [None]:
db_key = 'optum'

In [None]:
data_paths = {
    "optum": "/local-scratch/nigam/secure/optum/spfohl/zipcode_cvd/optum/dod"
}

cohort_filename='cohort_fold_1_5_ipcw.parquet'

cohort_paths = {
    key: os.path.join(value, "cohort", cohort_filename)
    for key, value in data_paths.items()
}
attributes = [
    "age_group", "gender_concept_name", "race_eth", "race_eth_gender",  
    "has_diabetes_type2_history", "has_diabetes_type1_history",
    "has_ra_history",  "has_ckd_history"
]
tasks = ["ascvd_binary"]

In [None]:
cohort = pd.read_parquet(cohort_paths[db_key])

In [None]:
cohort = cohort.assign(
    ipcw_weight_clip_100=lambda x: np.minimum(x.ipcw_weight, 100),
    ipcw_weight_clip_20=lambda x: np.minimum(x.ipcw_weight, 20)
)

In [None]:
cohort.groupby(['ascvd_binary', 'sampled']).agg(ipcw_weight=('ipcw_weight', 'mean'))

In [None]:
cohort_df_long = cohort.melt(
    id_vars=["person_id", "ascvd_binary", "censored_binary", "ipcw_weight", "ipcw_weight_clip_100", "ipcw_weight_clip_20", "fold_id"],
    value_vars=attributes,
    var_name="attribute",
    value_name="group",
)

In [None]:
evaluator = StandardEvaluator(metrics=["outcome_rate"])
evaluator.evaluate(
    df=cohort_df_long.assign(pred_probs=1).query(
        '~ascvd_binary.isnull() & fold_id == "test"'
    ),
    strata_vars=["attribute", "group", "fold_id"],
    label_var="ascvd_binary",
)
evaluator = StandardEvaluator(metrics=["outcome_rate"])
evaluator.evaluate(
    df=cohort_df_long.assign(pred_probs=1).query(
        '~ascvd_binary.isnull() & fold_id == "test"'
    ),
    strata_vars=["attribute", "group", "fold_id"],
    label_var="ascvd_binary",
    weight_var="ipcw_weight",
)

In [None]:
uncensored_statistics_df = (
    cohort_df_long.query("~ascvd_binary.isnull() and fold_id == 'eval'")
    .groupby(["attribute", "group"])
    .apply(
        lambda x: pd.DataFrame(
            {
                "incidence": np.average(x.ascvd_binary),
                "incidence_adjusted": np.average(x.ascvd_binary, weights=x.ipcw_weight),
                "incidence_adjusted_clip_20": np.average(x.ascvd_binary, weights=x.ipcw_weight_clip_20),
                "incidence_adjusted_clip_100": np.average(x.ascvd_binary, weights=x.ipcw_weight_clip_100),
            },
            index=[x.name],
        ),
    )
    .reset_index(level=-1, drop=True)
    .reset_index()
)
uncensored_statistics_df

In [None]:
cohort_statistics_df_temp = (
    cohort_df_long.groupby(["attribute", "group"])
    .agg(
        censoring_rate=("censored_binary", "mean"),
        Count=("censored_binary", lambda x: x.shape[0]),
    )
    .reset_index()
)
cohort_statistics_df_temp

In [None]:
group_label_dict = {
    "Black or African American": "Black",
    "Hispanic or Latino": "Hispanic",
    "FEMALE": "Female",
    "MALE": "Male",
    "Asian | FEMALE": "Asian, female",
    "Asian | MALE": "Asian, male",
    "Black or African American | FEMALE": "Black, female",
    "Black or African American | MALE": "Black, male",
    "Hispanic or Latino | FEMALE": "Hispanic, female",
    "Hispanic or Latino | MALE": "Hispanic, male",
    "Other | FEMALE": "Other, female",
    "Other | MALE": "Other, male",
    "White | FEMALE": "White, female",
    "White | MALE": "White, male",
    'ra_absent': 'RA absent',
    'ra_present': 'RA present',
    'ckd_absent': 'CKD absent',
    'ckd_present': 'CKD present',
    'diabetes_type1_absent': 'Type 1 diabetes absent',
    'diabetes_type1_present': 'Type 1 diabetes present',
    'diabetes_type2_absent': 'Type 2 diabetes absent',
    'diabetes_type2_present': 'Type 2 diabetes present'
}

In [None]:
cohort_statistics_df = cohort_statistics_df_temp.merge(uncensored_statistics_df)
cohort_statistics_df = (
    cohort_statistics_df.set_index(["attribute", "group"])
#     .reindex(columns=["Count", "censoring_rate", "incidence", "incidence_adjusted"])
    .reindex(columns=["Count", "censoring_rate", "incidence_adjusted"])
    .reindex(axis='index', level=0, labels=attributes)
    .reset_index(level="attribute", drop=True)
    .rename(
        columns={
            "size": "Count",
            "censoring_rate": "Censoring rate",
#             "incidence": "Incidence (unadjusted)",
            "incidence_adjusted": "Incidence",
        },
        index=group_label_dict
    )
    .assign(Count=lambda x: x.Count.apply("{:,}".format))
)
cohort_statistics_df

In [None]:
cohort = pd.read_parquet(os.path.join(cohort_paths[db_key]))

cohort_df_long = cohort.melt(
    id_vars=[
        "person_id",
        "ascvd_binary",
        "censored_binary",
        "ipcw_weight",
        "fold_id",
    ],
    value_vars=attributes,
    var_name="attribute",
    value_name="group",
)

uncensored_statistics_df = (
    cohort_df_long.query('~ascvd_binary.isnull() & fold_id == "test"')
    .groupby(["attribute", "group"])
    .apply(
        lambda x: pd.DataFrame(
            {
#                 "incidence": np.average(x.ascvd_binary),
                "incidence_adjusted": np.average(
                    x.ascvd_binary, weights=x.ipcw_weight
                ),
            },
            index=[x.name],
        ),
    )
    .reset_index(level=-1, drop=True)
    .reset_index()
)

cohort_statistics_df = (
    cohort_df_long.groupby(["attribute", "group"])
    .agg(
        censoring_rate=("censored_binary", "mean"),
        Count=("censored_binary", lambda x: x.shape[0]),
    )
    .reset_index()
)

cohort_statistics_df = cohort_statistics_df.merge(uncensored_statistics_df)
cohort_statistics_df = (
    cohort_statistics_df.set_index(["attribute", "group"])
#     .reindex(columns=["Count", "censoring_rate", "incidence", "incidence_adjusted"])
    .reindex(columns=["Count", "censoring_rate", "incidence_adjusted"])
    .reindex(axis='index', level=0, labels=attributes)
    .reset_index(level="attribute", drop=True)
    .rename(
        columns={
            "size": "Count",
            "censoring_rate": "Censoring rate",
#             "incidence": "Incidence (unadjusted)",
            "incidence_adjusted": "Incidence",
        },
        index=group_label_dict
    )
    .assign(Count=lambda x: x.Count.apply("{:,}".format))
)

table_str = cohort_statistics_df.to_latex(
    buf=None,
    float_format="%.3g",
    index_names=False,
    index=True,
    label=f"tab:cohort_{db_key}",
    position="!t",
    caption="A caption",
).replace("toprule\n{}", "toprule\n Group")

if isinstance(table_str, tuple):
    table_str = table_str[0]

table_str = re.sub(pattern="\[(?=\d)", repl=r"\\lbrack", string=table_str)

with open(os.path.join(result_path, f"{db_key}.txt"), "w") as fp:
    fp.write(table_str)