In [1]:
## Authors:     Michael Quillen & Max Parker; M.D. Candidates @ University of Florida
## Project:     'Towards prediction of CRC in patients under the age of 50'
## PIs:         Dr. Thomas George, MD; Dr. Jiang Bian, PhD
## 
## **base code adapted from Dr. Xi Yang, PhD project: 'Early Prediction of Alzheimer's Disease and Related Dementias
##                                                          Using Electronic Health Records'

In [2]:
import numpy as np
import pandas as pd
from pathlib import Path

In [3]:
import time
import os
import sys
import re 
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, wait
from functools import partial
import mmap
import json
import pickle as pkl
import gc
import logging
from tqdm import tqdm
CPU_COUNT=9

In [4]:
def pkl_dump(data, file):
    with open(file, "wb") as fw:
        pkl.dump(data, fw)

        
def pkl_load(file):
    with open(file, "rb") as fr:
        data = pkl.load(fr)
    return data


def pkl4_dump(data, file):
    with open(file, "wb") as fw:
        pkl.dump(data, fw, pkl.HIGHEST_PROTOCOL)

        
def pkl4_load(file):
    with open(file, "rb") as fr:
        data = pkl.load(fr)
    return data

In [5]:
os.chdir('/mnt/data1/chong/2021-CRC/updated_data/agg_files')
os.getcwd()
global_data_CC_01yr = pkl_load("aggCC01.pkl")
global_data_CC_3yr = pkl_load("aggCC3.pkl")
global_data_CC_5yr = pkl_load("aggCC5.pkl")

In [6]:
os.chdir('/mnt/data1/chong/2021-CRC/updated_data/agg_files')
d_l_v01 = pkl_load("d_l_v_CC_01.pkl")
d_l_v3 = pkl_load("d_l_v_CC_3.pkl")
d_l_v5 = pkl_load("d_l_v_CC_5.pkl")

In [7]:
import datetime

def bmi_map(x):
    if x <= 18.5:
        return 'underweight'
    elif 18.5 < x <= 23:
        return 'normal'
    elif 23 < x <= 30:
        return 'overweight'
    elif x > 30:
        return 'obesity'
    else:
        return f'bmi_other'
    
def diastolic_map(x):
    if x <= 80:
        return 'diastolic_Optimal'
    elif 80 < x <= 90:
        return 'diastolic_Prehypertension'
    elif 90 < x <= 100:
        return 'diastolic_Hypertension_stage_1'
    elif 100 < x <= 110:
        return 'diastolic_Hypertension_stage_2'
    elif x > 110:
        return 'diastolic_Hypertension_crisis'
    else:
        return f'diastolic_other'
    
def systolic_map(x):
    if x <= 120:
        return 'systolic_Optimal'
    elif 120 < x <= 140:
        return 'systolic_Prehypertension'
    elif 140 < x <= 160:
        return 'systolic_Hypertension_stage_1'
    elif 160 < x <= 180:
        return 'systolic_Hypertension_stage_2'
    elif x > 180:
        return 'systolic_Hypertension_crisis'
    else:
        return f'systolic_other'

In [8]:
def get_age(age):
    if 30 <= age < 40:
        return "age_30_39"
    elif 40 <= age < 50:
        return "age_40_49"
    elif 18 <= age < 30:
        return "age_18_29"
    elif age < 18:
        return "age_<18"
    else:
        return "age_>=50"
    

def get_fea_id(fea_dict, features, fea):
    if fea in fea_dict:
        return fea_dict[fea]
    else:
        features.append(fea)
        fea_dict[fea] = len(features)
        return fea_dict[fea]
    
def s2t(t):
    return datetime.datetime.strptime(t, "%Y-%m-%d")

def diff_days(d1, d2):
    d1 = s2t(d1)
    d2 = s2t(d2)
    return (d1-d2).days

In [9]:
import traceback

def get_clinic(tag, pat_data, idx, th, f2i, feas, p):
    try:
        data = pat_data[tag]
        l = []
        for k, v in data.items():
            if diff_days(idx, k) <= th:
                continue
            for val in v:
                l.append(get_fea_id(f2i, feas, f"{tag}_{val}"))
        return l
    except Exception as ex:
        errors.append((p, tag, traceback.format_exc))
#         print(p,tag,data)
#         traceback.print_exc()
        return []

In [10]:
#'index_date', 'age', 'has_ad', 'gender', 'hispanic', 'race', 'diag', 'proc', 'med_p', 'med_d'
from tqdm import tqdm_notebook as tqdm

