In [None]:
import pandas as pd
import numpy as np
from tableone import TableOne
import glob
import re
import os

In [None]:
def get_all_subs():
    
    subs_1year_paths = glob.glob(f'../../../sub-*/ses-1year')
    subs_1year = [['_'.join(re.split(r'(MM|HC)', path.split('/')[3].split('-')[1])[1:]),path.split('/')[4].split('-')[1]] for path in subs_1year_paths]

    subs_baseline_paths = glob.glob(f'../../../sub-*/ses-baseline')
    subs_baseline = [['_'.join(re.split(r'(MM|HC)', path.split('/')[3].split('-')[1])[1:]),path.split('/')[4].split('-')[1]] for path in subs_baseline_paths]

    subs_all = subs_baseline+subs_1year
    
    return subs_all
    

In [None]:
def create_indiv_subs_df(subs_input):
    #dataframe to add columns to for each subject
    subs_input.sort()
    subs = [[item[0],item[0]+'_'+item[1],item[0].split('_')[0]+'_'+item[1]] for item in subs_input]
    df_subs=pd.DataFrame.from_records(subs,columns=['subs','sub_ses','group_ses'])
    
    #load the non-imaging data
    non_img_data = pd.read_csv(f"../../../sourcedata/non_imaging_data/MMJ-Processed_data-2022_05_27-13_58-6858bbe.csv",low_memory=False)

    simple_additions = [('SBJ.CHR.Sex','Sex'),("SBJ.INT.Age",'Age'),("SBJ.CHR.Race",'Race'),
                        ("SBJ.CHR.Ethnicity",'Ethnicity'),("SBJ.CHR.Education_level",'Education level'),
                        ("SBJ.INT.Education_years",'Education years'),("SBJ.CHR.Employment_status",'Employment status'),
                        ("SBJ.CHR.Handedness",'Handedness'),('SSS.CHR.Primary_condition','Condition'),
                        ('URN.LGC.THC_present','Positive urine THC at baseline'),('URN.LGC.THC_present','Positive urine THC at one-year'),
                        ('CUD.CHR.Diagnosis','CUD diagnosis at baseline'),('CUD.CHR.Diagnosis','CUD diagnosis at one-year'),
                        ('INV.INT.CUDIT.Summed_score','CUDIT summed score at baseline'),('INV.INT.CUDIT.Summed_score','CUDIT summed score at one-year'),
                        ('TLF.CHR.THC.Frequency_in_month','THC frequency per month at baseline'),('TLF.CHR.THC.Frequency_in_month','THC frequency per month at one-year')]
    
    by_ses_additions = []
    
    
    for orig_name, col_name in simple_additions:
            
        if 'at one-year' in col_name:
            dict_map = non_img_data[non_img_data['SSS.CHR.Time_point'] == 'One year'].groupby('IDS.CHR.Subject')[orig_name].agg("first").to_dict()
        
        else:
            dict_map = non_img_data.groupby("IDS.CHR.Subject")[orig_name].agg("first").to_dict()

        if orig_name == 'SBJ.CHR.Race':
            for sub, race in dict_map.items():
                if race == 'Caucasian':
                    dict_map[sub] = 'White'
                elif race == 'African American':
                    dict_map[sub] = 'Black'

        df_subs[col_name] = df_subs['subs'].map(dict_map)
        
    
    for orig_name, col_name in by_ses_additions:
        
        dict_HC_baseline = non_img_data[non_img_data['SSS.CHR.Time_point'] == 'Screening'].groupby('IDS.CHR.Subject')[orig_name].agg("first").to_dict()
        dict_HC_baseline = {f'{sub}_baseline':val for (sub,val) in dict_HC_baseline.items() if 'HC' in sub}
        
        dict_MM_baseline = non_img_data[non_img_data['SSS.CHR.Time_point'] == 'Baseline'].groupby('IDS.CHR.Subject')[orig_name].agg("first").to_dict()
        dict_MM_baseline = {f'{sub}_baseline':val for (sub,val) in dict_MM_baseline.items() if 'MM' in sub}

        dict_MM_1year = non_img_data[non_img_data['SSS.CHR.Time_point'] == 'One year'].groupby('IDS.CHR.Subject')[orig_name].agg("first").to_dict()
        dict_MM_1year = {f'{sub}_1year':val for (sub,val) in dict_MM_1year.items() if 'MM' in sub}
        
        dict_all = {**dict_HC_baseline, **dict_MM_baseline, **dict_MM_1year}
        df_subs[col_name] = df_subs['sub_ses'].map(dict_all)


    df_subs['Employment status'].replace({'self':'Self'},regex=True,inplace=True)
    
    return df_subs


In [None]:
def save_table1(df_subs,include_pval):
    
    #this includes the THC/CUD items as well
    #columns = ['Sex', 'Age', 'Race', 'Ethnicity', 'Education level', 'Education years','Employment status','Handedness','Positive urine THC at baseline','Positive urine THC at one-year','CUDIT summed score at baseline','CUDIT summed score at one-year','THC frequency per month at baseline','THC frequency per month at one-year']
    #categorical = ['Sex', 'Race', 'Ethnicity', 'Education level','Employment status','Handedness','Positive urine THC at baseline','Positive urine THC at one-year','THC frequency per month at baseline','THC frequency per month at one-year']
    
    #this only includes the demographics
    columns = ['Sex', 'Age', 'Race', 'Ethnicity', 'Education level', 'Education years','Employment status','Handedness']
    categorical = ['Sex', 'Race', 'Ethnicity', 'Education level','Employment status','Handedness']

    groupby = ['group_ses']
    labels={'HC_baseline': 'HC baseline','MM_baseline': 'MC baseline','MM_1year': 'MC one-year'}
    
    if include_pval:
        mytable = TableOne(df_subs, columns=columns, categorical=categorical, groupby=groupby, rename=labels, pval=True)
    else:
        mytable = TableOne(df_subs, columns=columns, categorical=categorical, groupby=groupby, rename=labels, pval=False)

    
    #create paths to output dir if not exist
    derivatives_path = '../../../derivatives'
    nilearn_output_path = os.path.join(derivatives_path, 'demographics')
    if not os.path.isdir(nilearn_output_path):
        os.makedirs (nilearn_output_path)
    
    if include_pval:
        mytable.to_csv('../../../derivatives/demographics/table1_all_pval.csv')
        mytable_df = pd.read_csv(f'../../../derivatives/demographics/table1_all_pval.csv')
        mytable_df.drop('Grouped by group_ses.1', axis=1, inplace=True)     
        mytable_df.to_csv('../../../derivatives/demographics/table1_all_pval.csv',index=False)
                                 
                    
    else:
        mytable.to_csv('../../../derivatives/demographics/table1_all_no_pval.csv')
        mytable_df = pd.read_csv(f'../../../derivatives/demographics/table1_all_no_pval.csv')
        mytable_df.drop('Grouped by group_ses.1', axis=1, inplace=True)     
        mytable_df.to_csv('../../../derivatives/demographics/table1_all_no_pval.csv',index=False)
        
    display(mytable_df)
    
    return

In [None]:
def create_table(include_pval):
    subs_all = get_all_subs()
    df_subs = create_indiv_subs_df(subs_all)
    save_table1(df_subs,include_pval)
    return

In [None]:
include_pval=False
create_table(include_pval)