In [1]:
import pandas as pd
import pickle
import seaborn as sns
import numpy as np
from kde_ebm.mixture_model import fit_all_gmm_models
from kde_ebm import mixture_model
from sSuStaIn.sEBMSustain import sEBMSustain, sEBMSustainData
import os
import numpy.ma as ma
from collections import Counter
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt
from matplotlib import cm
import pdb

In [2]:
def process_L(L, min_val=0):
    mx = ma.masked_less_equal(L,min_val)
    min_masked = mx.min(axis=0)
    L_new = mx.filled(fill_value=min_masked)
    return L_new

### Read the pickled file

In [3]:
pickle_path = "/nethome/rtandon32/ebm/s-SuStain-outputs/sim_tadpole9_mixture_GMM/pickle_files/sim_tadpole9_subtype3.pickle"
with open(pickle_path, "rb") as input_file:
    pkl = pickle.load(input_file)

### MCI and CN/AD data

In [4]:
cn_ad_path = "/home/rtandon32/ebm/ebm_experiments/experiment_scripts/real_data/dfMri_D12_ebm_final_n327.csv"
df_cnad = pd.read_csv(cn_ad_path)
mci_path = "/nethome/rtandon32/ebm/ebm_experiments/experiment_scripts/real_data/dfMri_D12_ebm_mci.csv"
df_mci = pd.read_csv(mci_path)
df_cnad["EXAMDATE"] = pd.to_datetime(df_cnad["EXAMDATE"])
df_mci["EXAMDATE"] = pd.to_datetime(df_mci["EXAMDATE"])
assert all(df_mci.columns == df_cnad.columns)

### Get the mixture model for the CN/AD subjects

In [5]:

k=119
X = df_cnad.iloc[:,:k].values
bm_names = df_cnad.columns[:k].tolist()
y = df_cnad["DX"].map({"Dementia":1, "CN":0})
mm_fit = mixture_model.fit_all_gmm_models
mixture_models = mm_fit(X, y)
L_yes = np.zeros(X.shape)
L_no = np.zeros(X.shape)
for i in range(k):
    L_no[:, i], L_yes[:, i] = mixture_models[i].pdf(None, X[:, i])

L_no = process_L(L_no)
L_yes = process_L(L_yes)



### Create a SuStaIn object to call the associated methods

In [6]:

# Create the SuStaIn object for the sEBM model
stage_sizes = [25,25,25,25,19]
N_startpoints           = 25
N_S_max                 = 3
rep = 20
N_iterations_MCMC_init = int(2e4)
N_iterations_MCMC       = int(5e5)  #Generally recommend either 1e5 or 1e6 (the latter may be slow though) in practice
n_stages = 5
min_clust_size = 8
p_absorb = 0.3
SuStaInLabels = df_cnad.columns[:k].tolist()
dataset_name            = 'sim_tadpole9'
output_dir              = '/home/rtandon32/ebm/s-SuStain-outputs'
sustainType             = 'mixture_GMM'
output_folder           = os.path.join(output_dir, dataset_name + '_' + sustainType)
use_parallel_startpoints = True
sustain = sEBMSustain(L_yes, L_no, n_stages, stage_sizes, min_clust_size, p_absorb, rep, SuStaInLabels, N_startpoints, N_S_max, N_iterations_MCMC_init, N_iterations_MCMC, output_folder, dataset_name, use_parallel_startpoints)

### Get the prob mat for the mci subjects

In [33]:
# Get the prob mat for the mci subjects
X_mci = df_mci.iloc[:,:k].values
prob_mat_mci = mixture_model.get_prob_mat(X_mci, mixture_models)
L_no_mci = prob_mat_mci[:,:,0]
L_yes_mci = prob_mat_mci[:,:,1]
L_no_mci = process_L(L_no_mci)
L_yes_mci = process_L(L_yes_mci)

### Subtype and Stage the MCI subjects


In [34]:
# Subtype and Stage the MCI subjects
last_N = 1000000
N_samples = 1000
sustainData_newData = sEBMSustainData(L_yes_mci, L_no_mci, n_stages)
samples_sequence = pkl["samples_sequence"][:,:,-last_N:]
samples_f = pkl["samples_f"][:,-last_N:]
shape_seq = pkl["shape_seq"]
temp_mean_f = np.mean(samples_f, axis=1)
ix = np.argsort(temp_mean_f)[::-1]

