This notebook demonstrates all the feature engineering and model building code to train a dynamic Bayes net on EMR data. The sample data is generated by Synthea, using the `virtual_generalist_ckd` module. Please note that this process does not work on Synthea data generated without this module because it does not include patients with advanced stages of the disease; those stage transitions can only be learned from data generated when the `virtual_generalist_ckd` module is included.

In [0]:
%sql

-- use missouri_test20y
-- use missouri_test_big
use missouri_test_ckd2mo

In [0]:
%python

import pandas as pd
import numpy as np
import re
from io import StringIO

concept_table_text = """
attribute                | test       | attribute_type | icd_pattern | icd_name                              | snomed_concept_name
ckd_1                    | ==1        | integer        | N18.1       | Chronic_kidney_disease_stage_1        | Chronic kidney disease stage 1 (disorder)
ckd_2                    | ==2        | integer        | N18.2       | Chronic_kidney_disease_stage_2        | Chronic kidney disease stage 2 (disorder)
ckd_3                    | ==3        | integer        | N18.3       | Chronic_kidney_disease_stage_3        | Chronic kidney disease stage 3 (disorder)
ckd_4                    | ==4        | integer        | N18.4       | Chronic_kidney_disease_stage_4        | Chronic kidney disease stage 4 (disorder)
ckd_5                    | ==5        | integer        | N18.[56]    | Chronic_kidney_disease_stage_5        | End stage renal disease (disorder)
smoker                   | is true    | boolean        | F17         | Nicotine_dependence                   | Smokes tobacco daily
diabetes                 | is true    | boolean        | E11         | Type_2_diabetes_mellitus              | Diabetes
coronary_heart_disease   | is true    | boolean        | I25         | Chronic_ischemic_heart_disease        | Coronary Heart Disease
copd_variant             | is not nil | ConditionOnset | J44         | Chronic_obstructive_pulmonary_disease | Chronic obstructive bronchitis (disorder)
"""

# ICD10 code 'N18.6' is really 'End stage chronic kidney disease', but we lump it in with CKD stage 5 because there is no ckd==6 in Synthea.

concept_table = pd.read_csv(StringIO(re.sub(' +\| +', ',', concept_table_text)))

spark.createDataFrame(concept_table).write.mode('overwrite').saveAsTable('concept_table')

# Library functions

In [0]:
%python 

import numpy as np
import pandas as pd
import pyspark.sql.functions as fn
from pyspark.sql.types import *
from pyspark.sql import Window
from pyspark.sql.dataframe import DataFrame

if getattr(DataFrame, "transform", None) is None:
  DataFrame.transform = lambda self,f: f(self)  # 'monkey patching'

  
def timestamp_to_timeslice(sdf, timestamp_col, timeslice_id_col, time_unit='hour'):
  """ Convert timestamp column to integer ID representing number of time units since Unix epoch.
  Args:
    sdf (Spark dataframe): input dataframe.
    timestamp_col (str): name of input timestamp column; this will be replaced by a timestamp_id.
    timeslice_id_col (str): name of the column to be generated
    time_unit (str): period of time defined by a fixed number of seconds 
        ('day', 'hour', etc. as defined in `seconds_per` dict in this function)
  
  Returns:
    Spark dataframe with timestamp column replaced by timeslice ID.
    
    To do: support months as time_unit:
    ( year(encounter_date) - 1970 ) * 12 + month(encounter_date) epoch_month
  """
  # time_udf = fn.udf(lambda seconds: int(seconds/seconds_per[time_unit]), IntegerType())
  
  if time_unit == 'month':
    return sdf.withColumn(timeslice_id_col, fn.expr(f'( year({timestamp_col}) - 1970 ) * 12 + month({timestamp_col})'))
  
  def convert_seconds(num_seconds, time_unit):
    seconds_per = {'week': 60*60*24*7, 'day': 60*60*24, 'hour':60*60, 'minute':60, 'second':1}
    result = None
    try:
      result = int(num_seconds/seconds_per[time_unit])
    except:
      pass
    
    return result
  
  time_udf = fn.udf(lambda sec: convert_seconds(sec, time_unit), IntegerType())
  
  return sdf\
    .withColumn('posix_timestamp', fn.unix_timestamp(fn.col(timestamp_col)))\
    .withColumn(timeslice_id_col, time_udf(fn.col('posix_timestamp')))\
    .drop('posix_timestamp', timestamp_col)


