In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as fn, Row
from pyspark.sql.functions import when, col, regexp_extract
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, IndexToString, VectorAssembler
from pyspark.ml import Pipeline
from kafka import KafkaConsumer
from time import sleep
from IPython.display import display, clear_output
from json import loads
import findspark
findspark.init()

In [2]:
scala_version = '2.12'  
spark_version = '3.0.1' 
packages = [
    f'org.apache.spark:spark-sql-kafka-0-10_{scala_version}:{spark_version}',
    'org.apache.kafka:kafka-clients:2.8.0' 
]
spark = SparkSession.builder.master("local").appName("kafka-example").config(
        "spark.jars.packages", ",".join(packages)).getOrCreate()

In [3]:
train = spark.read.format("csv").load("D:/Big data/train.csv",header = 'True',inferSchema='True')

In [4]:
train.head()

Row(encounter_id=2278392, patient_nbr=8222157, race='Caucasian', gender='Female', age='[0-10)', weight='?', admission_type_id=6, discharge_disposition_id=25, admission_source_id=1, time_in_hospital=1, payer_code='?', medical_specialty='Pediatrics-Endocrinology', num_lab_procedures=41, num_procedures=0, num_medications=1, number_outpatient=0, number_emergency=0, number_inpatient=0, diag_1='250.83', diag_2='?', diag_3='?', number_diagnoses=1, max_glu_serum='None', A1Cresult='None', metformin='No', repaglinide='No', nateglinide='No', chlorpropamide='No', glimepiride='No', acetohexamide='No', glipizide='No', glyburide='No', tolbutamide='No', pioglitazone='No', rosiglitazone='No', acarbose='No', miglitol='No', troglitazone='No', tolazamide='No', examide='No', citoglipton='No', insulin='No', glyburide-metformin='No', glipizide-metformin='No', glimepiride-pioglitazone='No', metformin-rosiglitazone='No', metformin-pioglitazone='No', change='No', diabetesMed='No', readmitted='NO')

# Data preprocessing

In [5]:
train = train.withColumn("max_glu_serum", when(col("max_glu_serum").isNull(), "None").otherwise(col("max_glu_serum")))
train = train.withColumn("A1Cresult", when(col("A1Cresult").isNull(), "None").otherwise(col("A1Cresult")))
train = train.filter(col('gender').isin(['Male', 'Female']))

In [6]:
no_count = train.filter(col('readmitted') == 'NO').count()
gt30_count = train.filter(col('readmitted') == '>30').count()
lt30_count = train.filter(col('readmitted') == '<30').count()
max_count = max(no_count, gt30_count, lt30_count)

no_samples = train.filter(col('readmitted') == 'NO').sample(True, max_count / no_count - 1,seed= 42)
gt30_samples = train.filter(col('readmitted') == '>30').sample(True, max_count / gt30_count - 1,seed= 42)
lt30_samples = train.filter(col('readmitted') == '<30').sample(True, max_count / lt30_count - 1,seed= 42)
train_balanced = train.unionAll(no_samples).unionAll(gt30_samples).unionAll(lt30_samples)


# Model Building

In [7]:
label = StringIndexer(inputCol="readmitted", outputCol="label")
string_columns = ['race', 'gender', 'age',
       'admission_type_id', 'discharge_disposition_id', 'admission_source_id',
       'time_in_hospital', 'num_lab_procedures', 'num_procedures',
       'num_medications', 'number_outpatient', 'number_emergency',
       'number_inpatient', 'diag_1', 'diag_2', 'diag_3', 'number_diagnoses',
       'max_glu_serum', 'A1Cresult', 'metformin', 'repaglinide', 'nateglinide',
       'chlorpropamide', 'glimepiride', 'acetohexamide', 'glipizide',
       'glyburide', 'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose',
       'miglitol', 'troglitazone', 'tolazamide', 'examide', 'citoglipton',
       'insulin', 'glyburide-metformin', 'glipizide-metformin',
       'glimepiride-pioglitazone', 'metformin-rosiglitazone',
       'metformin-pioglitazone', 'change', 'diabetesMed']

indexers = [StringIndexer(inputCol=col, outputCol=col+"_index",handleInvalid='keep')for col in string_columns]
assembler_inputs = [col + "_index" for col in string_columns]
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol="features")
rf = DecisionTreeClassifier(featuresCol="features", labelCol="label", maxDepth=30,maxBins=750)

pipeline = Pipeline(stages=[label]+ indexers + [assembler, rf])
model = pipeline.fit(train_balanced)

# Method 1: KafkaConsumer

