## Library

In [268]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, Imputer, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

## SparkSession

In [269]:
spark = SparkSession.builder.appName("UASPractice").getOrCreate()
spark.version

'3.5.0'

# Load Dataset

In [270]:
df = spark.read.csv("Heart_Disease_Prediction.csv", header = True, inferSchema = True)

# Exploratory Data Analysis (EDA)

In [271]:
df.show(20)

+---+---+---------------+---+-----------+------------+-----------+------+---------------+-------------+-----------+-----------------------+--------+-------------+
|Age|Sex|Chest pain type| BP|Cholesterol|FBS over 120|EKG results|Max HR|Exercise angina|ST depression|Slope of ST|Number of vessels fluro|Thallium|Heart Disease|
+---+---+---------------+---+-----------+------------+-----------+------+---------------+-------------+-----------+-----------------------+--------+-------------+
| 70|  1|              4|130|        322|           0|          2|   109|              0|          2.4|          2|                      3|       3|     Presence|
| 67|  0|              3|115|        564|           0|          2|   160|              0|          1.6|          2|                      0|       7|      Absence|
| 57|  1|              2|124|        261|           0|          0|   141|              0|          0.3|          1|                      0|       7|     Presence|
| 64|  1|             

In [234]:
# after replace Absence and Presence
df.show(20)

+---+---+---------------+---+-----------+------------+-----------+------+---------------+-------------+-----------+-----------------------+--------+-------------+
|Age|Sex|Chest pain type| BP|Cholesterol|FBS over 120|EKG results|Max HR|Exercise angina|ST depression|Slope of ST|Number of vessels fluro|Thallium|Heart Disease|
+---+---+---------------+---+-----------+------------+-----------+------+---------------+-------------+-----------+-----------------------+--------+-------------+
| 70|  1|              4|130|        322|           0|          2|   109|              0|          2.4|          2|                      3|       3|            1|
| 67|  0|              3|115|        564|           0|          2|   160|              0|          1.6|          2|                      0|       7|            0|
| 57|  1|              2|124|        261|           0|          0|   141|              0|          0.3|          1|                      0|       7|            1|
| 64|  1|             

Age : Age of the patient (in years)<br>
Sex : Gender of the patient (1 = male, 0 = female)<br>
Chest pain type : Type of chest pain experienced by the patient<br>
BP : Resting blood pressure (mm Hg)<br>
Cholesterol : Serum cholesterol level (mg/dl)<br>
FBS over 120 :  Fasting blood sugar > 120 mg/dl (1 = yes, 0 = no)<br>
EKG results : Resting electrocardiogram results<br>
Max HR : Maximum heart rate achieved during exercise<br>
Exercise angina : Exercise-induced angina (1 = yes, 0 = no)<br>
ST depression : ST depression induced by exercise relative to rest<br>
Slope of ST : No Explanation<br>
Number of vessels fluro : No Explanation<br>
Thallium : No Explanation<br>
Heart Disease : Absence and Presence<br>

## Inconsistence Data

In [272]:
df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Sex: integer (nullable = true)
 |-- Chest pain type: integer (nullable = true)
 |-- BP: integer (nullable = true)
 |-- Cholesterol: integer (nullable = true)
 |-- FBS over 120: integer (nullable = true)
 |-- EKG results: integer (nullable = true)
 |-- Max HR: integer (nullable = true)
 |-- Exercise angina: integer (nullable = true)
 |-- ST depression: double (nullable = true)
 |-- Slope of ST: integer (nullable = true)
 |-- Number of vessels fluro: integer (nullable = true)
 |-- Thallium: integer (nullable = true)
 |-- Heart Disease: string (nullable = true)



In [273]:
df.columns

['Age',
 'Sex',
 'Chest pain type',
 'BP',
 'Cholesterol',
 'FBS over 120',
 'EKG results',
 'Max HR',
 'Exercise angina',
 'ST depression',
 'Slope of ST',
 'Number of vessels fluro',
 'Thallium',
 'Heart Disease']