def expand_rows(sdf, from_col, to_col, sequence_col_name, *id_cols):
  """ Expand a range of integers into a set of rows representing all values in the sequence.
  
  Args:
  
    sdf (spark dataframe): input dataframe
    from_col (str): name of column specifying beginning integer value of sequence.
    to_col (str): name of column specifying ending integer value of sequence
    sequence_col_name (str): name of new column to be generated with sequence values
    id_col_names (array of str): names of id columns
  
  Returns: 
  
    spark dataframe with columns specified in `id_col_names` and `sequence_col_name`, with one row per sequence element.
  
  Example::
  
    range_sdf = spark.createDataFrame(data=[('Ali',3,7), ('Bay',5,10), ('Cal',1,3)], schema = ['name','from','to'])
    expand_rows(range_sdf, 'from', 'to', 'sequence_id', 'name').show()
    
  """
  
  arrayify_udf = fn.udf(lambda s1, s2: [i for i in range(s1, s2+1)] if s1 is not None and s2 is not None else [], ArrayType(IntegerType()))
  
  id_range_df = sdf.select([ *id_cols, arrayify_udf(fn.col(from_col), fn.col(to_col)).alias('int_range')])

  return id_range_df.select([ *id_cols, fn.explode(fn.col('int_range')).alias(sequence_col_name)])


def fill_missing_values_forward(sdf, ordering_col, cols_to_fill, *id_cols):
  """ Fill missing values by carrying previous values forward.
  
  Args:
    sdf: a Spark DataFrame
    ordering_col: column by which rows should be sorted
    cols_to_fill: list of columns where missing values will be filled
    id_cols: list of columns that collectively form a unique identifier that can be used to partition cases.
  
  """
  lookback_window = Window.partitionBy(*id_cols)\
                 .orderBy(ordering_col)\
                 .rowsBetween( Window.unboundedPreceding, 0)

  for ctf in cols_to_fill:
    filled_col = ctf + "_filled"
    sdf = sdf.withColumn(filled_col, fn.last(sdf[ctf], ignorenulls=True).over(lookback_window))

  return sdf

# Collect encounter-month data

In [0]:
%sql

select * from conditions where description rlike('Chronic kidney disease') order by patient, start -- limit 10

START,STOP,PATIENT,ENCOUNTER,CODE,DESCRIPTION
1984-10-05,,00000d0b-cc2b-3dd1-2fd8-50db08b48cb7,b9953efc-5dd5-1da6-d01f-14753165c9d0,431855005,Chronic kidney disease stage 1 (disorder)
1993-05-17,,00014fd4-78f9-07a0-cd57-e8838c90b32c,7ab2fdf4-0a83-40ca-f127-65b6f6a6a3ac,431855005,Chronic kidney disease stage 1 (disorder)
2011-02-23,,0001ccca-76ae-1db4-1614-0b57e6b08f1e,02c831ac-0844-055d-e7b5-4890d9233285,433144002,Chronic kidney disease stage 3 (disorder)
2018-05-23,,0001ccca-76ae-1db4-1614-0b57e6b08f1e,1c7d87a3-9541-386f-3568-37dd51c9d116,431856006,Chronic kidney disease stage 2 (disorder)
1988-08-16,,0001d77e-e70d-3b53-d91e-bfdce2693a39,3dc5d310-c9a3-7a2d-1644-957ca82d25ce,433144002,Chronic kidney disease stage 3 (disorder)
1998-06-16,,0001d77e-e70d-3b53-d91e-bfdce2693a39,0d921a04-acb8-93f1-b782-56deba208df6,431857002,Chronic kidney disease stage 4 (disorder)
1985-07-03,,0001fdcf-7bc9-2afa-6bee-93dd06e4cf4f,77261b4f-2ee0-14a7-ab1a-46d8601681d4,431855005,Chronic kidney disease stage 1 (disorder)
1976-09-19,,000255b9-eb72-fded-6772-b1ceb372b31d,2e1d3d03-a2cd-b2de-8515-d955c397ee00,433144002,Chronic kidney disease stage 3 (disorder)
1991-06-23,,000255b9-eb72-fded-6772-b1ceb372b31d,484226dd-5ebf-f4a0-254e-196016cc699d,431857002,Chronic kidney disease stage 4 (disorder)
2007-09-11,,00037be2-d64b-adb8-c7e7-90dcffa144bf,3bfbda5c-8bf7-20f3-055b-e6c75ab6cd3e,433144002,Chronic kidney disease stage 3 (disorder)


