In [8]:
import sys 
sys.path.insert(0, '../')
from CPRD.spark import spark_init,read_txt,read_parquet, read_csv
from CPRD.table import Patient,Practice,Clinical,Diagnosis,Hes, Therapy, cvt_str2time, EHR
import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql.types import *
import random
import numpy as np
import datetime
import pandas as pd
from utils.utils import save_obj, load_obj
import matplotlib.pyplot as plt
import random

In [2]:
def create_vocab(code_cnt, symbols=None):
    if symbols is None:
        symbols = ["PAD", "UNK", "SEP", "CLS", "MASK"]

    # initialize dictionaries
    token2idx = {}
    idx2token = {}

    # set up predefined symbols
    for i in range(len(symbols)):
        token2idx[str(symbols[i])] = i
        idx2token[i] = str(symbols[i])

    # add all the tokens into the dictionary
    token = code_cnt.code.values
    for i in range(len(token)):
        idx = i + len(symbols)
        token2idx[str(token[i])] = idx
        idx2token[idx] = str(token[i])

    return token2idx, idx2token

def age_vocab(max_age, mon=1, symbol=None):
    age2idx = {}
    idx2age = {}
    if symbol is None:
        symbol = ['PAD', 'UNK']

    for i in range(len(symbol)):
        age2idx[str(symbol[i])] = i
        idx2age[i] = str(symbol[i])

    if mon == 12:
        for i in range(max_age):
            age2idx[str(i)] = len(symbol) + i
            idx2age[len(symbol) + i] = str(i)
    elif mon == 1:
        for i in range(max_age * 12):
            age2idx[str(i)] = len(symbol) + i
            idx2age[len(symbol) + i] = str(i)
    else:
        age2idx = None
        idx2age = None
    return age2idx, idx2age

age2idx, idx2age = age_vocab(max_age=110, mon=12, symbol=None)

data = {
    'token2idx': age2idx,
    'idx2token': idx2age
}

save_obj(data, '/home/shared/yikuan/HiBEHRT/data/dict4age')

# Data Preprocessing

In [3]:
def rename_col(df, old, new):
    """rename pyspark dataframe column"""
    return df.withColumnRenamed(old, new)

# construct representation learning dataset 
def check_time(df, col, time_a=1985, time_b=2015):
    """keep data with date between a and b"""
    year = F.udf(lambda x: x.year)
    df = df.withColumn('Y', year(col))
    df = df.filter(F.col('Y') >= time_a)
    df = df.filter(F.col('Y') <= time_b).drop('Y')
    return df

In [4]:
spark = spark_init()

# SSL

In [None]:
# read representation list 
rep_pat = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/pat_rep.parquet')
rep_pat = rename_col(rep_pat, 'patid', 'patid_eligible')

# diagnose
diag = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/diagnoses.parquet')
diag = check_time(diag, 'eventdate', time_a=1985, time_b=2005)
diag = diag.join(rep_pat, diag.patid==rep_pat.patid_eligible, 'left')
diag = diag.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

def select_diagnose(diagnoses):
    diagnoses_icd = diagnoses.filter(F.col('type')=='icd')\
                    .withColumn('first', F.col('code').substr(0,1))\
                    .filter(F.col('first').isin(*['Z','V','R','U','X','Y']) == False)\
                    .select(['patid', 'eventdate', 'code', 'type', 'source'])
    diagnoses_read = diagnoses.filter(F.col('type')=='read')\
                .withColumn('first', F.col('code').substr(0,1))\
                .filter(F.col('first').isin(*['0', '1', '2', '3', '4', '5', '6', '7','8', '9', 'Z', 'U']) == False)\
                .select(['patid', 'eventdate', 'code', 'type', 'source'])
    diagnoses = diagnoses_icd.union(diagnoses_read)    
    diagnoses = diagnoses.select(['patid', 'eventdate', 'code'])
    return diagnoses

diag = select_diagnose(diag) # patid, event date, code

print('diag:',diag.schema)

