Stroke Prediction - Kaggle
https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import regexp_replace
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, Imputer, VectorIndexer, MinMaxScaler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors

In [2]:
spark = SparkSession.builder.appName('Stroke Prediction').getOrCreate()

In [3]:
stroke_ds = spark.read.csv('./dataset/healthcare-dataset-stroke-data.csv', header=True, inferSchema = True)

In [4]:
stroke_ds = stroke_ds.sampleBy("stroke", fractions={0:0.0535, 1:1}, seed=1234)

In [5]:
stroke_ds.show(1, vertical=True)

-RECORD 0----------------------------
 id                | 9046            
 gender            | Male            
 age               | 67.0            
 hypertension      | 0               
 heart_disease     | 1               
 ever_married      | Yes             
 work_type         | Private         
 Residence_type    | Urban           
 avg_glucose_level | 228.69          
 bmi               | 36.6            
 smoking_status    | formerly smoked 
 stroke            | 1               
only showing top 1 row



In [6]:
stroke_ds.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [7]:
stroke_ds.filter(stroke_ds.gender == 'Male').count()

217

In [8]:
stroke_ds.filter(stroke_ds.gender == 'Female').count()

282

In [9]:
stroke_ds.filter(stroke_ds.stroke == 1).count()

249

In [10]:
stroke_ds.filter(stroke_ds.stroke == 0).count()

250

In [11]:
# null empty check
from pyspark.sql.functions import isnan, when, count, col
stroke_ds.select([count(when(isnan(c), c)).alias(c) for c in stroke_ds.columns]).show(vertical=True)

-RECORD 0----------------
 id                | 0   
 gender            | 0   
 age               | 0   
 hypertension      | 0   
 heart_disease     | 0   
 ever_married      | 0   
 work_type         | 0   
 Residence_type    | 0   
 avg_glucose_level | 0   
 bmi               | 0   
 smoking_status    | 0   
 stroke            | 0   



In [12]:
stroke_ds.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in stroke_ds.columns]).show()

+---+------+---+------------+-------------+------------+---------+--------------+-----------------+---+--------------+------+
| id|gender|age|hypertension|heart_disease|ever_married|work_type|Residence_type|avg_glucose_level|bmi|smoking_status|stroke|
+---+------+---+------------+-------------+------------+---------+--------------+-----------------+---+--------------+------+
|  0|     0|  0|           0|            0|           0|        0|             0|                0|  0|             0|     0|
+---+------+---+------------+-------------+------------+---------+--------------+-----------------+---+--------------+------+



In [13]:
stroke_ds.filter(stroke_ds.bmi == 'N/A').count()

49

In [14]:
#Replace part of string with another string
stroke_ds = stroke_ds.withColumn('bmi', regexp_replace('bmi', 'N/A', ''))

In [15]:
stroke_ds = stroke_ds.withColumn("bmi", stroke_ds.bmi.cast("int"))

In [16]:
bmiImputer = Imputer(inputCols=["bmi"], outputCols=["{}_imputed".format(c) for c in ["bmi"]]).setStrategy('mean') 
stroke_ds_mod = bmiImputer.fit(stroke_ds).transform(stroke_ds)

In [17]:
stroke_ds_mod = stroke_ds_mod.drop('bmi', 'id')

In [18]:
stroke_ds_mod.show(3, vertical=True)

-RECORD 0----------------------------
 gender            | Male            
 age               | 67.0            
 hypertension      | 0               
 heart_disease     | 1               
 ever_married      | Yes             
 work_type         | Private         
 Residence_type    | Urban           
 avg_glucose_level | 228.69          
 smoking_status    | formerly smoked 
 stroke            | 1               
 bmi_imputed       | 36              
-RECORD 1----------------------------
 gender            | Female          
 age               | 61.0            
 hypertension      | 0               
 heart_disease     | 0               
 ever_married      | Yes             
 work_type         | Self-employed   
 Residence_type    | Rural           
 avg_glucose_level | 202.21          
 smoking_status    | never smoked    
 stroke            | 1               
 bmi_imputed       | 29              
-RECORD 2----------------------------
 gender            | Male            
 age        

In [19]:
vectorAssembler_age = VectorAssembler(inputCols= ['age'], outputCol='age_v') 
stroke_ds_mod = vectorAssembler_age.transform(stroke_ds_mod)
scale_age = MinMaxScaler(inputCol='age_v', outputCol='age_scaled')
stroke_ds_mod = scale_age.fit(stroke_ds_mod).transform(stroke_ds_mod)

In [20]:
vectorAssembler_bmi = VectorAssembler(inputCols= ['bmi_imputed'], outputCol='bmi_v') 
stroke_ds_mod = vectorAssembler_bmi.transform(stroke_ds_mod)
scale_bmi = MinMaxScaler(inputCol='bmi_v', outputCol='bmi_scaled')
stroke_ds_mod = scale_bmi.fit(stroke_ds_mod).transform(stroke_ds_mod)

In [21]:
vectorAssembler_glu = VectorAssembler(inputCols= ['avg_glucose_level'], outputCol='glu_v') 
stroke_ds_mod = vectorAssembler_glu.transform(stroke_ds_mod)
scale_glu = MinMaxScaler(inputCol='glu_v', outputCol='glu_scaled')
stroke_ds_mod = scale_glu.fit(stroke_ds_mod).transform(stroke_ds_mod)

