In [1]:
import findspark
findspark.init()

import pyspark

In [2]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.sql.functions import when

In [3]:
spark = SparkSession.builder.appName(
    "HeartDiseaseClassification"
).getOrCreate()

In [4]:
# Read the data
df = spark.read.option(
    "delimiter", " "
).csv('heart.dat', inferSchema=True, header=False)

In [5]:
df.show(5)

+----+---+---+-----+-----+---+---+-----+---+---+----+----+----+----+
| _c0|_c1|_c2|  _c3|  _c4|_c5|_c6|  _c7|_c8|_c9|_c10|_c11|_c12|_c13|
+----+---+---+-----+-----+---+---+-----+---+---+----+----+----+----+
|70.0|1.0|4.0|130.0|322.0|0.0|2.0|109.0|0.0|2.4| 2.0| 3.0| 3.0|   2|
|67.0|0.0|3.0|115.0|564.0|0.0|2.0|160.0|0.0|1.6| 2.0| 0.0| 7.0|   1|
|57.0|1.0|2.0|124.0|261.0|0.0|0.0|141.0|0.0|0.3| 1.0| 0.0| 7.0|   2|
|64.0|1.0|4.0|128.0|263.0|0.0|0.0|105.0|1.0|0.2| 2.0| 1.0| 7.0|   1|
|74.0|0.0|2.0|120.0|269.0|0.0|2.0|121.0|1.0|0.2| 1.0| 1.0| 3.0|   1|
+----+---+---+-----+-----+---+---+-----+---+---+----+----+----+----+
only showing top 5 rows



In [6]:
# Rename the columns
new_column_names = [
    'year',
    'sex',
    'tPain',
    'restPressure', 
    'colesterol',
    'bloodSugarL120', 
    'electrocardioRest',
    'maxHeartRate', 
    'angina',
    'oldPeak',
    'stSlope', 
    'numVessels',
    'thal'
]
for i in range(len(new_column_names)):
    df = df.withColumnRenamed(df.columns[i], new_column_names[i])
df = df.drop(df.columns[-1])
df.show(5)

+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+
|year|sex|tPain|restPressure|colesterol|bloodSugarL120|electrocardioRest|maxHeartRate|angina|oldPeak|stSlope|numVessels|thal|
+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+
|70.0|1.0|  4.0|       130.0|     322.0|           0.0|              2.0|       109.0|   0.0|    2.4|    2.0|       3.0| 3.0|
|67.0|0.0|  3.0|       115.0|     564.0|           0.0|              2.0|       160.0|   0.0|    1.6|    2.0|       0.0| 7.0|
|57.0|1.0|  2.0|       124.0|     261.0|           0.0|              0.0|       141.0|   0.0|    0.3|    1.0|       0.0| 7.0|
|64.0|1.0|  4.0|       128.0|     263.0|           0.0|              0.0|       105.0|   1.0|    0.2|    2.0|       1.0| 7.0|
|74.0|0.0|  2.0|       120.0|     269.0|           0.0|              2.0|       121.0|   1.0|    0.2|    1.0|       1.

In [10]:
df.dtypes

[('year', 'double'),
 ('sex', 'double'),
 ('tPain', 'double'),
 ('restPressure', 'double'),
 ('colesterol', 'double'),
 ('bloodSugarL120', 'double'),
 ('electrocardioRest', 'double'),
 ('maxHeartRate', 'double'),
 ('angina', 'double'),
 ('oldPeak', 'double'),
 ('stSlope', 'double'),
 ('numVessels', 'double'),
 ('thal', 'double')]

In [11]:
df = df.withColumn(
    'sick',
    when((df['thal'] == 3) | (df['thal'] == 6), 0).otherwise(1)
)
df.show(5)

+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+
|year|sex|tPain|restPressure|colesterol|bloodSugarL120|electrocardioRest|maxHeartRate|angina|oldPeak|stSlope|numVessels|thal|sick|
+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+
|70.0|1.0|  4.0|       130.0|     322.0|           0.0|              2.0|       109.0|   0.0|    2.4|    2.0|       3.0| 3.0|   0|
|67.0|0.0|  3.0|       115.0|     564.0|           0.0|              2.0|       160.0|   0.0|    1.6|    2.0|       0.0| 7.0|   1|
|57.0|1.0|  2.0|       124.0|     261.0|           0.0|              0.0|       141.0|   0.0|    0.3|    1.0|       0.0| 7.0|   1|
|64.0|1.0|  4.0|       128.0|     263.0|           0.0|              0.0|       105.0|   1.0|    0.2|    2.0|       1.0| 7.0|   1|
|74.0|0.0|  2.0|       120.0|     269.0|           0.0|              2.0|       121

In [12]:
df = df.transform(
    lambda df: df.withColumn(
        'newSick',
        when((df['thal'] == 3) | (df['thal'] == 6), 0).otherwise(1))
)

df.show(5)