# medication
med = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/medication.parquet/')
med = check_time(med, 'eventdate', time_a=1985, time_b=2005)
med = med.join(rep_pat, med.patid==rep_pat.patid_eligible, 'left')
med = med.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible') # patid, event date, code

print('med:',med.schema)

# hes procedure
procedure = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/hes_procedure.parquet/')
procedure = check_time(procedure, 'eventdate', time_a=1985, time_b=2005)
procedure = procedure.join(rep_pat, procedure.patid==rep_pat.patid_eligible, 'left')
procedure = procedure.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')
procedure = rename_col(procedure, 'OPCS', 'code') # patid, event date, code

print('procedure:',procedure.schema)

# cprd test
test = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/cprd_test.parquet/')
test = check_time(test, 'eventdate', time_a=1985, time_b=2005)
test = test.join(rep_pat, test.patid==rep_pat.patid_eligible, 'left')
test = test.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible').select(['patid', 'eventdate', 'enttype'])
test = rename_col(test, 'enttype', 'code')

print('test:',test.schema)

# bmi
bmi = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/bmi.parquet/')
bmi = check_time(bmi, 'eventdate', time_a=1985, time_b=2005)
bmi = bmi.join(rep_pat, bmi.patid==rep_pat.patid_eligible, 'left')
bmi = bmi.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

