In [1]:
from pyspark.sql import SparkSession
import warnings
warnings.filterwarnings('ignore')


spark = SparkSession.builder \
    .appName("stroke-prediction") \
    .config("spark.executor.memory", "2g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/12/02 08:43:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
data_path = '/kaggle/input/stroke-prediction-dataset/healthcare-dataset-stroke-data.csv'
df = spark.read.option("header", True).option("inferSchema", True).csv(data_path)

In [3]:
df.show(5, truncate=False)

+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
|id   |gender|age |hypertension|heart_disease|ever_married|work_type    |Residence_type|avg_glucose_level|bmi |smoking_status |stroke|
+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
|9046 |Male  |67.0|0           |1            |Yes         |Private      |Urban         |228.69           |36.6|formerly smoked|1     |
|51676|Female|61.0|0           |0            |Yes         |Self-employed|Rural         |202.21           |N/A |never smoked   |1     |
|31112|Male  |80.0|0           |1            |Yes         |Private      |Rural         |105.92           |32.5|never smoked   |1     |
|60182|Female|49.0|0           |0            |Yes         |Private      |Urban         |171.23           |34.4|smokes         |1     |
|1665 |Female|79.0|1           |0            |Yes      

### EDA

In [4]:
df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [5]:
print(f"Total rows: {df.count()}")
print(f"Total columns: {len(df.columns)}")

Total rows: 5110
Total columns: 12


In [6]:
# Check class distribution
df.groupBy('stroke').count().show()

+------+-----+
|stroke|count|
+------+-----+
|     1|  249|
|     0| 4861|
+------+-----+



In [7]:
# Check for missing values
from pyspark.sql.functions import col, sum as spark_sum, isnan, when, count, mean

df.select([spark_sum(when(col(c).isNull() | isnan(c), 1).otherwise(0)).alias(c) 
           for c in df.columns]).show()

+---+------+---+------------+-------------+------------+---------+--------------+-----------------+---+--------------+------+
| id|gender|age|hypertension|heart_disease|ever_married|work_type|Residence_type|avg_glucose_level|bmi|smoking_status|stroke|
+---+------+---+------------+-------------+------------+---------+--------------+-----------------+---+--------------+------+
|  0|     0|  0|           0|            0|           0|        0|             0|                0|  0|             0|     0|
+---+------+---+------------+-------------+------------+---------+--------------+-----------------+---+--------------+------+



### Data Cleaning

In [8]:
df = df.drop('id')

### Feature Engineering & Preprocessing

In [9]:
from pyspark.sql.functions import col, mean, when

df = spark.read.option("header", True).option("inferSchema", True).csv(data_path)

# Drop id column
df = df.drop('id')

# Convert bmi to double and handle nulls
df = df.withColumn('bmi', col('bmi').cast('double'))

# Fill nulls with mean
mean_bmi = df.select(mean('bmi')).collect()[0][0]
df = df.fillna({'bmi': mean_bmi})

df.printSchema()
df.show(5)

root
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = false)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)

+------+----+------------+-------------+------------+-------------+--------------+-----------------+------------------+---------------+------+
|gender| age|hypertension|heart_disease|ever_married|    work_type|Residence_type|avg_glucose_level|               bmi| smoking_status|stroke|
+------+----+------------+-------------+------------+-------------+--------------+-----------------+------------------+---------------+------+
|  Male|67.0|           0|            1|         Yes|      Private|         Urban|           228.69|

In [10]:
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml import Pipeline

# Convert categorical columns to numeric indices
categorical_cols = ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']

indexers = [StringIndexer(inputCol=col, outputCol=col+"_index", handleInvalid="keep") 
            for col in categorical_cols]

# Assemble all features into a vector
numerical_cols = ['age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi']
indexed_cols = [col+"_index" for col in categorical_cols]
all_feature_cols = numerical_cols + indexed_cols

assembler = VectorAssembler(inputCols=all_feature_cols, outputCol="features_raw")

# Scale features
scaler = StandardScaler(inputCol="features_raw", outputCol="features")

# Create preprocessing pipeline
preprocessing_pipeline = Pipeline(stages=indexers + [assembler, scaler])