### Joining conditions to encounters

Note that the conditions table includes an encounter id, but this is just for the encounter where the diagnosis was made. we want to include all diagnoses that are still in effect at the time of the encounter, so we join based on the start and stop dates of the condition.

In [0]:
%sql

drop table if exists encounter_attribute;

create table encounter_attribute as
with 
conditions_attribute as (
  select c.*, ct.attribute 
    from conditions c 
    join concept_table ct on c.description = ct.snomed_concept_name
)
select dxa.patient
      , int((year(enc.start) - 1970) * 12 + month(enc.start)) month_number
      , dxa.encounter
      , dxa.code
      , dxa.description
      , dxa.attribute
  from conditions_attribute dxa
  join encounters enc on dxa.encounter = enc.id
    and dxa.start <= enc.start and ((dxa.stop >= enc.stop) or dxa.stop is null)
;

select * from encounter_attribute order by patient, month_number;

patient,month_number,encounter,code,description,attribute
00000d0b-cc2b-3dd1-2fd8-50db08b48cb7,62,68a43393-f4a1-5480-f1b3-7d45c8e1f09f,44054006,Diabetes,diabetes
00000d0b-cc2b-3dd1-2fd8-50db08b48cb7,178,b9953efc-5dd5-1da6-d01f-14753165c9d0,431855005,Chronic kidney disease stage 1 (disorder),ckd_1
00000d0b-cc2b-3dd1-2fd8-50db08b48cb7,220,071b4e6d-d3e7-91f2-5d73-de4e4ba21145,53741008,Coronary Heart Disease,coronary_heart_disease
00001638-85b1-4f99-0260-125deeed577d,452,5650721c-4a81-7468-935c-6ca1140e4328,53741008,Coronary Heart Disease,coronary_heart_disease
00001e47-7c13-be69-cd2f-1d5513edd4e9,-38,a044e84c-5146-d974-c497-108ca31e0f8b,449868002,Smokes tobacco daily,smoker
00002b08-788e-7184-50e4-fab6be044a6d,359,b4ebec76-9c0f-77d6-43d6-e96a325ac355,53741008,Coronary Heart Disease,coronary_heart_disease
00014fd4-78f9-07a0-cd57-e8838c90b32c,281,7ab2fdf4-0a83-40ca-f127-65b6f6a6a3ac,44054006,Diabetes,diabetes
00014fd4-78f9-07a0-cd57-e8838c90b32c,281,7ab2fdf4-0a83-40ca-f127-65b6f6a6a3ac,431855005,Chronic kidney disease stage 1 (disorder),ckd_1
0001ccca-76ae-1db4-1614-0b57e6b08f1e,494,02c831ac-0844-055d-e7b5-4890d9233285,433144002,Chronic kidney disease stage 3 (disorder),ckd_3
0001d77e-e70d-3b53-d91e-bfdce2693a39,224,3dc5d310-c9a3-7a2d-1644-957ca82d25ce,433144002,Chronic kidney disease stage 3 (disorder),ckd_3


