In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, FloatType, LongType, StringType, DoubleType
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import col
from pyspark.ml import Pipeline, Transformer
from pyspark.ml.feature import StringIndexer, VectorAssembler, Imputer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import pyspark.sql.functions as F
from itertools import combinations
import os

## Check Python Path

In [2]:
import sys
sys.executable

'/tmp/demos/bin/python3'

In [3]:
DATA_FOLDER = "data"

NUMBER_OF_FOLDS = 3
SPLIT_SEED = 7576
TRAIN_TEST_SPLIT = 0.8

## Function for data reading

In [4]:

def read_data(spark: SparkSession) -> DataFrame:
    """
    read data; since the data has the header we let spark guess the schema
    """
    
    # Read the Titanic CSV data into a DataFrame
    titanic_data = spark.read \
        .format("csv") \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .load(os.path.join(DATA_FOLDER,"heart_disease.csv"))

    return titanic_data

## Writing new Transformer type class : adding cross product of features

In [5]:
class PairwiseProduct(Transformer):

    def __init__(self, inputCols, outputCols):
        self.__inputCols = inputCols
        self.__outputCols = outputCols

        self._paramMap = self._params = {}

    def _transform(self, df):
        for cols, out_col in zip(self.__inputCols, self.__outputCols):
            df = df.withColumn(out_col, col(cols[0]) * col(cols[1]))
        return df

## The ML pipeline

In [6]:


def pipeline(data: DataFrame):

    """
    every attribute that is numeric is non-categorical; this is questionable
    """

    numeric_features = [f.name for f in data.schema.fields if isinstance(f.dataType, (DoubleType, FloatType, IntegerType, LongType))]
    string_features = [f.name for f in data.schema.fields if isinstance(f.dataType, StringType)]
    #numeric_features.remove("PassengerId")
    #numeric_features.remove("Survived")
    #string_features.remove("Name")
    print(numeric_features)
    print(string_features)

    # Fill missing values for string columns with a placeholder before indexing
    data = data.fillna({col: 'null' for col in string_features})
    
    # Index string features
    indexed_string_columns = [f"{v}Index" for v in string_features]
    indexers = [StringIndexer(inputCol=col, outputCol=indexed_col, handleInvalid='keep') for col, indexed_col in zip(string_features, indexed_string_columns)]

    # Impute missing values for indexed string columns
    imputed_columns_string = [f"Imputed{v}" for v in indexed_string_columns]
    imputer_string = Imputer(inputCols=indexed_string_columns, outputCols=imputed_columns_string, strategy="mode")

    
    # numeric columns
    imputed_columns_numeric = [f"Imputed{v}" for v in numeric_features]
    imputer_numeric = Imputer(inputCols=numeric_features, outputCols=imputed_columns_numeric, strategy = "mean")

    # Create all pairwise products of numeric features
    #all_pairs = [v for v in combinations(imputed_columns_numeric, 2)]
    #pairwise_columns = [f"{col1}_{col2}" for col1, col2 in all_pairs]
    #pairwise_product = PairwiseProduct(inputCols=all_pairs, outputCols=pairwise_columns)

    # Assemble feature columns into a single feature vector
    assembler = VectorAssembler(
        inputCols=imputed_columns_numeric + imputed_columns_string, 
        outputCol="features"
        )

    # Create a list of pipeline stages
    stages = indexers + [imputer_string, imputer_numeric, assembler]
    
    # Create and fit the pipeline
    pipeline = Pipeline(stages=stages)
    model = pipeline.fit(data)
    
    # Transform the data
    transformed_data = model.transform(data)
    
    return transformed_data

    
    

    

In [7]:
def main():
    # Create a Spark session
    spark = SparkSession.builder \
        .appName("Predict Titanic Survival") \
        .getOrCreate()

    try:
        # Read data
        data = read_data(spark)
        
        # Print schema and preview the data
        data.printSchema()
        data.show(5)

        # Apply the pipeline
        transformed_data = pipeline(data)
        
        # Show the transformed data, including the imputed columns
        columns_to_show = [col for col in transformed_data.columns if col.startswith("Imputed")]
        transformed_data.select(columns_to_show).show(truncate=False)
        
    finally:
        # Stop the Spark session
        spark.stop()

In [8]:

main()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/26 16:33:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


root
 |-- age: string (nullable = true)
 |-- sex: integer (nullable = true)
 |-- painloc: integer (nullable = true)
 |-- painexer: integer (nullable = true)
 |-- relrest: integer (nullable = true)
 |-- pncaden: string (nullable = true)
 |-- cp: integer (nullable = true)
 |-- trestbps: integer (nullable = true)
 |-- htn: integer (nullable = true)
 |-- chol: integer (nullable = true)
 |-- smoke: integer (nullable = true)
 |-- cigs: integer (nullable = true)
 |-- years: integer (nullable = true)
 |-- fbs: integer (nullable = true)
 |-- dm: integer (nullable = true)
 |-- famhist: integer (nullable = true)
 |-- restecg: integer (nullable = true)
 |-- ekgmo: integer (nullable = true)
 |-- ekgday(day: integer (nullable = true)
 |-- ekgyr: integer (nullable = true)
 |-- dig: integer (nullable = true)
 |-- prop: integer (nullable = true)
 |-- nitr: integer (nullable = true)
 |-- pro: integer (nullable = true)
 |-- diuretic: integer (nullable = true)
 |-- proto: integer (nullable = true)
 |-- th

24/05/26 16:33:49 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+---+---+-------+--------+-------+-------+---+--------+---+----+-----+----+-----+---+----+-------+-------+-----+----------+-----+---+----+----+---+--------+-----+-------+--------+----+-------+--------+--------+--------+-----+--------+-----+-----+-------+-----+-----+------+---+-------+-------+------+------+------+------+----+-------+-------+-------+---+----+---+------+
|age|sex|painloc|painexer|relrest|pncaden| cp|trestbps|htn|chol|smoke|cigs|years|fbs|  dm|famhist|restecg|ekgmo|ekgday(day|ekgyr|dig|prop|nitr|pro|diuretic|proto|thaldur|thaltime| met|thalach|thalrest|tpeakbps|tpeakbpd|dummy|trestbpd|exang|xhypo|oldpeak|slope|rldv5|rldv5e| ca|restckm|exerckm|restef|restwm|exeref|exerwm|thal|thalsev|thalpul|earlobe|cmo|cday|cyr|target|
+---+---+-------+--------+-------+-------+---+--------+---+----+-----+----+-----+---+----+-------+-------+-----+----------+-----+---+----+----+---+--------+-----+-------+--------+----+-------+--------+--------+--------+-----+--------+-----+-----+-------+----

                                                                                

+---------------+-------------------+-------------------+----------+--------------+---------------+--------------+---------+---------------+----------+-----------+------------+-----------+------------+----------+---------+--------------+--------------+------------+-----------------+------------+----------+-----------+-----------+----------+---------------+------------+--------------+-----------------+----------+--------------+---------------+---------------+---------------+------------+---------------+------------+------------+--------------+------------+------------+-------------+---------+--------------+------------------+-------------+-------------+-------------+-----------+--------------+--------------+--------------+----------+-----------+----------+-------------+
|ImputedageIndex|ImputedpncadenIndex|ImputedrestckmIndex|Imputedsex|Imputedpainloc|Imputedpainexer|Imputedrelrest|Imputedcp|Imputedtrestbps|Imputedhtn|Imputedchol|Imputedsmoke|Imputedcigs|Imputedyears|Imputedfbs|Imputedd