# Fit and transform
pipeline_model = preprocessing_pipeline.fit(df)
df_processed = pipeline_model.transform(df).select('features', 'stroke')

df_processed.show(5, truncate=False)

+----------------------------------------------------------------------------------------------------------------------------------+------+
|features                                                                                                                          |stroke|
+----------------------------------------------------------------------------------------------------------------------------------+------+
|[2.9629437376526147,0.0,4.4235458933104335,5.050177133589643,4.7544706732567965,2.0282180648496397,0.0,0.0,0.0,1.8659871947574447]|1     |
|(10,[0,3,4,7,8],[2.6976054924896937,4.465417456745646,3.7533346325843864,0.9009626130593035,2.000061825390059])                   |1     |
|[3.5378432688389427,0.0,4.4235458933104335,2.3390387073759893,4.22186603499579,2.0282180648496397,0.0,0.0,2.000061825390059,0.0]  |1     |
|(10,[0,3,4,9],[2.1669290021638523,3.781283967749156,4.468682818580159,2.798980792136167])                                         |1     |
|[3.493620227978456,

### Train-Test Split

In [11]:
# Split data (80% train, 20% test)
train_df, test_df = df_processed .randomSplit([0.8, 0.2], seed=42)

print(f"Training set size: {train_df.count()}")
print(f"Test set size: {test_df.count()}")

train_df.groupBy('stroke').count().show()

Training set size: 4123
Test set size: 987
+------+-----+
|stroke|count|
+------+-----+
|     1|  192|
|     0| 3931|
+------+-----+



### Handle Class Imbalance

In [12]:
# Separate classes
train_majority = train_df.filter(col('stroke') == 0)
train_minority = train_df.filter(col('stroke') == 1)

# Calculate oversampling ratio
majority_count = train_majority.count()
minority_count = train_minority.count()
ratio = majority_count / minority_count

print(f"Majority: {majority_count}, Minority: {minority_count}, Ratio: {ratio}")

# Oversample minority class
train_minority_oversampled = train_minority.sample(withReplacement=True, fraction=ratio, seed=42)

# Combine datasets
train_balanced = train_majority.union(train_minority_oversampled)

print(f"Balanced training set size: {train_balanced.count()}")
train_balanced.groupBy('stroke').count().show()

Majority: 3931, Minority: 192, Ratio: 20.473958333333332
Balanced training set size: 7964
+------+-----+
|stroke|count|
+------+-----+
|     0| 3931|
|     1| 4033|
+------+-----+



### Train ML Model

In [13]:
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

# 1. Logistic Regression
lr = LogisticRegression(featuresCol='features', labelCol='stroke', maxIter=10)
lr_model = lr.fit(train_balanced)

# 2. Random Forest
rf = RandomForestClassifier(featuresCol='features', labelCol='stroke', numTrees=100, seed=42)
rf_model = rf.fit(train_balanced)

# 3. Gradient Boosted Trees
gbt = GBTClassifier(featuresCol='features', labelCol='stroke', maxIter=10, seed=42)
gbt_model = gbt.fit(train_balanced)

print("Models trained successfully!")

                                                                                

Models trained successfully!


In [14]:
# Make predictions on test set
lr_predictions = lr_model.transform(test_df)
rf_predictions = rf_model.transform(test_df)
gbt_predictions = gbt_model.transform(test_df)

# Evaluators
auc_evaluator = BinaryClassificationEvaluator(labelCol='stroke', metricName='areaUnderROC')
accuracy_evaluator = MulticlassClassificationEvaluator(labelCol='stroke', metricName='accuracy')

# Evaluate models
models = {
    'Logistic Regression': lr_predictions,
    'Random Forest': rf_predictions,
    'Gradient Boosted Trees': gbt_predictions
}

for name, predictions in models.items():
    auc = auc_evaluator.evaluate(predictions)
    accuracy = accuracy_evaluator.evaluate(predictions)
    print(f"\n{name}:")
    print(f"  AUC: {auc:.4f}")
    print(f"  Accuracy: {accuracy:.4f}")


Logistic Regression:
  AUC: 0.8592
  Accuracy: 0.7244

Random Forest:
  AUC: 0.8474
  Accuracy: 0.7204

Gradient Boosted Trees:
  AUC: 0.8368
  Accuracy: 0.7548


In [15]:
from pyspark.mllib.evaluation import MulticlassMetrics

# Function to calculate detailed metrics
def print_detailed_metrics(predictions, model_name):
    # Convert to RDD for MulticlassMetrics
    predictionAndLabels = predictions.select('prediction', 'stroke').rdd.map(lambda x: (float(x[0]), float(x[1])))
    metrics = MulticlassMetrics(predictionAndLabels)
    
    print(f"\n{'='*50}")
    print(f"{model_name} - Detailed Metrics")
    print(f"{'='*50}")
    
    # Confusion matrix
    print("\nConfusion Matrix:")
    print(metrics.confusionMatrix().toArray())
    
    # Precision, Recall, F1 for each class
    for label in [0.0, 1.0]:
        print(f"\nClass {int(label)} ({'No Stroke' if label == 0 else 'Stroke'}):")
        print(f"  Precision: {metrics.precision(label):.4f}")
        print(f"  Recall: {metrics.recall(label):.4f}")
        print(f"  F1-Score: {metrics.fMeasure(label):.4f}")

# Print metrics for all models
print_detailed_metrics(lr_predictions, "Logistic Regression")
print_detailed_metrics(rf_predictions, "Random Forest")
print_detailed_metrics(gbt_predictions, "Gradient Boosted Trees")

                                                                                


Logistic Regression - Detailed Metrics

Confusion Matrix:
[[664. 266.]
 [  6.  51.]]

Class 0 (No Stroke):
  Precision: 0.9910
  Recall: 0.7140
  F1-Score: 0.8300

Class 1 (Stroke):
  Precision: 0.1609
  Recall: 0.8947
  F1-Score: 0.2727

Random Forest - Detailed Metrics

Confusion Matrix:
[[663. 267.]
 [  9.  48.]]

Class 0 (No Stroke):
  Precision: 0.9866
  Recall: 0.7129
  F1-Score: 0.8277

Class 1 (Stroke):
  Precision: 0.1524
  Recall: 0.8421
  F1-Score: 0.2581

Gradient Boosted Trees - Detailed Metrics

Confusion Matrix:
[[701. 229.]
 [ 13.  44.]]

Class 0 (No Stroke):
  Precision: 0.9818
  Recall: 0.7538
  F1-Score: 0.8528

Class 1 (Stroke):
  Precision: 0.1612
  Recall: 0.7719
  F1-Score: 0.2667


In [16]:
# Save the preprocessing pipeline and model
pipeline_model.write().overwrite().save("/kaggle/working/preprocessing_pipeline")
lr_model.write().overwrite().save("/kaggle/working/stroke_prediction_model")

print("Model and pipeline saved successfully!")

                                                                                

Model and pipeline saved successfully!


In [17]:
def predict_stroke(patient_data):
    """
    Predict stroke risk for a new patient
    
    patient_data: dict with keys: gender, age, hypertension, heart_disease, 
                  ever_married, work_type, Residence_type, avg_glucose_level, 
                  bmi, smoking_status
    """
    # Create DataFrame from input
    patient_df = spark.createDataFrame([patient_data])
    
    # Preprocess
    patient_processed = pipeline_model.transform(patient_df)
    
    # Predict
    prediction = lr_model.transform(patient_processed)
    
    result = prediction.select('prediction', 'probability').collect()[0]
    
    stroke_probability = result['probability'][1]
    prediction_label = "HIGH RISK" if result['prediction'] == 1 else "LOW RISK"
    
    return {
        'prediction': prediction_label,
        'stroke_probability': f"{stroke_probability:.2%}"
    }

# Test with a sample patient
sample_patient = {
    'gender': 'Male',
    'age': 67.0,
    'hypertension': 1,
    'heart_disease': 1,
    'ever_married': 'Yes',
    'work_type': 'Private',
    'Residence_type': 'Urban',
    'avg_glucose_level': 228.69,
    'bmi': 36.6,
    'smoking_status': 'formerly smoked'
}

result = predict_stroke(sample_patient)
print(f"\nPrediction: {result['prediction']}")
print(f"Stroke Probability: {result['stroke_probability']}")


Prediction: HIGH RISK
Stroke Probability: 86.14%