In [0]:
%sql

create or replace temporary view patient_info as
  select id patient
      , (year(birthdate) - 1970) * 12 + month(birthdate) birth_month 
      , gender
      , race
      , ethnicity
  from patients;


-- all the ckd stages 1 or higher
create or replace temporary view patient_month_ckd_stage as
with encounter_month_ckd_positive as (
  select encounter, patient, month_number, attribute ckd_stage 
    from encounter_attribute 
    where attribute rlike('ckd_')
)
select patient, month_number, max(ckd_stage) ckd_stage 
  from encounter_month_ckd_positive 
  group by patient, month_number
;


create or replace temporary view encounter_month_attribute as 
select encounter, patient, month_number, attribute
  from encounter_attribute 
  where attribute not rlike('ckd_')
;

In [0]:
%python

emal = spark.sql('select * from encounter_month_attribute')

pmaw = ( emal
         .withColumn('mark', fn.lit(1))
         .groupby('patient', 'month_number')
         .pivot('attribute')
         .agg(fn.sum('mark'))
       )

attribute_cols = ['copd_variant', 'coronary_heart_disease', 'diabetes', 'smoker']  # pmaw.columns[2:]

pmaw = fill_missing_values_forward(pmaw, 'month_number', attribute_cols, 'patient')

pmaw = pmaw.selectExpr(['patient', 'month_number', *[c + '_filled as ' + c for c in attribute_cols]])
    
for attribute_col in attribute_cols:
  pmaw = pmaw.withColumn(attribute_col, fn.when(fn.col(attribute_col).isNotNull(),'T').otherwise('F'))

pmaw.write.format("parquet").mode('overwrite').saveAsTable('patient_month_attribute_wide')

# pmaw.filter("patient == '0X10001BEC'").orderBy(['patient', 'month_number']).show(10)

In [0]:
%sql

select * from patient_month_attribute_wide limit 10

patient,month_number,copd_variant,coronary_heart_disease,diabetes,smoker
0003c81b-8c17-8ea0-3248-f90b7d043b2a,437,F,T,F,F
00046737-1075-9bdf-9df4-d54183ad2460,280,F,F,T,F
0004d31c-f584-9e74-f3b9-efbf6b71b811,219,T,F,F,F
00057907-23ce-6f30-2ae1-6b15fea01224,498,F,T,F,F
0005fa49-6568-8798-298b-034139002252,430,F,T,F,F
00080917-e2eb-ef97-8e57-1ff2a0cbab11,486,F,F,T,F
00090784-4fd0-3c3d-b5ee-79b3a6558118,149,F,F,T,F
000b8e34-33ee-cd05-c3ec-2495abcdc26e,21,T,F,F,F
0011fea8-e6ee-a055-caf8-837f29f81b7a,563,F,T,F,F
00137918-0c44-c237-8321-b2c68fe0f83a,380,F,T,F,F


In [0]:
%sql

select * from patient_month_ckd_stage order by patient, month_number

patient,month_number,ckd_stage
00000d0b-cc2b-3dd1-2fd8-50db08b48cb7,178,ckd_1
00014fd4-78f9-07a0-cd57-e8838c90b32c,281,ckd_1
0001ccca-76ae-1db4-1614-0b57e6b08f1e,494,ckd_3
0001d77e-e70d-3b53-d91e-bfdce2693a39,224,ckd_3
0001d77e-e70d-3b53-d91e-bfdce2693a39,342,ckd_4
0001d77e-e70d-3b53-d91e-bfdce2693a39,357,ckd_5
0001d77e-e70d-3b53-d91e-bfdce2693a39,369,ckd_5
0001d77e-e70d-3b53-d91e-bfdce2693a39,390,ckd_5
0001d77e-e70d-3b53-d91e-bfdce2693a39,403,ckd_5
0001d77e-e70d-3b53-d91e-bfdce2693a39,454,ckd_5