In [13]:
topic_name = 'Patient_data'
kafka_server = 'localhost:9092'
consumer = KafkaConsumer(
    topic_name,
    bootstrap_servers=kafka_server,
    value_deserializer=lambda x: loads(x.decode('utf-8'))
)

In [16]:
i = 0
test_1 = []
labelsArray = ["NO",">30","<30"]
try:
    for c in consumer:
        print("Showing live view refreshed every 5 seconds")
        print(f"Seconds passed: {i*5}")
        
        test_1.append(c.value)
        test_df_1 = spark.createDataFrame([Row(**x) for x in test_1])         
        display(test_df_1.toPandas().tail(1))
        
        trans_data_1 = model.transform(test_df_1)
        result_1 = IndexToString(inputCol="prediction", outputCol="predicted_label", labels = labelsArray).transform(
                    trans_data_1).select("index", "predicted_label")        
        display(result_1.toPandas())
        
        i += 1
        clear_output(wait=True)
        
except KeyboardInterrupt:
    print("Showing live view refreshed every 5 seconds")
    print(f"Seconds passed: {i*5}")
    
    test_df_1 = spark.createDataFrame([Row(**x) for x in test_1]) 
    test_pd_1 = test_df_1.toPandas()
    display(test_df_1.toPandas())
    
    trans_data_1 = model.transform(test_df_1)
    result_1 = IndexToString(inputCol="prediction", outputCol="predicted_label", labels = labelsArray).transform(
        trans_data_1).select("index", "predicted_label")
    
    display(result_1.toPandas())
    print("Break")
    
finally:  
    print("Live view ended...")

Showing live view refreshed every 5 seconds
Seconds passed: 80


Unnamed: 0,index,encounter_id,patient_nbr,race,gender,age,weight,admission_type_id,discharge_disposition_id,admission_source_id,...,citoglipton,insulin,glyburide-metformin,glipizide-metformin,glimepiride-pioglitazone,metformin-rosiglitazone,metformin-pioglitazone,change,diabetesMed,readmitted
0,27,248916,115196778,Caucasian,Female,[50-60),?,1,1,1,...,No,Steady,No,No,No,No,No,No,Yes,>30
1,28,250872,41606064,Caucasian,Male,[20-30),?,2,1,2,...,No,Down,No,No,No,No,No,Ch,Yes,>30
2,12,40926,85504905,Caucasian,Female,[40-50),?,1,3,7,...,No,Down,No,No,No,No,No,Ch,Yes,<30
3,27,248916,115196778,Caucasian,Female,[50-60),?,1,1,1,...,No,Steady,No,No,No,No,No,No,Yes,>30
4,28,250872,41606064,Caucasian,Male,[20-30),?,2,1,2,...,No,Down,No,No,No,No,No,Ch,Yes,>30
5,32,260166,80845353,Caucasian,Female,[70-80),?,1,1,7,...,No,Steady,No,No,No,No,No,No,Yes,>30
6,33,293058,114715242,Caucasian,Male,[60-70),?,2,6,2,...,No,Steady,No,No,No,No,No,No,Yes,>30
7,35,325848,63023292,Caucasian,Female,[60-70),?,1,1,7,...,No,Down,No,No,No,No,No,Ch,Yes,>30
8,37,326028,112002975,Caucasian,Female,[60-70),?,1,1,7,...,No,Steady,No,No,No,No,No,Ch,Yes,>30
9,43,449142,66274866,Caucasian,Male,[50-60),?,1,1,7,...,No,Steady,No,No,No,No,No,Ch,Yes,>30


Unnamed: 0,index,predicted_label
0,27,>30
1,28,>30
2,12,>30
3,27,>30
4,28,>30
5,32,NO
6,33,>30
7,35,>30
8,37,>30
9,43,<30


Break
Live view ended...


In [17]:
result_1 = IndexToString(inputCol="prediction", outputCol="predicted_label", labels = labelsArray).transform(trans_data_1).select("index", "predicted_label", "readmitted")
result_1.toPandas()

Unnamed: 0,index,predicted_label,readmitted
0,27,>30,>30
1,28,>30,>30
2,12,>30,<30
3,27,>30,>30
4,28,>30,>30
5,32,NO,>30
6,33,>30,>30
7,35,>30,>30
8,37,>30,>30
9,43,<30,>30


# Method 2: Spark Structured Streaming

In [8]:
topic_name = 'Patient_data'
kafka_server = 'localhost:9092'
kafkaDf = spark.read.format("kafka").option("kafka.bootstrap.servers", kafka_server).option("subscribe", topic_name).option("startingOffsets", "earliest").load()