ml_subtype_mci, \
prob_ml_subtype_mci, \
ml_stage_mci, \
prob_ml_stage_mci, \
prob_subtype_mci, \
prob_stage_mci, \
prob_subtype_stage_mci = sustain.subtype_and_stage_individuals(sustainData_newData, shape_seq, samples_sequence, samples_f, N_samples)


### Subtype and Stage CN/AD data

In [35]:
# subtype and stage CN/AD data

sustainData_cnad = sEBMSustainData(L_yes, L_no, n_stages)
samples_sequence = pkl["samples_sequence"][:,:,-last_N:]
samples_f = pkl["samples_f"][:,-last_N:]
shape_seq = pkl["shape_seq"]

ml_subtype_cnad, \
prob_ml_subtype_cnad, \
ml_stage_cnad, \
prob_ml_stage_cnad, \
prob_subtype_cnad, \
prob_stage_cnad, \
prob_subtype_stage_cnad = sustain.subtype_and_stage_individuals(sustainData_cnad, shape_seq, samples_sequence, samples_f, N_samples)


### Prepare final dataframe which has subtype, stage, PTID, DX, and EXAMDATE for all subjects

In [36]:
# DataFrame for CN/AD subjects
array_subtype_stage_cnad = np.hstack([ml_subtype_cnad, ml_stage_cnad, y.values.reshape(-1,1)])
cnad_solved = pd.DataFrame(data=array_subtype_stage_cnad, columns=["subtype", "stage", "DX"])
cnad_solved[["PTID", "EXAMDATE"]] = df_cnad[["PTID", "EXAMDATE"]]

In [42]:
# DataFrame for MCI subjects
array_subtype_stage_mci = np.hstack([ml_subtype_mci, ml_stage_mci])
mci_solved = pd.DataFrame(data=array_subtype_stage_mci, columns=["subtype", "stage"])
mci_solved["DX"] = 2.0
mci_solved[["PTID", "EXAMDATE"]] = df_mci[["PTID", "EXAMDATE"]]
final_df = pd.concat([cnad_solved, mci_solved], axis=0)

In [43]:
final_df["DX"] = final_df["DX"].map({0.0:"Controls", 1.0:"AD", 2.0:"MCI"})


In [44]:
final_df

Unnamed: 0,subtype,stage,DX,PTID,EXAMDATE
0,1.0,2.0,Controls,002_S_0295,2006-11-02
1,1.0,1.0,Controls,002_S_0413,2007-06-01
2,2.0,3.0,AD,002_S_0619,2006-12-13
3,1.0,4.0,Controls,002_S_0685,2010-07-15
4,1.0,2.0,AD,002_S_0816,2008-01-28
...,...,...,...,...,...
546,3.0,3.0,MCI,941_S_4100,2015-08-28
547,3.0,1.0,MCI,941_S_4187,2012-02-29
548,0.0,2.0,MCI,941_S_4377,2012-08-16
549,2.0,1.0,MCI,941_S_4420,2013-03-25


# Longitudinal data analysis

In [7]:
# Longitudinal data analysis
long_path = "/nethome/rtandon32/ebm/ebm_experiments/experiment_scripts/real_data/df12_longitudinal_ebm.csv"
df_long = pd.read_csv(long_path)
cnad_ptid = df_cnad["PTID"].tolist()
mci_ptid = df_mci["PTID"].tolist()
df_long_cnad = df_long[df_long["sid"].isin(cnad_ptid)]
df_long_mci = df_long[df_long["sid"].isin(mci_ptid)]

In [8]:
k = 119
prob_mat_mci_long = mixture_model.get_prob_mat(df_long_mci.iloc[:,:k].values, mixture_models)
L_no_mci_long = prob_mat_mci_long[:,:,0]
L_yes_mci_long = prob_mat_mci_long[:,:,1]
L_yes_mci_long = process_L(L_yes_mci_long)
L_no_mci_long = process_L(L_no_mci_long)

### Subtype and Stage the MCI subjects

In [10]:

last_N = 1000000
N_samples = 1000
sustainData_mci_long = sEBMSustainData(L_yes_mci_long, L_no_mci_long, n_stages)
samples_sequence = pkl["samples_sequence"][:,:,-last_N:]
samples_f = pkl["samples_f"][:,-last_N:]
shape_seq = pkl["shape_seq"]
temp_mean_f = np.mean(samples_f, axis=1)
ix = np.argsort(temp_mean_f)[::-1]

### Get subtype and stage fir the MCI longitudinal observations

In [11]:
ml_subtype_mci, \
prob_ml_subtype_mci, \
ml_stage_mci, \
prob_ml_stage_mci, \
prob_subtype_mci, \
prob_stage_mci, \
prob_subtype_stage_mci = sustain.subtype_and_stage_individuals(sustainData_mci_long, shape_seq, samples_sequence, samples_f, N_samples)


In [14]:
mci_long_subtype_stage = np.hstack([ml_subtype_mci, ml_stage_mci])
mci_long_subtype_stage = pd.DataFrame(mci_long_subtype_stage, columns=["subtype", "stage"])
mci_long_subtype_stage[["sid", "date"]] = df_long_mci.reset_index()[["sid", "date"]]
mci_long_subtype_stage["date"] = pd.to_datetime(mci_long_subtype_stage["date"])
df_followup = mci_long_subtype_stage
df_followup = df_followup.sort_values(["sid", "date"])

In [22]:
mci_ptid = pd.unique(df_followup["sid"]).tolist()

In [65]:
st_dict = dict(zip(final_df["PTID"], final_df["subtype"]))

### Read the pdxconv file

In [67]:
pdxconv_path = "/home/rtandon32/ebm/ebm_experiments/experiment_scripts/adni_post_hoc/DXSUM_PDXCONV_ADNIALL_25Jan2024.csv"
df_pdxconv = pd.read_csv(pdxconv_path)
df_pdxconv["EXAMDATE"] = pd.to_datetime(df_pdxconv["EXAMDATE"])
df_pdxconv_mci =  df_pdxconv[df_pdxconv["PTID"].isin(mci_ptid)]


In [69]:
df_pdxconv_mci["cs_subtype"] = df_pdxconv_mci["PTID"].map(st_dict)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_pdxconv_mci["cs_subtype"] = df_pdxconv_mci["PTID"].map(st_dict)


In [70]:
df_pdxconv_mci

Unnamed: 0,Phase,ID,RID,PTID,SITEID,VISCODE,VISCODE2,VISDATE,USERDATE,USERDATE2,...,DXPDES,DXPCOG,DXPATYP,DXDEP,DXOTHDEM,DXODES,DXCONFID,DIAGNOSIS,update_stamp,cs_subtype
10,ADNI1,22,22,011_S_0022,107,bl,bl,2005-11-01,2005-11-02,,...,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2005-11-02 00:00:00.0,0.0
15,ADNI1,32,41,007_S_0041,2,bl,bl,2005-11-14,2005-11-14,,...,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2005-11-14 00:00:00.0,1.0
27,ADNI1,56,33,035_S_0033,6,bl,bl,2005-12-09,2005-12-21,,...,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2005-12-21 00:00:00.0,2.0
30,ADNI1,62,51,099_S_0051,45,bl,bl,2005-12-29,2006-01-03,,...,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2006-01-03 00:00:00.0,2.0
37,ADNI1,76,101,007_S_0101,2,bl,bl,2006-01-05,2006-01-09,,...,-4.0,-4.0,-4.0,,-4.0,-4.0,3.0,,2006-01-09 00:00:00.0,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13438,ADNI3,175136,4513,041_S_4513,28,y5,m132,2023-04-17,2023-05-17,2023-05-17,...,,,,0.0,,,,2.0,2023-05-18 04:20:27.0,0.0
13468,ADNI3,176432,4514,126_S_4514,46,y5,m138,2023-06-27,2023-06-27,2023-06-27,...,,,,1.0,,,,2.0,2023-06-28 04:22:25.0,3.0
13485,ADNI3,177571,2249,052_S_2249,30,init,m90,2018-08-27,2023-08-01,2023-08-01,...,,,,0.0,,,,2.0,2023-08-02 04:21:22.0,3.0
13506,ADNI3,179204,4210,127_S_4210,47,y5,m132,2022-10-04,2023-11-28,2023-11-28,...,,,,,,,,,2023-11-29 04:24:50.0,0.0


