# Description

Notebook for breast cancer screening RDD discovery analysis. Exercises the end-to-end process of RDD discovery.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import seaborn as sns
import sys

from tqdm import tqdm

In [None]:
# sample split for TMLR rebuttal
from sklearn.model_selection import train_test_split

In [None]:
sys.path.append("/home/liutony/optum-pipeline/notebooks/tmlr/")

import rdsgd

In [None]:
# user imports 
sys.path.append("../../")

from rddd.feat import gen_feat_df
#from rdsgd import *
#from rddd.rddd import policy_tree_discovery, test_discontinuity, create_feat_df

In [None]:
# notebook magics
%load_ext autoreload

%autoreload 2

%matplotlib inline

# Load data

In [None]:
%%time
cc_df = pd.read_parquet("/project/liu_optum_causal_inference/data/colon_cancer/merge/colon_cancer.parq")

In [None]:
cc_df.columns

In [None]:
print(cc_df.shape)

# Clean data

In [None]:
%%time
cc_feat = gen_feat_df(cc_df)

In [None]:
cc_feat.columns

In [None]:
cc_feat['d_household_income_range_code'].value_counts()

# Run RDSGD

In [None]:
test_df = cc_feat.copy()

In [None]:
# import warnings
# warnings.filterwarnings("ignore", module='sk.*')

In [None]:
%%time
grid_dict = {
    'age': np.arange(40, 61, 5)
}
alpha = 0.05
treat = 'indicator'
running_cols = ['age']
tree_kwargs = {
    'max_depth': 2,
    'min_balancedness_tol': 0.3,
}
random_state = 42
bw = 4

sample_df = test_df.copy()

In [None]:
# add in sample splitting
s1_df, s2_df = train_test_split(sample_df, test_size=0.5, random_state=random_state)

In [None]:
s1_df.shape

In [None]:
%%time
subgroup_dict, num_tests = rdsgd.rd_subgroup_discovery(s1_df,
                                                 running_cols=running_cols,
                                                 grid_dict=grid_dict,
                                                 treat=treat,
                                                 alpha=alpha,
                                                 rescale=False,
                                                 omit_mask=True,
                                                 bw=bw,
                                                 #tree_kwargs=tree_kwargs,
                                                 #random_state=random_state
                                                )

In [None]:
s2_df.columns

In [None]:
len(subgroup_dict['age'][60])

In [None]:
sel_nodes = []
for cutoff, nodes in subgroup_dict['age'].items():
    for node in nodes:
        rule_path = node['rule_path']
        holdout = s2_df.copy()
        for rule in rule_path[:-1]:
            if rule.path_dir == '<':
                holdout = holdout[holdout[rule.feature] < rule.threshold]
            elif rule.path_dir == '>=':
                holdout = holdout[holdout[rule.feature] >= rule.threshold]
            elif rule.path_dir == '<=':
                holdout = holdout[holdout[rule.feature] <= rule.threshold]
            elif rule.path_dir == '>':
                holdout = holdout[holdout[rule.feature] > rule.threshold]
            elif rule.path_dir == '==':
                holdout = holdout[holdout[rule.feature] == rule.threshold]

        llr_results, _, _ = rdsgd.test_discontinuity(holdout, cutoff, 'age', treat=treat, bw=bw, kernel='triangular')
        node['llr_results'] = llr_results

        #if (node['llr_results'].pvalues['z'] < alpha / num_tests):
            #sel_nodes.append((cutoff, node))

In [None]:
out_dir = "/project/liu_optum_causal_inference/results/tmlr_sample_split/"

In [None]:
pickle.dump((subgroup_dict, num_tests), open(f"{out_dir}/colon_cancer_subgroup_results_tmlr.pkl", "wb"), -1)

# Extract baseline and subgroup data

In [None]:
cutoff = 50
running = 'age'
bw = 4
baseline_df = create_feat_df(sample_df, running=running, 
                             cutoff=cutoff, bw=bw)

In [None]:
baseline_df['in_subgroup'] = (sorted_nodes[0][1]['subgroup_mask']).astype(int)

#baseline_df[[running, 'indicator', 'in_subgroup']].to_parquet("/project/liu_optum_causal_inference/results/colon_cancer_running.parq")

In [None]:
print(baseline_df.shape)

In [None]:
col = 'gdr_cd_F'

def get_descriptives(baseline_df, col):
    count_df = baseline_df[col].value_counts().to_frame()
    pct_df = (baseline_df[col].value_counts() / baseline_df.shape[0]).to_frame()
    display(pd.concat([count_df, pct_df], axis=1))
    
get_descriptives(col)

In [None]:
for col in baseline_df.columns:
    get_descriptives(baseline_df, col)

In [None]:
baseline_df['age'].describe()