In [274]:
# df.select("Heart Disease").cast("int")
df = df.withColumn("Heart Disease",
                   when(col("Heart Disease") == "Absence", 0)
                   .when(col("Heart Disease") == "Presence", 1))

In [275]:
df.select("Heart Disease").show()

+-------------+
|Heart Disease|
+-------------+
|            1|
|            0|
|            1|
|            0|
|            0|
|            0|
|            1|
|            1|
|            1|
|            1|
|            0|
|            0|
|            0|
|            1|
|            0|
|            0|
|            1|
|            1|
|            0|
|            0|
+-------------+
only showing top 20 rows



In [276]:
df = df.withColumn("Chest pain type", col("Chest pain type").cast("string"))
df = df.withColumn("EKG results", col("EKG results").cast("string"))
df = df.withColumn("Slope of ST", col("Slope of ST").cast("string"))
df = df.withColumn("Number of vessels fluro", col("Number of vessels fluro").cast("string"))
df = df.withColumn("Thallium", col("Thallium").cast("string"))

In [277]:
df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Sex: integer (nullable = true)
 |-- Chest pain type: string (nullable = true)
 |-- BP: integer (nullable = true)
 |-- Cholesterol: integer (nullable = true)
 |-- FBS over 120: integer (nullable = true)
 |-- EKG results: string (nullable = true)
 |-- Max HR: integer (nullable = true)
 |-- Exercise angina: integer (nullable = true)
 |-- ST depression: double (nullable = true)
 |-- Slope of ST: string (nullable = true)
 |-- Number of vessels fluro: string (nullable = true)
 |-- Thallium: string (nullable = true)
 |-- Heart Disease: integer (nullable = true)



## ValuesCount

### Numerical Features

In [278]:
df.groupBy("Age").count().show()

+---+-----+
|Age|count|
+---+-----+
| 65|    8|
| 53|    7|
| 34|    2|
| 76|    1|
| 44|   10|
| 47|    4|
| 52|   11|
| 40|    3|
| 57|   12|
| 54|   16|
| 48|    7|
| 64|    9|
| 41|    9|
| 43|    7|
| 37|    2|
| 61|    7|
| 35|    3|
| 59|   12|
| 55|    6|
| 39|    3|
+---+-----+
only showing top 20 rows



In [279]:
df.groupBy("Sex").count().show()

+---+-----+
|Sex|count|
+---+-----+
|  1|  183|
|  0|   87|
+---+-----+



In [280]:
df.groupBy("BP").count().show()

+---+-----+
| BP|count|
+---+-----+
|148|    1|
|108|    6|
|155|    1|
|115|    3|
|101|    1|
|126|    3|
|192|    1|
|128|    9|
|122|    3|
|140|   30|
|132|    6|
|152|    4|
|146|    1|
|142|    3|
|178|    2|
| 94|    2|
|120|   34|
|117|    1|
|112|    9|
|165|    1|
+---+-----+
only showing top 20 rows



In [281]:
df.groupBy("Cholesterol").count().show()

+-----------+-----+
|Cholesterol|count|
+-----------+-----+
|        243|    4|
|        255|    2|
|        322|    1|
|        321|    1|
|        211|    4|
|        193|    1|
|        126|    1|
|        210|    1|
|        183|    1|
|        300|    1|
|        271|    2|
|        192|    1|
|        253|    1|
|        236|    2|
|        223|    2|
|        417|    1|
|        409|    1|
|        222|    2|
|        209|    2|
|        330|    2|
+-----------+-----+
only showing top 20 rows



In [282]:
df.groupBy("FBS over 120").count().show()

+------------+-----+
|FBS over 120|count|
+------------+-----+
|           1|   40|
|           0|  230|
+------------+-----+



In [283]:
df.groupBy("Max HR").count().show()