In [0]:
%sql

-- fill forward to cover missing values where there is a value to fill with; replace the rest with 'ckd_0'
create or replace temporary view patient_month_ckd_stage_all as
select pmaw.patient, pmaw.month_number, pmckd.ckd_stage
  from patient_month_attribute_wide pmaw
  left join patient_month_ckd_stage pmckd 
    on pmaw.patient = pmckd.patient 
    and pmaw.month_number = pmckd.month_number
  order by patient, month_number
;


In [0]:
%sql
-- TEMP
-- select * from patient_month_ckd_stage_all where ckd_stage is not null order by patient, month_number

-- select patient, count(*) tally from patient_month_ckd_stage_all group by patient order by  tally desc

-- dc860e66-5529-e6d6-874a-ca8a352ad053 4

select patient, count(*) tally from conditions where description rlike ('Chronic kidney disease stage ')
  group by patient order by tally desc

patient,tally
ef1cd24a-6aa3-ce1f-a1de-864e1d4170de,4
414c4dca-7086-f4b8-d35f-4cf45edfe03b,4
565f8883-40ad-9354-c4d0-5b67a93e6c47,4
7225a192-7cbc-f124-dbda-23d324e1eca9,4
b002f919-2ce6-4fa5-468f-946a7d7ffd46,4
5fce4f24-808e-7681-f23b-a46c2b139503,4
84a50a48-d778-96f1-3ffe-c8127215a6a9,4
9f2b979f-6681-0d0a-500b-8ab23aec26d6,4
27078de0-2f6b-9542-2080-2c74c2b9dd09,4
c17a528a-d152-dade-6b0e-c75850984923,4


In [0]:
%python

pmcsa = spark.sql("select * from patient_month_ckd_stage_all")
pmcsa2 = fill_missing_values_forward(pmcsa, 'month_number', ['ckd_stage'], 'patient')
pmcsa2a = pmcsa2.selectExpr(['patient', 'month_number', "ckd_stage_filled as ckd_stage"])
pmcsa2b = pmcsa2a.na.fill(value='ckd_0', subset=["ckd_stage"])

# This might leave us with cases where the first few values are set to 'ckd_0' just because ckd stage hasn't been mentioned yet.
# Drop the first few months' observations for each patient

window_spec = Window.partitionBy("patient")
pmcsa3 = pmcsa2b\
  .withColumn('row_num', fn.row_number().over(window_spec.orderBy("month_number")))\
  .withColumn('num_rows', fn.max('row_num').over(window_spec))\
  .filter('num_rows > 2')  #  and row_num > 1
                           
pmcsa3.show(100)

In [0]:
%python

sdf2 = pmcsa3.select(['patient', 'month_number', 'ckd_stage'])\
  .orderBy(['patient', 'month_number'])\
  .groupBy(['patient', 'month_number'])\
  .agg(
    fn.max('ckd_stage').alias('ckd_stage'),
  )

sdf3 = sdf2\
  .withColumn("next_month_number", fn.lead(fn.col("month_number")).over(window_spec.orderBy("month_number")))\
  .withColumn("end_month_number", fn.coalesce(fn.col("next_month_number") - 1, fn.col(("month_number"))))

# filtering rows where next_stage is null means anybody whose CKD doesn't show up at least twice will not be counted as having CKD
expand_rows(sdf3, 'month_number', 'end_month_number', 'month_number', 'patient', 'ckd_stage')\
  .orderBy(['patient', 'month_number'])\
  .withColumn('next_ckd_stage', fn.lead(fn.col("ckd_stage")).over(window_spec.orderBy("month_number")))\
  .filter(fn.col('next_ckd_stage').isNotNull())\
  .select('patient', 'month_number', 'ckd_stage', 'next_ckd_stage')\
  .write.format("parquet").mode('overwrite').saveAsTable('patient_monthly_ckd_stage_expanded')