In [9]:
test_df_2 = kafkaDf.select(
    regexp_extract(col("value").cast("string"), r'"index":\s*([0-9]+)', 1).cast("integer").alias("index"),
    regexp_extract(col("value").cast("string"), r'"encounter_id":\s*([0-9]+)', 1).cast("integer").alias("encounter_id"),
    regexp_extract(col("value").cast("string"), r'"patient_nbr":\s*([0-9]+)', 1).cast("integer").alias("patient_nbr"),
    regexp_extract(col("value").cast("string"), r'"race":\s*"([^"]*)"', 1).alias("race"),
    regexp_extract(col("value").cast("string"), r'"gender":\s*"([^"]*)"', 1).alias("gender"),
    regexp_extract(col("value").cast("string"), r'"age":\s*"([^"]*)"', 1).alias("age"),
    regexp_extract(col("value").cast("string"), r'"weight":\s*"([^"]*)"', 1).alias("weight"),
    regexp_extract(col("value").cast("string"), r'"admission_type_id":\s*([0-9]+)', 1).cast("integer").alias("admission_type_id"),
    regexp_extract(col("value").cast("string"), r'"discharge_disposition_id":\s*([0-9]+)', 1).cast("integer").alias("discharge_disposition_id"),
    regexp_extract(col("value").cast("string"), r'"admission_source_id":\s*([0-9]+)', 1).cast("integer").alias("admission_source_id"),
    regexp_extract(col("value").cast("string"), r'"time_in_hospital":\s*([0-9]+)', 1).cast("integer").alias("time_in_hospital"),
    regexp_extract(col("value").cast("string"), r'"payer_code":\s*"([^"]*)"', 1).alias("payer_code"),
    regexp_extract(col("value").cast("string"), r'"medical_specialty":\s*"([^"]*)"', 1).alias("medical_specialty"),
    regexp_extract(col("value").cast("string"), r'"num_lab_procedures":\s*([0-9]+)', 1).cast("integer").alias("num_lab_procedures"),
    regexp_extract(col("value").cast("string"), r'"num_procedures":\s*([0-9]+)', 1).cast("integer").alias("num_procedures"),
    regexp_extract(col("value").cast("string"), r'"num_medications":\s*([0-9]+)', 1).cast("integer").alias("num_medications"),
    regexp_extract(col("value").cast("string"), r'"number_outpatient":\s*([0-9]+)', 1).cast("integer").alias("number_outpatient"),
    regexp_extract(col("value").cast("string"), r'"number_emergency":\s*([0-9]+)', 1).cast("integer").alias("number_emergency"),
    regexp_extract(col("value").cast("string"), r'"number_inpatient":\s*([0-9]+)', 1).cast("integer").alias("number_inpatient"),
    regexp_extract(col("value").cast("string"), r'"diag_1":\s*"([^"]*)"', 1).alias("diag_1"),
    regexp_extract(col("value").cast("string"), r'"diag_2":\s*"([^"]*)"', 1).alias("diag_2"),
    regexp_extract(col("value").cast("string"), r'"diag_3":\s*"([^"]*)"', 1).alias("diag_3"),
    regexp_extract(col("value").cast("string"), r'"number_diagnoses":\s*([0-9]+)', 1).cast("integer").alias("number_diagnoses"),
    regexp_extract(col("value").cast("string"), r'"max_glu_serum":\s*"([^"]*)"', 1).alias("max_glu_serum"),
    regexp_extract(col("value").cast("string"), r'"A1Cresult":\s*"([^"]*)"', 1).alias("A1Cresult"),
    regexp_extract(col("value").cast("string"), r'"metformin":\s*"([^"]*)"', 1).alias("metformin"),
    regexp_extract(col("value").cast("string"), r'"repaglinide":\s*"([^"]*)"', 1).alias("repaglinide"),
    regexp_extract(col("value").cast("string"), r'"nateglinide":\s*"([^"]*)"', 1).alias("nateglinide"),
    regexp_extract(col("value").cast("string"), r'"chlorpropamide":\s*"([^"]*)"', 1).alias("chlorpropamide"),
    regexp_extract(col("value").cast("string"), r'"glimepiride":\s*"([^"]*)"', 1).alias("glimepiride"),
    regexp_extract(col("value").cast("string"), r'"acetohexamide":\s*"([^"]*)"', 1).alias("acetohexamide"),
    regexp_extract(col("value").cast("string"), r'"glipizide":\s*"([^"]*)"', 1).alias("glipizide"),
    regexp_extract(col("value").cast("string"), r'"glyburide":\s*"([^"]*)"', 1).alias("glyburide"),
    regexp_extract(col("value").cast("string"), r'"tolbutamide":\s*"([^"]*)"', 1).alias("tolbutamide"),
    regexp_extract(col("value").cast("string"), r'"pioglitazone":\s*"([^"]*)"', 1).alias("pioglitazone"),
    regexp_extract(col("value").cast("string"), r'"rosiglitazone":\s*"([^"]*)"', 1).alias("rosiglitazone"),
    regexp_extract(col("value").cast("string"), r'"acarbose":\s*"([^"]*)"', 1).alias("acarbose"),
    regexp_extract(col("value").cast("string"), r'"miglitol":\s*"([^"]*)"', 1).alias("miglitol"),
    regexp_extract(col("value").cast("string"), r'"troglitazone":\s*"([^"]*)"', 1).alias("troglitazone"),
    regexp_extract(col("value").cast("string"), r'"tolazamide":\s*"([^"]*)"', 1).alias("tolazamide"),
    regexp_extract(col("value").cast("string"), r'"examide":\s*"([^"]*)"', 1).alias("examide"),
    regexp_extract(col("value").cast("string"), r'"citoglipton":\s*"([^"]*)"', 1).alias("citoglipton"),
    regexp_extract(col("value").cast("string"), r'"insulin":\s*"([^"]*)"', 1).alias("insulin"),
    regexp_extract(col("value").cast("string"), r'"glyburide-metformin":\s*"([^"]*)"', 1).alias("glyburide-metformin"),
    regexp_extract(col("value").cast("string"), r'"glipizide-metformin":\s*"([^"]*)"', 1).alias("glipizide-metformin"),
    regexp_extract(col("value").cast("string"), r'"glimepiride-pioglitazone":\s*"([^"]*)"', 1).alias("glimepiride-pioglitazone"),
    regexp_extract(col("value").cast("string"), r'"metformin-rosiglitazone":\s*"([^"]*)"', 1).alias("metformin-rosiglitazone"),
    regexp_extract(col("value").cast("string"), r'"metformin-pioglitazone":\s*"([^"]*)"', 1).alias("metformin-pioglitazone"),
    regexp_extract(col("value").cast("string"), r'"change":\s*"([^"]*)"', 1).alias("change"),
    regexp_extract(col("value").cast("string"), r'"diabetesMed":\s*"([^"]*)"', 1).alias("diabetesMed"),
    regexp_extract(col("value").cast("string"), r'"readmitted":\s*"([^"]*)"', 1).alias("readmitted")
)