+------+-----+
|Max HR|count|
+------+-----+
|   148|    3|
|   137|    1|
|   133|    2|
|   155|    3|
|   108|    2|
|   126|    4|
|   115|    1|
|   159|    4|
|   192|    1|
|   103|    2|
|   128|    1|
|   122|    4|
|   157|    5|
|   190|    1|
|   111|    3|
|   140|    5|
|   177|    1|
|   152|    6|
|   132|    6|
|   185|    1|
+------+-----+
only showing top 20 rows



In [284]:
df.groupBy("Exercise angina").count().show()

+---------------+-----+
|Exercise angina|count|
+---------------+-----+
|              1|   89|
|              0|  181|
+---------------+-----+



In [285]:
df.groupBy("ST depression").count().show()

+-------------+-----+
|ST depression|count|
+-------------+-----+
|          2.4|    3|
|          0.0|   85|
|          3.5|    1|
|          0.2|   11|
|          2.9|    1|
|          1.4|   13|
|          0.7|    1|
|          2.3|    2|
|          0.1|    6|
|          3.4|    2|
|          2.5|    2|
|          1.0|   12|
|          0.6|   12|
|          3.1|    1|
|          0.8|   11|
|          2.2|    4|
|          2.8|    4|
|          4.0|    2|
|          1.9|    5|
|          6.2|    1|
+-------------+-----+
only showing top 20 rows



### Categorical Features

In [286]:
df.groupBy("Chest pain type").count().show()

+---------------+-----+
|Chest pain type|count|
+---------------+-----+
|              3|   79|
|              1|   20|
|              4|  129|
|              2|   42|
+---------------+-----+



In [287]:
df.groupBy("EKG results").count().show()

+-----------+-----+
|EKG results|count|
+-----------+-----+
|          0|  131|
|          1|    2|
|          2|  137|
+-----------+-----+



In [288]:
df.groupBy("Slope of ST").count().show()

+-----------+-----+
|Slope of ST|count|
+-----------+-----+
|          3|   18|
|          1|  130|
|          2|  122|
+-----------+-----+



In [289]:
df.groupBy("Number of vessels fluro").count().show()

+-----------------------+-----+
|Number of vessels fluro|count|
+-----------------------+-----+
|                      3|   19|
|                      0|  160|
|                      1|   58|
|                      2|   33|
+-----------------------+-----+



In [290]:
df.groupBy("Thallium").count().show()

+--------+-----+
|Thallium|count|
+--------+-----+
|       7|  104|
|       3|  152|
|       6|   14|
+--------+-----+



### Labels

In [291]:
df.groupBy("Heart Disease").count().show()

+-------------+-----+
|Heart Disease|count|
+-------------+-----+
|            1|  120|
|            0|  150|
+-------------+-----+



# StringIndexer

In [292]:
categorical_cols = [
    "Chest pain type", 
    "EKG results", 
    "Slope of ST", 
    "Number of vessels fluro", 
    "Thallium"]

In [293]:
indexers = [
    StringIndexer(
        inputCol = c,
        outputCol = f"{c}_idx"
    )for c in categorical_cols
]

# OneHotEncoder (ohe)

In [294]:
ohe = [
    OneHotEncoder(
        inputCol = f"{c}_idx",
        outputCol = f"{c}_ohe"
    ) for c in categorical_cols
]

# Imputer (handle null)

In [295]:
numerical_cols = [
    "Age", "Sex", "BP", 
    "Cholesterol", "FBS over 120", "Max HR", 
    "Exercise angina", "ST depression"
]

In [296]:
imputer = Imputer(
    inputCols = numerical_cols,
    outputCols = numerical_cols
)

# VectorAssembler

In [297]:
feature_cols = numerical_cols + [f"{c}_ohe" for c in categorical_cols]

assembler = VectorAssembler(
    inputCols = feature_cols,
    outputCol = "features_raw"
)