In [0]:
%sql

-- TEMP

select * from patient_monthly_ckd_stage_expanded order by patient, month_number;

patient,month_number,ckd_stage,next_ckd_stage
028441dc-c57a-0dbd-90e0-546aedce462a,-244,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-243,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-242,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-241,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-240,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-239,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-238,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-237,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-236,ckd_0,ckd_0
028441dc-c57a-0dbd-90e0-546aedce462a,-235,ckd_0,ckd_0


In [0]:
%sql

drop table if exists bayes_snomed_table_ckd_stage;

create table bayes_snomed_table_ckd_stage as
with bayes_table1 as (
  select coalesce(pmcs.patient, emaw.patient) patient, coalesce(pmcs.month_number, emaw.month_number) month_number,
        pi.birth_month,
        pi.gender, pi.race, pi.ethnicity,
        pmcs.ckd_stage, pmcs.next_ckd_stage, 
        coalesce(emaw.copd_variant, 'F') copd_variant , 
        coalesce(emaw.coronary_heart_disease, 'F') coronary_heart_disease, 
        coalesce(emaw.diabetes, 'F') diabetes, 
        coalesce(emaw.smoker, 'F') smoker
    from patient_monthly_ckd_stage_expanded pmcs 
      left join patient_month_attribute_wide emaw 
        on pmcs.patient=emaw.patient 
        and pmcs.month_number=emaw.month_number
      left join patient_info pi
        on pmcs.patient = pi.patient
),
bayes_table2 as (
  select patient, month_number, 
      floor((month_number - birth_month)/12) age,
      gender, 
      race, ethnicity, 
      ckd_stage, next_ckd_stage, copd_variant, coronary_heart_disease, diabetes, smoker
    from bayes_table1
    order by patient, month_number
)
select patient, month_number, 
        case
          when age <= 1 then 'age_00_01'
          when age <= 8 then 'age_02_08'
          when age <=17 then 'age_09_17'
          when age <=24 then 'age_18_24'
          when age <=44 then 'age_25_44'
          when age <=64 then 'age_45_64'
          when age <=74 then 'age_65_74'
          when age >=75 then 'age_75_plus'
        end age_group,
        gender, 
        race, ethnicity, 
        ckd_stage, next_ckd_stage, copd_variant, coronary_heart_disease, diabetes, smoker
    from bayes_table2
;

num_affected_rows,num_inserted_rows


In [0]:
%sql

select * from bayes_snomed_table_ckd_stage where next_ckd_stage != ckd_stage order by patient, month_number

patient,month_number,age_group,gender,race,ethnicity,ckd_stage,next_ckd_stage,copd_variant,coronary_heart_disease,diabetes,smoker
0b1d9dc4-26c0-6bfa-a9c2-8564ee86c617,-130,age_25_44,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F
377058e4-ba81-c097-439b-aed380ba12a9,-85,age_25_44,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F
37c746c2-f4fc-e101-4b4e-7d5a9eb24195,46,age_45_64,M,white,nonhispanic,ckd_0,ckd_3,F,F,F,F
37ebe68a-2f45-4da4-2573-43a368b550ac,366,age_25_44,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F
3bbad0ab-18e2-b54f-f310-57637c12060f,111,age_25_44,M,black,hispanic,ckd_0,ckd_1,F,F,F,F
469351de-7ae7-1253-59b5-836e283525af,99,age_45_64,F,white,nonhispanic,ckd_0,ckd_2,F,F,F,F
4b123918-fea2-748f-9075-1ce77cbed357,-213,age_25_44,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F
63b7d5d0-6c87-5342-9eec-11e5d4704df0,35,age_45_64,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F
744c5e82-5156-ee1d-d88e-229f9bd05b8c,156,age_45_64,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F
7744b174-9198-df2b-f83e-0528c43126b7,-55,age_25_44,M,white,nonhispanic,ckd_0,ckd_1,F,F,F,F