In [10]:
trans_data_2 = model.transform(test_df_2)
labelsArray = ["NO",">30","<30"]
result_2 = IndexToString(inputCol="prediction", outputCol="predicted_label", labels = labelsArray).transform(trans_data_2).select("index","predicted_label")

In [11]:
for x in range(0, 2000):
    try:
        print("Showing live view refreshed every 5 seconds")
        print(f"Seconds passed: {x*5}")
        display(test_df_2.toPandas().tail(1))
        display(result_2.toPandas())
        sleep(5)
        clear_output(wait=True)
    except KeyboardInterrupt:
        print("break")
        break
print("Live view ended...")

Showing live view refreshed every 5 seconds
Seconds passed: 70


Unnamed: 0,index,encounter_id,patient_nbr,race,gender,age,weight,admission_type_id,discharge_disposition_id,admission_source_id,...,citoglipton,insulin,glyburide-metformin,glipizide-metformin,glimepiride-pioglitazone,metformin-rosiglitazone,metformin-pioglitazone,change,diabetesMed,readmitted
15,82,1079592,101707335,Caucasian,Female,[50-60),?,1,7,7,...,No,Up,No,No,No,No,No,Ch,Yes,>30


Unnamed: 0,index,predicted_label
0,12,>30
1,27,>30
2,28,>30
3,32,NO
4,33,>30
5,35,>30
6,37,>30
7,43,<30
8,46,NO
9,50,<30


break
Live view ended...


In [12]:
result_2 = IndexToString(inputCol="prediction", outputCol="predicted_label", labels = labelsArray).transform(trans_data_2).select("index", "predicted_label", "readmitted")
result_2.toPandas()

Unnamed: 0,index,predicted_label,readmitted
0,12,>30,<30
1,27,>30,>30
2,28,>30,>30
3,32,NO,>30
4,33,>30,>30
5,35,>30,>30
6,37,>30,>30
7,43,<30,>30
8,46,NO,<30
9,50,<30,<30