+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+-------+
|year|sex|tPain|restPressure|colesterol|bloodSugarL120|electrocardioRest|maxHeartRate|angina|oldPeak|stSlope|numVessels|thal|sick|newSick|
+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+-------+
|70.0|1.0|  4.0|       130.0|     322.0|           0.0|              2.0|       109.0|   0.0|    2.4|    2.0|       3.0| 3.0|   0|      0|
|67.0|0.0|  3.0|       115.0|     564.0|           0.0|              2.0|       160.0|   0.0|    1.6|    2.0|       0.0| 7.0|   1|      1|
|57.0|1.0|  2.0|       124.0|     261.0|           0.0|              0.0|       141.0|   0.0|    0.3|    1.0|       0.0| 7.0|   1|      1|
|64.0|1.0|  4.0|       128.0|     263.0|           0.0|              0.0|       105.0|   1.0|    0.2|    2.0|       1.0| 7.0|   1|      1|
|74.0|0.0|  2.0|       120.

In [13]:
# drop newSick column
df = df.drop('newSick')
df.show(5)

+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+
|year|sex|tPain|restPressure|colesterol|bloodSugarL120|electrocardioRest|maxHeartRate|angina|oldPeak|stSlope|numVessels|thal|sick|
+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+
|70.0|1.0|  4.0|       130.0|     322.0|           0.0|              2.0|       109.0|   0.0|    2.4|    2.0|       3.0| 3.0|   0|
|67.0|0.0|  3.0|       115.0|     564.0|           0.0|              2.0|       160.0|   0.0|    1.6|    2.0|       0.0| 7.0|   1|
|57.0|1.0|  2.0|       124.0|     261.0|           0.0|              0.0|       141.0|   0.0|    0.3|    1.0|       0.0| 7.0|   1|
|64.0|1.0|  4.0|       128.0|     263.0|           0.0|              0.0|       105.0|   1.0|    0.2|    2.0|       1.0| 7.0|   1|
|74.0|0.0|  2.0|       120.0|     269.0|           0.0|              2.0|       121

In [14]:
# Assemble all the features into a single vector
assembler = VectorAssembler(
    inputCols=[
        'year', 'sex', 'tPain', 'restPressure',
        'colesterol', 'bloodSugarL120', 'electrocardioRest',
        'maxHeartRate', 'angina', 'oldPeak', 'stSlope',
        'numVessels'
    ],
    outputCol='features'
)

df = assembler.transform(df)

In [15]:
df.show(5)

+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+--------------------+
|year|sex|tPain|restPressure|colesterol|bloodSugarL120|electrocardioRest|maxHeartRate|angina|oldPeak|stSlope|numVessels|thal|sick|            features|
+----+---+-----+------------+----------+--------------+-----------------+------------+------+-------+-------+----------+----+----+--------------------+
|70.0|1.0|  4.0|       130.0|     322.0|           0.0|              2.0|       109.0|   0.0|    2.4|    2.0|       3.0| 3.0|   0|[70.0,1.0,4.0,130...|
|67.0|0.0|  3.0|       115.0|     564.0|           0.0|              2.0|       160.0|   0.0|    1.6|    2.0|       0.0| 7.0|   1|[67.0,0.0,3.0,115...|
|57.0|1.0|  2.0|       124.0|     261.0|           0.0|              0.0|       141.0|   0.0|    0.3|    1.0|       0.0| 7.0|   1|[57.0,1.0,2.0,124...|
|64.0|1.0|  4.0|       128.0|     263.0|           0.0|              0.0|       105.0|  

In [16]:
# Select only the 'features' and 'new_column' for the model
model_data = df.select('features', 'sick')

# Rename 'new_column' to 'label' as required by MLlib
model_data = model_data.withColumnRenamed('sick', 'label')

# Split the data into training and test sets
train_data, test_data = model_data.randomSplit([0.7, 0.3])

In [17]:
# Create a Logistic Regression model and fit it to the training data
lr = LogisticRegression()
lr_model = lr.fit(train_data)

# Make predictions on the test data
predictions = lr_model.transform(test_data)

# Show some predictions
predictions.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|(12,[0,2,3,4,7,10...|    0|[3.29444650177848...|[0.96423780092264...|       0.0|
|[29.0,1.0,2.0,130...|    0|[1.68330619836882...|[0.84334183010437...|       0.0|
|[35.0,1.0,4.0,126...|    1|[0.04022173626572...|[0.51005407865560...|       0.0|
|[40.0,1.0,1.0,140...|    1|[0.98237922910892...|[0.72758005222614...|       0.0|
|[40.0,1.0,4.0,110...|    1|[-0.2407645079885...|[0.44009795836661...|       1.0|
|[40.0,1.0,4.0,152...|    1|[-0.2790060296920...|[0.43069747795044...|       1.0|
|[41.0,0.0,2.0,105...|    0|[3.82278970785327...|[0.97860120729678...|       0.0|
|[41.0,0.0,3.0,112...|    0|[3.10797251035206...|[0.95722040801919...|       0.0|
|[41.0,1.0,2.0,110...|    0|[1.77370328226339...|[0.85491760609190...|       0.0|
|[41.0,1.0,3.0,1

In [18]:
spark.stop()