def matching(agg, vit, thc):
    data_points = []
    if thc:
        th = 365*thc
    else:
        th = 1
    
    for ii, (k, v) in enumerate(agg.items()):
        if ii % 10000 == 0:
            print("processed ", ii)
        data_point = []
        pid = k
        index_date = v['index_date']
        
        # case or control
        label = v['has_CRC']
        
        # age, gender, race, hispanic
        age_fea = get_age(v['age'])
        age_id = get_fea_id(fea2id, features, age_fea)
        gender_id = get_fea_id(fea2id, features, "SEX_" + v['SEX'])
        race_id = get_fea_id(fea2id, features, "Race_" + v['Race'])
        hispanic_id = get_fea_id(fea2id, features, "Hispanic_" + v['Hispanic'])
        data_point.extend([age_id, gender_id, race_id, hispanic_id])
        
        #diag, proc, med
        med_p = get_clinic('med_p', v, index_date, th, fea2id, features, pid)
        med_d = get_clinic('med_d', v, index_date, th, fea2id, features, pid)
        diags = get_clinic('diag', v, index_date, th, fea2id, features, pid)
        proc = get_clinic('proc', v, index_date, th, fea2id, features, pid)
        labs = get_clinic('lab',  v, index_date, th, fea2id, features, pid)
        data_point.extend(med_p)
        data_point.extend(med_d)
        data_point.extend(diags)
        data_point.extend(proc)
        data_point.extend(labs)
        
        # vital
        pv = vit[pid]
        if 'BMI_mean' in pv:
            bmi = pv['BMI_mean']
            if not pd.isna(bmi):
                bmi_id = get_fea_id(fea2id, features, bmi_map(bmi))
                data_point.append(bmi_id)
        if 'DIASTOLIC_mean' in pv:
            dp = pv['DIASTOLIC_mean']
            if not pd.isna(dp):
                dp_id = get_fea_id(fea2id, features, diastolic_map(dp))
                data_point.append(dp_id)
        if 'SYSTOLIC_mean' in pv:    
            sp = pv['SYSTOLIC_mean']
            if not pd.isna(sp):
                sp_id = get_fea_id(fea2id, features, systolic_map(sp))
                data_point.append(sp_id)
        
        #for lbk in {'BUNCreat', 'hmglb', 'folate', 'hba1c', 'ttChol', 'vitaminD', 'b12', 'hdl', 'glucose'}:
         #   lvv = pv[f"{lbk}_l_abnind_{thc}y"]
         #   if not pd.isna(lvv):
         #       lvv_id = get_fea_id(fea2id, features, f"{lbk}_{lvv}")
         #       data_point.append(lvv_id)
        

        data_point = list(sorted(set(data_point)))
        data_point.insert(0, label)
        data_point.insert(0, pid)
        data_points.append(data_point)
    return data_points

In [11]:
os.chdir('/mnt/data1/chong/2021-CRC/updated_data/encoding_files')

In [12]:
features = []
fea2id = dict()
errors = []
d5_1fl = matching(global_data_CC_01yr, d_l_v01,thc=0)
print(len(features))
pkl_dump(d5_1fl, "./data_CC0yr_expr.pkl")
pkl_dump((fea2id, features), "./data_CC0yr_expr_features.pkl")

processed  0
8601


In [13]:
features = []
fea2id = dict()
errors = []
d5_1fl = matching(global_data_CC_01yr, d_l_v01,thc=1)
print(len(features))
pkl_dump(d5_1fl, "./data_CC1yr_expr.pkl")
pkl_dump((fea2id, features), "./data_CC1yr_expr_features.pkl")

processed  0
7714


In [14]:
features = []
fea2id = dict()
errors = []
d5_1fl = matching(global_data_CC_3yr, d_l_v3,thc=3)
print(len(features))
pkl_dump(d5_1fl, "./data_CC3yr_expr.pkl")
pkl_dump((fea2id, features), "./data_CC3yr_expr_features.pkl")

processed  0
5852


In [15]:
features = []
fea2id = dict()
errors = []
d5_1fl = matching(global_data_CC_5yr, d_l_v5,thc=5)
print(len(features))
pkl_dump(d5_1fl, "./data_CC5yr_expr.pkl")
pkl_dump((fea2id, features), "./data_CC5yr_expr_features.pkl")

processed  0
3794