In [22]:
stroke_ds_mod.show(1, vertical=True)

-RECORD 0---------------------------------
 gender            | Male                 
 age               | 67.0                 
 hypertension      | 0                    
 heart_disease     | 1                    
 ever_married      | Yes                  
 work_type         | Private              
 Residence_type    | Urban                
 avg_glucose_level | 228.69               
 smoking_status    | formerly smoked      
 stroke            | 1                    
 bmi_imputed       | 36                   
 age_v             | [67.0]               
 age_scaled        | [0.8161764705882353] 
 bmi_v             | [36.0]               
 bmi_scaled        | [0.43396226415094... 
 glu_v             | [228.69]             
 glu_scaled        | [0.8003524555952326] 
only showing top 1 row



In [23]:
strIndexer = StringIndexer(inputCols=['gender', 'smoking_status', 'Residence_type', 'work_type', 'ever_married'], 
outputCols=['gender_idx', 'smoking', 'residence', 'work', 'married']) 

In [24]:
stroke_ds_transform = strIndexer.fit(stroke_ds_mod).transform(stroke_ds_mod)

In [25]:
stroke_ds_transform.printSchema()

root
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)
 |-- bmi_imputed: integer (nullable = true)
 |-- age_v: vector (nullable = true)
 |-- age_scaled: vector (nullable = true)
 |-- bmi_v: vector (nullable = true)
 |-- bmi_scaled: vector (nullable = true)
 |-- glu_v: vector (nullable = true)
 |-- glu_scaled: vector (nullable = true)
 |-- gender_idx: double (nullable = false)
 |-- smoking: double (nullable = false)
 |-- residence: double (nullable = false)
 |-- work: double (nullable = false)
 |-- married: double (nullable = false)



In [26]:
stroke_ds_transform.show(1, vertical=True)

-RECORD 0---------------------------------
 gender            | Male                 
 age               | 67.0                 
 hypertension      | 0                    
 heart_disease     | 1                    
 ever_married      | Yes                  
 work_type         | Private              
 Residence_type    | Urban                
 avg_glucose_level | 228.69               
 smoking_status    | formerly smoked      
 stroke            | 1                    
 bmi_imputed       | 36                   
 age_v             | [67.0]               
 age_scaled        | [0.8161764705882353] 
 bmi_v             | [36.0]               
 bmi_scaled        | [0.43396226415094... 
 glu_v             | [228.69]             
 glu_scaled        | [0.8003524555952326] 
 gender_idx        | 1.0                  
 smoking           | 2.0                  
 residence         | 0.0                  
 work              | 0.0                  
 married           | 0.0                  
only showin

In [27]:
stroke_ds_transform = stroke_ds_transform.withColumnRenamed("stroke","label")

In [28]:
vectorAssembler = VectorAssembler(inputCols= ['hypertension', 'heart_disease', 'glu_scaled', 
                                             'smoking', 'bmi_scaled', 'age_scaled', 'gender_idx', 
                                             'work', 'residence', 'married'], 
                                  outputCol='features') 

In [29]:
vector_stroke_data = vectorAssembler.transform(stroke_ds_transform) 
stroke_data_final = vector_stroke_data.select(["features","label"]) 

In [30]:
stroke_data_final.show(5, vertical=True)

-RECORD 0------------------------
 features | [0.0,1.0,0.800352... 
 label    | 1                    
-RECORD 1------------------------
 features | (10,[2,4,5,7,8],[... 
 label    | 1                    
-RECORD 2------------------------
 features | [0.0,1.0,0.230997... 
 label    | 1                    
-RECORD 3------------------------
 features | (10,[2,3,4,5],[0.... 
 label    | 1                    
-RECORD 4------------------------
 features | [1.0,0.0,0.547280... 
 label    | 1                    
only showing top 5 rows



In [37]:
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(stroke_data_final) 
featureIndexer = VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(stroke_data_final) 
(trainingData, testData) = stroke_data_final.randomSplit([0.6, 0.4]) 
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures",  maxDepth=5) 
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])  

In [38]:
model = pipeline.fit(trainingData) 

In [39]:
predictions = model.transform(testData)

In [40]:
predictions.select("prediction", "indexedLabel", "features").show(5)

+----------+------------+--------------------+
|prediction|indexedLabel|            features|
+----------+------------+--------------------+
|       0.0|         1.0|(10,[0,2,3,4,5],[...|
|       1.0|         0.0|(10,[0,2,3,4,5],[...|
|       1.0|         1.0|(10,[0,2,4,5],[1....|
|       1.0|         1.0|(10,[0,2,4,5,6],[...|
|       1.0|         1.0|(10,[0,2,4,5,8],[...|
+----------+------------+--------------------+
only showing top 5 rows



In [41]:
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g " % (1.0 - accuracy))

Test Error = 0.285714 


In [36]:
treeModel = model.stages[2]
# summary only
print(treeModel)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_2fed11a00afd, depth=6, numNodes=53, numClasses=2, numFeatures=10