In [0]:
%python
import pandas as pd
pd.set_option("display.max_rows", 20)
dynamic_bn_data = spark.sql("select * from bayes_snomed_table_ckd_stage order by patient, month_number")
dbnd_pdf = dynamic_bn_data.toPandas()

dbnd_pdf

Unnamed: 0,patient,month_number,age_group,gender,race,ethnicity,ckd_stage,next_ckd_stage,copd_variant,coronary_heart_disease,diabetes,smoker
0,028441dc-c57a-0dbd-90e0-546aedce462a,-244,age_25_44,M,white,hispanic,ckd_0,ckd_0,F,F,T,F
1,028441dc-c57a-0dbd-90e0-546aedce462a,-243,age_25_44,M,white,hispanic,ckd_0,ckd_0,F,F,F,F
2,028441dc-c57a-0dbd-90e0-546aedce462a,-242,age_25_44,M,white,hispanic,ckd_0,ckd_0,F,F,F,F
3,028441dc-c57a-0dbd-90e0-546aedce462a,-241,age_25_44,M,white,hispanic,ckd_0,ckd_0,F,F,F,F
4,028441dc-c57a-0dbd-90e0-546aedce462a,-240,age_25_44,M,white,hispanic,ckd_0,ckd_0,F,F,F,F
...,...,...,...,...,...,...,...,...,...,...,...,...
110182,fec86bb8-a771-66d4-4656-1747d1a5eea3,99,age_45_64,M,asian,nonhispanic,ckd_0,ckd_0,F,F,F,F
110183,fec86bb8-a771-66d4-4656-1747d1a5eea3,100,age_45_64,M,asian,nonhispanic,ckd_0,ckd_0,F,F,F,F
110184,fec86bb8-a771-66d4-4656-1747d1a5eea3,101,age_45_64,M,asian,nonhispanic,ckd_0,ckd_0,F,F,F,F
110185,fec86bb8-a771-66d4-4656-1747d1a5eea3,102,age_45_64,M,asian,nonhispanic,ckd_0,ckd_0,F,F,F,F


In [0]:
%r

library(bnlearn)
library(sparklyr)
library(DBI)
library(dplyr)

options(width=200, digits=2)

sc <- spark_connect(method='databricks')

Q <- "select * from bayes_snomed_table_ckd_stage"

stage_data <- dbGetQuery(sc, Q)

keepers = c('age_group', 'gender', 'race', 'ethnicity', 'copd_variant', 'coronary_heart_disease', 'diabetes', 'smoker', 'ckd_stage', 'next_ckd_stage')

for (col in keepers)
  stage_data[[col]] <- factor(stage_data[[col]])


indata <- as.data.frame(stage_data[keepers]) # must be dataframe, not tibble



In [0]:
%r
tiers_list <- list(root_nodes=c('age_group', 'gender', 'race', 'ethnicity'),
                   tier2_nodes=c('smoker', 'diabetes'),
                   middle_nodes=c('ckd_stage', 'copd_variant', 'coronary_heart_disease'),
                   outcomes=c('next_ckd_stage')
                  )

my_blacklist <- tiers2blacklist(tiers_list)
stage_dag <- hc(indata[unlist(tiers_list)], blacklist=my_blacklist)
modelstring(stage_dag)

In [0]:
%r

fit <- bn.fit(stage_dag, indata)

fit$next_ckd_stage$prob %>% as.matrix %>% t %>% format(digits=3, scientific=FALSE)

In [0]:
%r

fit$ckd_stage$prob %>% ftable %>% as.matrix %>% format(digits=3, scientific=FALSE)

In [0]:
%r
# patients with diabetes
# fit$next_ckd_stage$prob[,'T',] %>% as.matrix %>% t %>% format(digits=3, scientific=FALSE)