In [None]:
import os
import re
import glob

import pandas as pd
import numpy as np

In [None]:
def rm_CUD_baseline(subs_list):
        
    #subs to be excluded (only MM) because they had cannabis use disorder at baseline (exclusion criterium)
    excluded_subs = ['MM_014','MM_188','MM_197','MM_217','MM_228','MM_239','MM_241']
    
    #get only subjects that aren't those of any of the excluded subjects
    final_subs_list = [sub for sub in subs_list if sub not in excluded_subs]
        
    return final_subs_list
    

In [None]:
def get_all_subs():
    
    #get subs for each group and timepoint
    HC_subs_baseline_paths = glob.glob(f'../../../sub-HC*/ses-baseline')
    HC_subs_baseline = ['HC_' + path.split('/')[3].split('-HC')[1] for path in HC_subs_baseline_paths]
    
    MM_subs_baseline_paths = glob.glob(f'../../../sub-MM*/ses-baseline')
    MM_subs_baseline = set(['MM_' + path.split('/')[3].split('-MM')[1] for path in MM_subs_baseline_paths])

    MM_subs_1year_paths = glob.glob(f'../../../sub-MM*/ses-1year')
    MM_subs_1year = set(['MM_' + path.split('/')[3].split('-MM')[1] for path in MM_subs_1year_paths])
    
    #additionally specify which MM subs are paired and remove CUD baseline subs from all MM lists
    MM_subs_paired = rm_CUD_baseline(list(MM_subs_baseline.intersection(MM_subs_1year)))
    
    MM_subs_baseline = rm_CUD_baseline(list(MM_subs_baseline))
    
    MM_subs_1year = rm_CUD_baseline(list(MM_subs_1year))
        
    #put all subs lists together as a dictionary
    subs_all_dict = {'HC_baseline':HC_subs_baseline, 'MM_baseline':MM_subs_baseline, 'MM_1year':MM_subs_1year, 'MM_paired':MM_subs_paired}

    return subs_all_dict
    

In [None]:
def create_indiv_subs_df(group, subs_list):
    
    #dataframe to add columns to for each subject
    ses = group.split('_')[-1]
    
    df_data = {
    'subs': subs_list,
    'session': [ses] * len(subs_list)
    }
    
    df_subs = pd.DataFrame(df_data)
    
    
    #load the non-imaging data
    non_img_data = pd.read_csv(f"../../../sourcedata/non_imaging_data/MMJ-Processed_data-2022_05_27-13_58-6858bbe.csv",low_memory=False)

    simple_additions = [('SBJ.CHR.Sex','Sex'),("SBJ.INT.Age",'Age'),
                        ("SBJ.CHR.Race",'Race'), ("SBJ.CHR.Ethnicity",'Ethnicity'),
                        ("SBJ.INT.Education_years",'Education years'),
                        ('SSS.CHR.Primary_condition','Condition'),
                        ("SBJ.CHR.Handedness",'Handedness')]
                        
    
    for orig_name, col_name in simple_additions:
            
        dict_map = non_img_data.groupby("IDS.CHR.Subject")[orig_name].agg("first").to_dict()

        if orig_name == 'SBJ.CHR.Race':
            for sub, race in dict_map.items():
                if race == 'Caucasian':
                    dict_map[sub] = 'White'
                elif race == 'African American':
                    dict_map[sub] = 'Black'
                elif race == 'Asian':
                    dict_map[sub] = 'Other'
                elif race == 'Multi-racial':
                    dict_map[sub] = 'Other'
                elif race == 'Pacific Islander':
                    dict_map[sub] = 'Other'
                    
        if orig_name == 'SSS.CHR.Primary_condition':
            for sub, condition in dict_map.items():
                if condition == 'Affective Disorder (Depression/Anxiety)':
                    dict_map[sub] = 'Depression/anxiety symptoms'
                elif condition == 'Insomnia':
                    dict_map[sub] = 'Insomnia symptoms'
                elif condition == 'Pain':
                    dict_map[sub] = 'Pain symptoms'

        df_subs[col_name] = df_subs['subs'].map(dict_map)
                
    return df_subs


In [None]:
def create_summary_table(group, df):
    
    #explicitly specify which columns are categorical and which are numerical
    categorical_columns = ['Sex', 'Race', 'Ethnicity', 'Condition', 'Handedness']
    numerical_columns = ['Age', 'Education years']

    #initialize a list to store the formatted results
    summary_list = []

    #calculate the total count of individuals
    total_count = len(df)

    #summary statistics for numerical columns
    for column in numerical_columns:
        median_value = df[column].median()
        q1 = df[column].quantile(0.25)
        q3 = df[column].quantile(0.75)
        iqr = f'({q1:.1f}-{q3:.1f})'
        result = f'{median_value:.1f} {iqr}'
        summary_list.append({'Items': f'{column}, median (IQR)', 'Levels': np.nan, f'{group}': result})

    #summary statistics for categorical columns
    for column in categorical_columns:
        counts = df[column].value_counts()
        proportions = df[column].value_counts(normalize=True)
        
        #reorder the categories based on the custom order and fill in missing values as 0
        if column == 'Race':
            custom_order = ['Black', 'White', 'Other']
            counts = counts.reindex(custom_order, fill_value=0)
            proportions = proportions.reindex(custom_order, fill_value=0)
            
        if column == 'Ethnicity':
            custom_order = ['Hispanic or Latino', 'Not Hispanic or Latino']
            counts = counts.reindex(custom_order, fill_value=0)
            proportions = proportions.reindex(custom_order, fill_value=0)
        
        if column == 'Condition':
            custom_order = ['Depression/anxiety symptoms', 'Insomnia symptoms', 'Pain symptoms', 'Healthy control']
            counts = counts.reindex(custom_order, fill_value=0)
            proportions = proportions.reindex(custom_order, fill_value=0)

        for category in counts.index:
            count = counts[category]
            percentage = proportions[category] * 100
            result = f'{count} ({percentage:.1f})'
            summary_list.append({'Items': f'{column}, n (%)', 'Levels': category, f'{group}': result})
    
    
    #move education years to correct spot
    item_to_move = summary_list.pop(1)
    summary_list.insert(8, item_to_move)
    
    #move age to correct spot
    item_to_move2 = summary_list.pop(0)
    summary_list.insert(2, item_to_move2)
    
    
    #create a DataFrame from the summary list
    summary_df = pd.DataFrame(summary_list)

    #create the count row
    count_row = pd.DataFrame([{'Items': 'n', 'Levels': np.nan, f'{group}': total_count}])

    #concatenate the count row with the summary DataFrame
    summary_df = pd.concat([count_row, summary_df], ignore_index=True)

    return summary_df

In [None]:
def create_table():
        
    #create df of all subs with relevant demographics
    subs_all_dict = get_all_subs()
    
    groups = subs_all_dict.keys()
        
    summary_dfs = []
    
    for group in groups:
        #get individual dataframes per group with relevant demographics
        indiv_subs_df = create_indiv_subs_df(group, subs_all_dict[group])
        
        #create summary tables
        indiv_summary_df = create_summary_table(group, indiv_subs_df)
        summary_dfs.append(indiv_summary_df)
            
    #make shared summary statistics dataframe
    summary_df = summary_dfs[0].iloc[:, :2].copy()
    
    for indiv_summary_df in summary_dfs:
        summary_df = pd.concat([summary_df, indiv_summary_df.iloc[:, 2]], axis=1)

    display(summary_df)
    
    summary_df.to_csv(f'../../../derivatives/demographics/table1.csv',index=False)
    
    return

    

In [None]:
create_table()