# Scaler (wajib untuk Logistic Regression)

In [298]:
scaler = StandardScaler(
    inputCol = "features_raw",
    outputCol = "features",
    withMean = True,
    withStd = True
)

# Model (Logistic Regression)

In [299]:
lr = LogisticRegression(
    featuresCol = "features",
    labelCol = "Heart Disease"
)

# Pipeline

In [300]:
pipeline = Pipeline(
    stages = indexers + ohe + [imputer, assembler, scaler, lr]
)

# Split Train and Test Data

In [301]:
train_df, test_df = df.randomSplit([0.8, 0.2], seed = 42)

In [302]:
train_df.show(10)
test_df.show(10)

+---+---+---------------+---+-----------+------------+-----------+------+---------------+-------------+-----------+-----------------------+--------+-------------+
|Age|Sex|Chest pain type| BP|Cholesterol|FBS over 120|EKG results|Max HR|Exercise angina|ST depression|Slope of ST|Number of vessels fluro|Thallium|Heart Disease|
+---+---+---------------+---+-----------+------------+-----------+------+---------------+-------------+-----------+-----------------------+--------+-------------+
| 29|  1|              2|130|        204|           0|          2|   202|              0|          0.0|          1|                      0|       3|            0|
| 34|  0|              2|118|        210|           0|          0|   192|              0|          0.7|          1|                      0|       3|            0|
| 35|  0|              4|138|        183|           0|          0|   182|              0|          1.4|          1|                      0|       3|            0|
| 35|  1|             

# Model Fit

In [303]:
model = pipeline.fit(train_df)

In [304]:
predictions = model.transform(test_df)

# Model Evaluation

## ROC - AUC

In [305]:
evaluator = BinaryClassificationEvaluator(
    labelCol = "Heart Disease"
)

evaluator.evaluate(predictions)

0.8823529411764707

## Confusion Matrix

In [306]:
predictions.groupBy("Heart Disease", "prediction").count().show()

+-------------+----------+-----+
|Heart Disease|prediction|count|
+-------------+----------+-----+
|            1|       0.0|    5|
|            0|       0.0|   22|
|            1|       1.0|   12|
|            0|       1.0|    3|
+-------------+----------+-----+



## Accuracy

In [307]:
acc = MulticlassClassificationEvaluator(
    labelCol="Heart Disease",
    metricName="accuracy"
)

acc.evaluate(predictions)

0.8095238095238095

## Presicion

In [308]:
precision = MulticlassClassificationEvaluator(
    labelCol="Heart Disease",
    metricName="weightedPrecision"
)

precision.evaluate(predictions)

0.8088183421516755

## Recall

In [309]:
recall = MulticlassClassificationEvaluator(
    labelCol="Heart Disease",
    metricName="weightedRecall"
)

recall.evaluate(predictions)

0.8095238095238095

## F1-Score

In [310]:
f1 = MulticlassClassificationEvaluator(
    labelCol="Heart Disease",
    metricName="f1"
)

f1.evaluate(predictions)

0.8072344322344323

## PR - AUC

In [311]:
pr_auc = BinaryClassificationEvaluator(
    labelCol="Heart Disease",
    metricName="areaUnderPR"
)

pr_auc.evaluate(predictions)

0.7352011804575092

## Log Loss

In [312]:
predictions.select("probability", "Heart Disease").show(5, truncate = False)

+------------------------------------------+-------------+
|probability                               |Heart Disease|
+------------------------------------------+-------------+
|[0.9959428720408537,0.004057127959146256] |0            |
|[0.9989566461040099,0.0010433538959900845]|0            |
|[0.7287166887496469,0.2712833112503531]   |1            |
|[0.1020891289452451,0.8979108710547549]   |1            |
|[0.9743068793894551,0.02569312061054485]  |0            |
+------------------------------------------+-------------+
only showing top 5 rows

