# Create Train and Test Cohorts

## Load libraries and enable Pyspark and PyArrow to allow for distributed processing

In [1]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

In [2]:
import numpy as np
import pandas as pd
import csv
import pickle
import time
import math
import collections
import os
from tqdm import tqdm
import pyspark
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf
import pyspark.sql.functions as f

In [3]:
# create the session
conf = SparkConf().set("spark.ui.port", "4050")
conf = (conf.setMaster('local[*]')
        .set('spark.executor.memory', '35G')
        .set('spark.driver.memory', '35G')
        .set('spark.driver.maxResultSize', '35G'))
# create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
# arrow enabling is what makes the conversion from pandas to spark dataframe really fast
sc._conf.get('spark.driver.memory')

'35G'

## Read in raw data and select for diagnosis code columns (16-41) and demographic information (0-8)

In [28]:
dirPath = '/home/ubuntu/BioMedProject/Data/'
patientDataPath = dirPath + 'patientData.csv' # file was renamed from 'B220_SAA_v1.csv'

begin = time.time()
df = spark.read.load(patientDataPath, format="csv", sep=",", inferSchema="true", header="true")
df = df.select(df.columns[0:8] + df.columns[16:41])
df = df.drop('Type')
end = time.time()
print("Reading csv took: {} seconds".format(end - begin))

Reading csv took: 5.038235664367676 seconds


## Create 30-day readmission labels based on immediate future visit for each patient

In [29]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lit

begin = time.time()
my_window = Window.partitionBy().orderBy("ID")

df = df.withColumn("next_Date", lead(df.Date).over(my_window))
df = df.withColumn("next_Id", lead(df.ID).over(my_window))
df = df.withColumn("decFirst", lit("2018-12-01"))
df = df.withColumn("daysUntilDecFirst", datediff(col("decFirst"), col("Date")).alias("finalVisitDiff"))
df = df.withColumn("datediff", datediff(col("next_Date"), col("Date")).alias("datediff"))
end = time.time()
print("Cell took {}".format(end - begin))

Cell took 0.05290389060974121


In [32]:
begin = time.time()
df = df.withColumn("label", when((f.col("ID") != f.col("next_Id")) & (f.col("daysUntilDecFirst") >= 0), 0)\
                   .when((f.col("ID") != f.col("next_Id")) & (f.col("daysUntilDecFirst") < 0), 2)\
                   .when(col("datediff") > 30, 0)\
                   .when(col("datediff") <= 30, 1)\
                   .otherwise(2))
end = time.time()
print("Cell took {}".format(end - begin))

Cell took 0.02687835693359375


In [33]:
df = df.filter(f.col("label") != 2)

In [34]:
total_count = df.count()
print("Total number of visits: {}".format(total_count))

Total number of visits: 27502638


In [35]:
df = df.drop("next_Date", "next_Id", "decFirst", "daysUntilDecFirst", "datediff")
df.printSchema()

root
 |-- ID: integer (nullable = true)
 |-- Visit: integer (nullable = true)
 |-- Visits: integer (nullable = true)
 |-- Date: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Race: string (nullable = true)
 |-- Dx10_prin: string (nullable = true)
 |-- Dx10_1: string (nullable = true)
 |-- Dx10_2: string (nullable = true)
 |-- Dx10_3: string (nullable = true)
 |-- Dx10_4: string (nullable = true)
 |-- Dx10_5: string (nullable = true)
 |-- Dx10_6: string (nullable = true)
 |-- Dx10_7: string (nullable = true)
 |-- Dx10_8: string (nullable = true)
 |-- Dx10_9: string (nullable = true)
 |-- Dx10_10: string (nullable = true)
 |-- Dx10_11: string (nullable = true)
 |-- Dx10_12: string (nullable = true)
 |-- Dx10_13: string (nullable = true)
 |-- Dx10_14: string (nullable = true)
 |-- Dx10_15: string (nullable = true)
 |-- Dx10_16: string (nullable = true)
 |-- Dx10_17: string (nullable = true)
 |-- Dx10_18: string (nullable = true)
 |-- D

In [10]:
# COUNTS NUMBER OF VISITS FILTERED OUT:
# df.filter(f.col("daysUntilDecFirst") < 0).select(countDistinct("ID")).show()

Visits to keep: 27502639

Visits filtered out: 475293

Total visits: 27977932

## Shuffle the data at random

In [36]:
from pyspark.sql.functions import monotonically_increasing_id 

df = df.orderBy(rand(seed=1234))
df = df.withColumn("index", monotonically_increasing_id())

assert df.count()== total_count

## Create an 80-20 train test split

In [37]:
split = int(total_count * 0.8)
print("Total count: {}".format(total_count))
print("Splitting at: {}".format(split))

Total count: 27502638
Splitting at: 22002110


In [51]:
train = df.filter(col("index").between(0, split)).drop("index") # repartitin or coalesce after filtering -- 1GB per partition rule
test = df.filter(col("index").between(split + 1, total_count)).drop("index")

In [52]:
print("Train size: {}".format(train.count()))
print("Test size: {}".format(test.count()))

Train size: 22002111
Test size: 5500527


## Save train and test datasets

In [53]:
begin = time.time()
train.write.format('csv').option('header', True).option('sep',',').mode('overwrite').save(dirPath + 'train')
end = time.time()
print("Saving train file took {} seconds".format(end - begin))

Saving train file took 165.25227093696594 seconds


In [54]:
begin = time.time()
test.write.format('csv').option('header', True).option('sep',',').mode('overwrite').save(dirPath + 'test')
end = time.time()
print("Saving test file took {} seconds".format(end - begin))

Saving test file took 117.19075202941895 seconds


## Load from saved files

In [55]:
begin = time.time()
train_loaded = spark.read.format("csv").load(dirPath + "train",
                     sep=",", inferSchema="true", header="true")
end = time.time()
print("Cell took {} seconds".format(end - begin))

Cell took 2.823918581008911 seconds


In [56]:
begin = time.time()
test_loaded = spark.read.load(dirPath + "test",
                     format="csv", sep=",", inferSchema="true", header="true")
end = time.time()
print("Cell took {} seconds".format(end - begin))

Cell took 0.7571585178375244 seconds


## Assert no errors during saving

In [57]:
assert train.count() == train_loaded.count()
assert test.count() == test_loaded.count()