In [106]:
def counts_from_table(table, ptid, fields):
    # table - dataframe
    # fields - list of columns
    # ptid - list of PTID (subjects)
    # vals - dictionary, where keys 
    # are the fields and vals are a 
    # list of their possible values
    ptid_dict = {}
    for pt in ptid:
        table_pt = table[table["PTID"].isin([pt])][fields]
        f_dict = {}
        for f in fields:
            f_dict[f] = table_pt[f].unique()
        ptid_dict[pt] = f_dict
    return ptid_dict

In [109]:
ptid_dict = counts_from_table(df_pdxconv_mci, mci_ptid, ["DXMCI", "DXOTHDEM"])

In [110]:
ptid_dict

{'002_S_0729': {'DXMCI': array([ 1., -4., nan]),
  'DXOTHDEM': array([-4., nan])},
 '002_S_0782': {'DXMCI': array([1.]), 'DXOTHDEM': array([-4.])},
 '002_S_0954': {'DXMCI': array([ 1., -4.]), 'DXOTHDEM': array([-4.])},
 '002_S_1155': {'DXMCI': array([ 1., nan]), 'DXOTHDEM': array([-4., nan])},
 '002_S_1268': {'DXMCI': array([ 1., nan]), 'DXOTHDEM': array([-4., nan])},
 '002_S_2073': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4229': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4237': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4251': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4447': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4473': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4521': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4654': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '002_S_4746': {'DXMCI': array([nan]), 'DXOTHDEM': array([nan])},
 '003_S_0908': {'DXMCI': array([ 1.

In [100]:
# st = 0.0
# df_pdxconv_mci_st = df_pdxconv_mci[df_pdxconv_mci["cs_subtype"]==st]
table = df_pdxconv_mci.pivot_table(index='cs_subtype', columns='DXMDUE', 
                        aggfunc=len, fill_value=0)["PTID"]
table

DXMDUE,-4.0,1.0,2.0
cs_subtype,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0.0,124,996,38
1.0,20,906,54
2.0,37,485,23
3.0,43,635,56


In [98]:
table.sum().sum()

4062

In [48]:
long_to_cs = df_followup.sort_values(["sid", "date"]).drop_duplicates(subset=["sid"])

In [56]:
df_followup.sort_values(["sid", "date"])

Unnamed: 0,subtype,stage,sid,date
1683,0.0,3.0,002_S_0729,2007-02-22
1339,0.0,3.0,002_S_0729,2007-09-05
1186,0.0,3.0,002_S_0729,2008-09-29
1433,0.0,3.0,002_S_0729,2009-08-13
721,3.0,2.0,002_S_0782,2007-04-11
...,...,...,...,...
211,2.0,1.0,941_S_4420,2013-03-25
1484,1.0,4.0,941_S_4764,2012-06-01
1416,1.0,4.0,941_S_4764,2012-08-20
2279,1.0,4.0,941_S_4764,2013-06-10


In [57]:
final_df[final_df["DX"].isin(["MCI"])]

Unnamed: 0,subtype,stage,DX,PTID,EXAMDATE
0,0.0,3.0,MCI,002_S_0729,2007-02-22
1,3.0,2.0,MCI,002_S_0782,2008-10-17
2,0.0,4.0,MCI,002_S_0954,2007-05-03
3,1.0,1.0,MCI,002_S_1155,2012-12-20
4,3.0,3.0,MCI,002_S_1268,2007-09-21
...,...,...,...,...,...
546,3.0,3.0,MCI,941_S_4100,2015-08-28
547,3.0,1.0,MCI,941_S_4187,2012-02-29
548,0.0,2.0,MCI,941_S_4377,2012-08-16
549,2.0,1.0,MCI,941_S_4420,2013-03-25


In [58]:
long_to_cs.merge(final_df[final_df["DX"].isin(["MCI"])], left_on=["sid", "date"], right_on=["PTID", "EXAMDATE"], how="inner")

Unnamed: 0,subtype_x,stage_x,sid,date,subtype_y,stage_y,DX,PTID,EXAMDATE
0,0.0,3.0,002_S_0729,2007-02-22,0.0,3.0,MCI,002_S_0729,2007-02-22
1,0.0,4.0,002_S_0954,2007-05-03,0.0,4.0,MCI,002_S_0954,2007-05-03
2,0.0,4.0,006_S_0675,2007-04-26,0.0,4.0,MCI,006_S_0675,2007-04-26
3,2.0,3.0,007_S_0101,2007-01-29,2.0,3.0,MCI,007_S_0101,2007-01-29
4,0.0,3.0,007_S_0128,2006-08-14,0.0,3.0,MCI,007_S_0128,2006-08-14
...,...,...,...,...,...,...,...,...,...
78,1.0,1.0,137_S_0800,2011-08-30,1.0,1.0,MCI,137_S_0800,2011-08-30
79,3.0,2.0,137_S_0973,2007-05-22,3.0,2.0,MCI,137_S_0973,2007-05-22
80,1.0,1.0,137_S_4623,2013-05-13,1.0,1.0,MCI,137_S_4623,2013-05-13
81,3.0,1.0,137_S_4631,2013-06-17,3.0,1.0,MCI,137_S_4631,2013-06-17


In [30]:
df_followup.sort_values(["sid", "date"])

Unnamed: 0,subtype,stage,sid,date
1683,0.0,3.0,002_S_0729,2007-02-22
1339,0.0,3.0,002_S_0729,2007-09-05
1186,0.0,3.0,002_S_0729,2008-09-29
1433,0.0,3.0,002_S_0729,2009-08-13
721,3.0,2.0,002_S_0782,2007-04-11
...,...,...,...,...
211,2.0,1.0,941_S_4420,2013-03-25
1484,1.0,4.0,941_S_4764,2012-06-01
1416,1.0,4.0,941_S_4764,2012-08-20
2279,1.0,4.0,941_S_4764,2013-06-10


In [29]:
pd.merge(df_followup, df_pdxconv, left_on="sid", right_on="PTID", how="inner")

Unnamed: 0,subtype,stage,sid,date,Phase,ID,RID,PTID,SITEID,VISCODE,...,DXPARK,DXPDES,DXPCOG,DXPATYP,DXDEP,DXOTHDEM,DXODES,DXCONFID,DIAGNOSIS,update_stamp
0,0.0,3.0,002_S_0729,2007-02-22,ADNI1,786,729,002_S_0729,101,bl,...,-4.0,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2006-08-03 00:00:00.0
1,0.0,3.0,002_S_0729,2007-02-22,ADNI1,2418,729,002_S_0729,101,m06,...,-4.0,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2007-03-21 00:00:00.0
2,0.0,3.0,002_S_0729,2007-02-22,ADNI1,3832,729,002_S_0729,101,m12,...,-4.0,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2007-09-27 00:00:00.0
3,0.0,3.0,002_S_0729,2007-02-22,ADNI1,5634,729,002_S_0729,101,m18,...,-4.0,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2008-08-08 00:00:00.0
4,0.0,3.0,002_S_0729,2007-02-22,ADNI1,5884,729,002_S_0729,101,m24,...,-4.0,-4.0,-4.0,-4.0,,-4.0,-4.0,4.0,,2008-10-03 00:00:00.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17340,1.0,4.0,941_S_4764,2014-06-16,ADNI2,3404,4764,941_S_4764,58,v03,...,0.0,,,,0.0,,,,,2013-03-14 19:16:46.0
17341,1.0,4.0,941_S_4764,2014-06-16,ADNI2,4620,4764,941_S_4764,58,v05,...,0.0,,,,0.0,,,,,2013-03-14 19:16:47.0
17342,1.0,4.0,941_S_4764,2014-06-16,ADNI2,6588,4764,941_S_4764,58,v11,...,0.0,,,,0.0,,,,,2013-07-01 19:16:37.0
17343,1.0,4.0,941_S_4764,2014-06-16,ADNI2,8884,4764,941_S_4764,58,v21,...,0.0,,,,0.0,,,,,2014-07-03 19:16:58.0


In [59]:
pdxconv_path = "/home/rtandon32/ebm/ebm_experiments/experiment_scripts/adni_post_hoc/DXSUM_PDXCONV_ADNIALL_25Jan2024.csv"

In [26]:
df_pdxconv.shape

(13514, 45)