round_bmi = F.udf(lambda x: int(x//1))
bmi = bmi.withColumn('BMI', round_bmi('BMI'))
bmi = rename_col(bmi, 'BMI', 'code')

print('bmi:',bmi.schema)

# bp_low
bp_low = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/bp_low.parquet/')
bp_low = check_time(bp_low, 'eventdate', time_a=1985, time_b=2005)
bp_low = bp_low.join(rep_pat, bp_low.patid==rep_pat.patid_eligible, 'left')
bp_low = bp_low.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

round_bp = F.udf(lambda x: int(x//5))
bp_low = bp_low.withColumn('bp_low', round_bp('bp_low'))
bp_low = rename_col(bp_low, 'bp_low', 'code')

print('bp_low:',bp_low.schema)

# bp_high
bp_high = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/bp_high.parquet/')
bp_high = check_time(bp_high, 'eventdate', time_a=1985, time_b=2005)
bp_high = bp_high.join(rep_pat, bp_high.patid==rep_pat.patid_eligible, 'left')
bp_high = bp_high.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

round_bp = F.udf(lambda x: int(x//5))
bp_high = bp_high.withColumn('bp_high', round_bp('bp_high'))
bp_high = rename_col(bp_high, 'bp_high', 'code')

print('bp_high:', bp_high.schema)

# smoke
smoke = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/smoke.parquet/')
smoke = check_time(smoke, 'eventdate', time_a=1985, time_b=2005)
smoke = smoke.join(rep_pat, smoke.patid==rep_pat.patid_eligible, 'left')
smoke = smoke.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')
smoke = rename_col(smoke, 'smoke', 'code')

print('smoke:',smoke.schema)

# alcohol
alcohol = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/alcohol.parquet/')
alcohol = check_time(alcohol, 'eventdate', time_a=1985, time_b=2005)
alcohol = alcohol.join(rep_pat, alcohol.patid==rep_pat.patid_eligible, 'left')
alcohol = alcohol.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')
alcohol = rename_col(alcohol, 'alcohol', 'code')

print('alcohol:',alcohol.schema)

# add signature for each mordality
diag = diag.withColumn('code', F.concat(F.lit('DIA'), diag['code']))
med = med.withColumn('code', F.concat(F.lit('MED'), med['code']))
procedure = procedure.withColumn('code', F.concat(F.lit('PRO'), procedure['code']))
test = test.withColumn('code', F.concat(F.lit('TES'), test['code']))
bmi = bmi.withColumn('code', F.concat(F.lit('BMI'), bmi['code']))
bp_low = bp_low.withColumn('code', F.concat(F.lit('BPL'), bp_low['code']))
bp_high = bp_high.withColumn('code', F.concat(F.lit('BPH'), bp_high['code']))
smoke = smoke.withColumn('code', F.concat(F.lit('SMO'), smoke['code']))
alcohol = alcohol.withColumn('code', F.concat(F.lit('ALC'), alcohol['code']))

def format_age(path_to_demo, data):
    demographic = Patient(read_txt(spark.sc, spark.sqlContext, path=path_to_demo)) \
    .accept_flag().yob_calibration().cvt_crd2date().cvt_tod2date().cvt_deathdate2date().get_pracid().drop('accept')\
    .select(['patid', 'yob'])
    
    data = data.join(demographic, data.patid==demographic.patid, 'left').drop(demographic.patid)
    data= EHR(data).cal_age('eventdate', 'yob', year=True, name='age').select(['patid', 'eventdate', 'code', 'age'])
    return data

diag = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', diag)
med = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', med)
procedure = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', procedure)
test = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', test)
bmi = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', bmi)
bp_low = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', bp_low)
bp_high = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', bp_high)
smoke = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', smoke)
alcohol = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', alcohol)

# ensemble
records = diag.union(med).union(procedure).union(test).union(bmi).union(bp_low).union(bp_high).union(smoke).union(alcohol)

# # create dictionary
# encounter = records.groupBy('code').count().toPandas()
# encounter = encounter[encounter['count']>1000]

# def create_vocab(code_cnt, symbols=None):
#     if symbols is None:
#         symbols = ["PAD", "UNK", "SEP", "CLS", "MASK"]

#     # initialize dictionaries
#     token2idx = {}
#     idx2token = {}

#     # set up predefined symbols
#     for i in range(len(symbols)):
#         token2idx[str(symbols[i])] = i
#         idx2token[i] = str(symbols[i])

#     # add all the tokens into the dictionary
#     token = code_cnt.code.values
#     for i in range(len(token)):
#         idx = i + len(symbols)
#         token2idx[str(token[i])] = idx
#         idx2token[idx] = str(token[i])

#     return token2idx, idx2token

# token2idx, idx2token = create_vocab(encounter)
# # save vocab
# data = {
#     'token2idx': token2idx,
#     'idx2token': idx2token
# }

# save_obj(data, '/home/shared/yikuan/HiBEHRT/data/dict4all')

# del encounter
# del data

def format_sequence(data):
    # group by date
    data = data.groupby(['patid', 'eventdate']).agg(F.collect_list('code').alias('code'), F.collect_list('age').alias('age'))
    
    data = EHR(data).array_add_element('code', 'SEP')
    # add extra age to fill the gap of sep
    extract_age = F.udf(lambda x: x[0])
    data = data.withColumn('age_temp', extract_age('age')).withColumn('age', F.concat(F.col('age'),F.array(F.col('age_temp')))).drop('age_temp')
    
    # sort and merge code and age
    w = Window.partitionBy('patid').orderBy('eventdate')
    data = data.withColumn('code',F.collect_list('code').over(w))\
                .withColumn('age', F.collect_list('age').over(w))\
                .groupBy('patid').agg(F.max('code').alias('code'), F.max('age').alias('age'))
    data = EHR(data).array_flatten('code').array_flatten('age') # patid, code, age
    return data

records = format_sequence(records)
 
def format_seg_and_position_code(data):
#     
    def seg_records(x):
        seg_list = []
        flag = 0
        for each in x:
            if each != 'SEP':
                seg_list.append(flag)
            else:
                seg_list.append(flag)
                flag = (flag + 1)%2
                
        return seg_list
                
    
    seg = F.udf(lambda x: seg_records(x),ArrayType(StringType(),True))
    data = data.withColumn('seg', seg('code'))
    
    def posi_records(x):
        posi_list = []
        flag = 0
        for each in x:
            if each != 'SEP':
                posi_list.append(flag)
            else:
                posi_list.append(flag)
                flag = flag + 1
        return posi_list
    
    posi = F.udf(lambda x: posi_records(x),ArrayType(StringType(),True))
    data = data.withColumn('position', posi('code'))
    
    return data

records = format_seg_and_position_code(records)

from pyspark.sql.types import *

schema = StructType([StructField('patid', StringType(), True),
                     StructField('code',ArrayType(StringType(),True), True),
                     StructField('age', ArrayType(StringType(),True), True),
                     StructField('seg', ArrayType(StringType(),True), True),
                     StructField('position', ArrayType(StringType(),True), True)
                     ])

def remove_sep(patid, code, age, seg, position):
    code_list = []
    age_list = []
    seg_list = []
    position_list = []
    
    for i in range(len(code)):
        if code[i]!='SEP':
            code_list.append(code[i])
            age_list.append(age[i])
            seg_list.append(seg[i])
            position_list.append(position[i])
    return patid, code, age, seg, position
        

test_udf = F.udf(remove_sep, schema)
records = records.select(test_udf('patid', 'code', 'age', 'seg', 'position').alias("test"))
records = records.select("test.*")


records.write.parquet('/home/shared/yikuan/HiBEHRT/data/selfsupervise.parquet')
print('end')

# merge patient from 1985 - 2015, exclude pateint who are positive within 1985-2015

In [10]:
patient_summary_1985_2005 = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/HF/HF_1985_2005.parquet/')
patient_summary_2005_2015 = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/HF/HF_2005_2015.parquet/')

name_list = ['patid', 'firstDate', 'lastDate', 'label', 'HFDate']
for name in name_list:
    patient_summary_1985_2005 = rename_col(patient_summary_1985_2005, name, '{}_1985_2005'.format(name))
    patient_summary_2005_2015 = rename_col(patient_summary_2005_2015, name, '{}_2005_2015'.format(name))

patient_summary = patient_summary_2005_2015.join(patient_summary_1985_2005, patient_summary_2005_2015.patid_2005_2015==patient_summary_1985_2005.patid_1985_2005, 'left')

# paitent doesnt exist in 1985-2005 use all information from 2005 - 2015
patient_summary_new = patient_summary.filter(F.col('patid_1985_2005').isNull())
patient_summary_new = rename_col(patient_summary_new, 'patid_2005_2015', 'patid')
patient_summary_new = rename_col(patient_summary_new, 'firstDate_2005_2015', 'firstDate')
patient_summary_new = rename_col(patient_summary_new, 'lastDate_2005_2015', 'lastDate')
patient_summary_new = rename_col(patient_summary_new, 'label_2005_2015', 'label')
patient_summary_new = rename_col(patient_summary_new, 'HFDate_2005_2015', 'HFDate')
patient_summary_new = patient_summary_new.select(name_list)

# process patient exist in 1985-2005
# 1. remove patient who are (+) in 1985-2005
patient_summary_old = patient_summary.filter(F.col('patid_1985_2005').isNotNull()).filter(F.col('label_1985_2005')!=1)
# 2. keep patient_2005_2015, first date 1985-2005,  last date 2005-2015, label 2005-2015, HF date 2005-2015

patient_summary_old = rename_col(patient_summary_old, 'patid_2005_2015', 'patid')
patient_summary_old = rename_col(patient_summary_old, 'firstDate_1985_2005', 'firstDate')
patient_summary_old = rename_col(patient_summary_old, 'lastDate_2005_2015', 'lastDate')
patient_summary_old = rename_col(patient_summary_old, 'label_2005_2015', 'label')
patient_summary_old = rename_col(patient_summary_old, 'HFDate_2005_2015', 'HFDate')
patient_summary_old = patient_summary_old.select(name_list)

patient_summary = patient_summary_old.union(patient_summary_new)
patient_summary.write.parquet('/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/HF_1985_2015.parquet')

# format training dataset

In [24]:
# read representation list 
rep_pat = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/pat_downstrean.parquet')
rep_pat = rename_col(rep_pat, 'patid', 'patid_eligible')

start_year = 1985
end_year = 2015

# diagnose
diag = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/diagnoses.parquet')
diag = check_time(diag, 'eventdate', time_a=start_year, time_b=end_year)
diag = diag.join(rep_pat, diag.patid==rep_pat.patid_eligible, 'left')
diag = diag.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

def select_diagnose(diagnoses):
    diagnoses_icd = diagnoses.filter(F.col('type')=='icd')\
                    .withColumn('first', F.col('code').substr(0,1))\
                    .filter(F.col('first').isin(*['Z','V','R','U','X','Y']) == False)\
                    .select(['patid', 'eventdate', 'code', 'type', 'source'])
    diagnoses_read = diagnoses.filter(F.col('type')=='read')\
                .withColumn('first', F.col('code').substr(0,1))\
                .filter(F.col('first').isin(*['0', '1', '2', '3', '4', '5', '6', '7','8', '9', 'Z', 'U']) == False)\
                .select(['patid', 'eventdate', 'code', 'type', 'source'])
    diagnoses = diagnoses_icd.union(diagnoses_read)    
    diagnoses = diagnoses.select(['patid', 'eventdate', 'code'])
    return diagnoses

diag = select_diagnose(diag) # patid, event date, code

print('diag:',diag.schema)

# medication
med = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/medication.parquet/')
med = check_time(med, 'eventdate', time_a=start_year, time_b=end_year)
med = med.join(rep_pat, med.patid==rep_pat.patid_eligible, 'left')
med = med.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible') # patid, event date, code

print('med:',med.schema)

# hes procedure
procedure = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/hes_procedure.parquet/')
procedure = check_time(procedure, 'eventdate', time_a=start_year, time_b=end_year)
procedure = procedure.join(rep_pat, procedure.patid==rep_pat.patid_eligible, 'left')
procedure = procedure.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')
procedure = rename_col(procedure, 'OPCS', 'code') # patid, event date, code

print('procedure:',procedure.schema)

# cprd test
test = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/cprd_test.parquet/')
test = check_time(test, 'eventdate', time_a=start_year, time_b=end_year)
test = test.join(rep_pat, test.patid==rep_pat.patid_eligible, 'left')
test = test.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible').select(['patid', 'eventdate', 'enttype'])
test = rename_col(test, 'enttype', 'code')

print('test:',test.schema)

# bmi
bmi = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/bmi.parquet/')
bmi = check_time(bmi, 'eventdate', time_a=start_year, time_b=end_year)
bmi = bmi.join(rep_pat, bmi.patid==rep_pat.patid_eligible, 'left')
bmi = bmi.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

round_bmi = F.udf(lambda x: int(x//1))
bmi = bmi.withColumn('BMI', round_bmi('BMI'))
bmi = rename_col(bmi, 'BMI', 'code')

print('bmi:',bmi.schema)

# bp_low
bp_low = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/bp_low.parquet/')
bp_low = check_time(bp_low, 'eventdate', time_a=start_year, time_b=end_year)
bp_low = bp_low.join(rep_pat, bp_low.patid==rep_pat.patid_eligible, 'left')
bp_low = bp_low.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

round_bp = F.udf(lambda x: int(x//5))
bp_low = bp_low.withColumn('bp_low', round_bp('bp_low'))
bp_low = rename_col(bp_low, 'bp_low', 'code')

print('bp_low:',bp_low.schema)

# bp_high
bp_high = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/bp_high.parquet/')
bp_high = check_time(bp_high, 'eventdate', time_a=start_year, time_b=end_year)
bp_high = bp_high.join(rep_pat, bp_high.patid==rep_pat.patid_eligible, 'left')
bp_high = bp_high.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')

round_bp = F.udf(lambda x: int(x//5))
bp_high = bp_high.withColumn('bp_high', round_bp('bp_high'))
bp_high = rename_col(bp_high, 'bp_high', 'code')

print('bp_high:', bp_high.schema)

# smoke
smoke = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/smoke.parquet/')
smoke = check_time(smoke, 'eventdate', time_a=start_year, time_b=end_year)
smoke = smoke.join(rep_pat, smoke.patid==rep_pat.patid_eligible, 'left')
smoke = smoke.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')
smoke = rename_col(smoke, 'smoke', 'code')

print('smoke:',smoke.schema)

# alcohol
alcohol = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/alcohol.parquet/')
alcohol = check_time(alcohol, 'eventdate', time_a=start_year, time_b=end_year)
alcohol = alcohol.join(rep_pat, alcohol.patid==rep_pat.patid_eligible, 'left')
alcohol = alcohol.filter(F.col('patid_eligible').isNotNull()).drop('patid_eligible')
alcohol = rename_col(alcohol, 'alcohol', 'code')

print('alcohol:',alcohol.schema)

# add signature for each mordality
diag = diag.withColumn('code', F.concat(F.lit('DIA'), diag['code']))
med = med.withColumn('code', F.concat(F.lit('MED'), med['code']))
procedure = procedure.withColumn('code', F.concat(F.lit('PRO'), procedure['code']))
test = test.withColumn('code', F.concat(F.lit('TES'), test['code']))
bmi = bmi.withColumn('code', F.concat(F.lit('BMI'), bmi['code']))
bp_low = bp_low.withColumn('code', F.concat(F.lit('BPL'), bp_low['code']))
bp_high = bp_high.withColumn('code', F.concat(F.lit('BPH'), bp_high['code']))
smoke = smoke.withColumn('code', F.concat(F.lit('SMO'), smoke['code']))
alcohol = alcohol.withColumn('code', F.concat(F.lit('ALC'), alcohol['code']))

def format_age(path_to_demo, data):
    demographic = Patient(read_txt(spark.sc, spark.sqlContext, path=path_to_demo)) \
    .accept_flag().yob_calibration().cvt_crd2date().cvt_tod2date().cvt_deathdate2date().get_pracid().drop('accept')\
    .select(['patid', 'yob'])
    
    data = data.join(demographic, data.patid==demographic.patid, 'left').drop(demographic.patid)
    data= EHR(data).cal_age('eventdate', 'yob', year=True, name='age').select(['patid', 'eventdate', 'code', 'age'])
    return data

diag = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', diag)
med = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', med)
procedure = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', procedure)
test = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', test)
bmi = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', bmi)
bp_low = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', bp_low)
bp_high = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', bp_high)
smoke = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', smoke)
alcohol = format_age('/home/workspace/datasets/cprd/cuts/02_cprd2015/1_RawData/Patient', alcohol)

# ensemble
records = diag.union(med).union(procedure).union(test).union(bmi).union(bp_low).union(bp_high).union(smoke).union(alcohol)

diag: StructType(List(StructField(patid,StringType,true),StructField(eventdate,DateType,true),StructField(code,StringType,true)))
med: StructType(List(StructField(patid,StringType,true),StructField(eventdate,DateType,true),StructField(code,StringType,true)))
procedure: StructType(List(StructField(patid,StringType,true),StructField(code,StringType,true),StructField(eventdate,DateType,true)))
test: StructType(List(StructField(patid,StringType,true),StructField(eventdate,DateType,true),StructField(code,StringType,true)))
bmi: StructType(List(StructField(patid,StringType,true),StructField(eventdate,DateType,true),StructField(code,StringType,true)))
bp_low: StructType(List(StructField(patid,StringType,true),StructField(eventdate,DateType,true),StructField(code,StringType,true)))
bp_high: StructType(List(StructField(patid,StringType,true),StructField(eventdate,DateType,true),StructField(code,StringType,true)))
smoke: StructType(List(StructField(patid,StringType,true),StructField(eventdate,Da

In [25]:
hf_ref_df = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/HF_1985_2015.parquet')

In [26]:
timeDiff = (F.unix_timestamp('lastDate', "yyyy-MM-dd") - F.unix_timestamp('firstDate', "yyyy-MM-dd"))
hf_ref = hf_ref_df.withColumn("duration", timeDiff).withColumn('duration', (F.col('duration')/3600/24).cast('integer'))

records = records.join(hf_ref, records.patid==hf_ref.patid, 'left').drop(hf_ref.patid).dropna()\
            .where((F.col('eventdate') >= F.col('firstDate')) & (F.col('eventdate') <= F.col('lastDate')))\
            .select(['patid', 'eventdate', 'code', 'age', 'label', 'duration'])

records = records.filter(F.col('duration') > 3*365).select(['patid', 'eventdate', 'code', 'age', 'label']).dropDuplicates()

In [27]:
def format_sequence(data):
    # group by date
    data = data.groupby(['patid', 'eventdate']).agg(F.collect_list('code').alias('code'), F.collect_list('age').alias('age'), F.first('label').alias('label'))
    
    data = EHR(data).array_add_element('code', 'SEP')
    # add extra age to fill the gap of sep
    extract_age = F.udf(lambda x: x[0])
    data = data.withColumn('age_temp', extract_age('age')).withColumn('age', F.concat(F.col('age'),F.array(F.col('age_temp')))).drop('age_temp')
    
    # sort and merge code and age
    w = Window.partitionBy('patid').orderBy('eventdate')
    data = data.withColumn('code',F.collect_list('code').over(w))\
                .withColumn('age', F.collect_list('age').over(w))\
                .groupBy('patid').agg(F.max('code').alias('code'), F.max('age').alias('age'), F.first('label').alias('label'))
    data = EHR(data).array_flatten('code').array_flatten('age') # patid, code, age
    return data

records = format_sequence(records)
 
def format_seg_and_position_code(data):
#     
    def seg_records(x):
        seg_list = []
        flag = 0
        for each in x:
            if each != 'SEP':
                seg_list.append(flag)
            else:
                seg_list.append(flag)
                flag = (flag + 1)%2
                
        return seg_list
                
    
    seg = F.udf(lambda x: seg_records(x),ArrayType(StringType(),True))
    data = data.withColumn('seg', seg('code'))
    
    def posi_records(x):
        posi_list = []
        flag = 0
        for each in x:
            if each != 'SEP':
                posi_list.append(flag)
            else:
                posi_list.append(flag)
                flag = flag + 1
        return posi_list
    
    posi = F.udf(lambda x: posi_records(x),ArrayType(StringType(),True))
    data = data.withColumn('position', posi('code'))
    
    return data

records = format_seg_and_position_code(records)

from pyspark.sql.types import *

schema = StructType([StructField('patid', StringType(), True),
                     StructField('code',ArrayType(StringType(),True), True),
                     StructField('age', ArrayType(StringType(),True), True),
                     StructField('seg', ArrayType(StringType(),True), True),
                     StructField('position', ArrayType(StringType(),True), True),
                     StructField('label', IntegerType(), True)
                     ])

def remove_sep(patid, code, age, seg, position, label):
    code_list = []
    age_list = []
    seg_list = []
    position_list = []
    
    for i in range(len(code)):
        if code[i]!='SEP':
            code_list.append(code[i])
            age_list.append(age[i])
            seg_list.append(seg[i])
            position_list.append(position[i])
    return patid, code, age, seg, position, label
        

test_udf = F.udf(remove_sep, schema)
records = records.select(test_udf('patid', 'code', 'age', 'seg', 'position', 'label').alias("test"))
records = records.select("test.*")

# 1985-2005

In [21]:
final = records.randomSplit([0.5,0.5])
train = final[0]
validation = final[1]

validation = validation.randomSplit([0.4, 0.6])
tune = validation[0]
valid = validation[1]

train.write.parquet('/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/train.parquet')
tune.write.parquet('/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/tune.parquet')
valid.write.parquet('/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/valid.parquet')

# 1985 - 2015

In [28]:
pat_list_train = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/train.parquet').select(['patid'])
pat_list_tune = read_parquet(spark.sqlContext, '/home/shared/yikuan/HiBEHRT/data/HF/1985_2005/test/tune.parquet').select(['patid'])

pat_list = pat_list_train.union(pat_list_tune)
pat_list = rename_col(pat_list, 'patid', 'patid_temp')

records = records.join(pat_list, records.patid==pat_list.patid_temp, 'left').filter(F.col('patid_temp').isNull()).drop('patid_temp')
valid.write.parquet('/home/shared/yikuan/HiBEHRT/data/HF/1985_2015/test/valid.parquet')