In [1]:
# Install Java
!apt-get install openjdk-11-jdk-headless -qq > /dev/null

# Download and install Spark
!wget -q https://downloads.apache.org/spark/spark-3.5.1/spark-3.5.1-bin-hadoop3.tgz
!tar xf spark-3.5.1-bin-hadoop3.tgz

# Install findspark, a Python library that makes it easier for Python to find Spark
!pip install -q findspark

# Set environment variables
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.1-bin-hadoop3"

In [2]:
# Install PySpark directly
!pip install pyspark==3.5.1

# Now you can import PySpark directly without needing findspark
from pyspark.sql import SparkSession

# Create a Spark session
spark = SparkSession.builder.master("local[*]").appName("Colab").getOrCreate()

Collecting pyspark==3.5.1
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488491 sha256=3c3c048413c3a2cc888f2f51441a8a629b17f67bffcca0d6afd65f96257c33a8
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1


In [37]:
#Importing libraries
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.feature import Imputer
from pyspark.sql.functions import when, count, broadcast
from pyspark.sql.functions import col, expr, lit,udf
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, FloatType
from pyspark.sql.types import StructType, StructField, StringType, FloatType
import math


In [5]:
# Ensure that the Spark session is created (as previously shown)
spark = SparkSession.builder.master("local[*]").appName("HeartDiseaseAnalysis").getOrCreate()

# Read the CSV file into a DataFrame
df1 = spark.read.csv("heart_disease(in).csv", header=True, inferSchema=True)

# Show the first few rows of the DataFrame to verify it's loaded correctly
df1.show()

+---+---+-------+--------+-------+-------+---+--------+---+----+-----+----+-----+---+----+-------+-------+-----+----------+-----+---+----+----+---+--------+-----+-------+--------+----+-------+--------+--------+--------+-----+--------+-----+-----+-------+-----+-----+------+---+-------+-------+------+------+------+------+----+-------+-------+-------+---+----+---+------+
|age|sex|painloc|painexer|relrest|pncaden| cp|trestbps|htn|chol|smoke|cigs|years|fbs|  dm|famhist|restecg|ekgmo|ekgday(day|ekgyr|dig|prop|nitr|pro|diuretic|proto|thaldur|thaltime| met|thalach|thalrest|tpeakbps|tpeakbpd|dummy|trestbpd|exang|xhypo|oldpeak|slope|rldv5|rldv5e| ca|restckm|exerckm|restef|restwm|exeref|exerwm|thal|thalsev|thalpul|earlobe|cmo|cday|cyr|target|
+---+---+-------+--------+-------+-------+---+--------+---+----+-----+----+-----+---+----+-------+-------+-----+----------+-----+---+----+----+---+--------+-----+-------+--------+----+-------+--------+--------+--------+-----+--------+-----+-----+-------+----

In [6]:
columns_to_keep = [
    'age', 'sex', 'painloc', 'painexer', 'cp', 'trestbps', 'smoke', 'fbs',
    'prop', 'nitr', 'pro', 'diuretic', 'thaldur', 'thalach', 'exang',
    'oldpeak', 'slope', 'target'
]
# Select only the specified columns
df = df1.select(*columns_to_keep)

# Show the DataFrame to verify the columns
df.show()

+---+---+-------+--------+---+--------+-----+---+----+----+---+--------+-------+-------+-----+-------+-----+------+
|age|sex|painloc|painexer| cp|trestbps|smoke|fbs|prop|nitr|pro|diuretic|thaldur|thalach|exang|oldpeak|slope|target|
+---+---+-------+--------+---+--------+-----+---+----+----+---+--------+-------+-------+-----+-------+-----+------+
| 63|  1|   NULL|    NULL|  1|     145| NULL|  1|   0|   0|  0|       0|   10.5|    150|    0|    2.3|    3|     0|
| 67|  1|   NULL|    NULL|  4|     160| NULL|  0|   1|   0|  0|       0|    9.5|    108|    1|    1.5|    2|     1|
| 67|  1|   NULL|    NULL|  4|     120| NULL|  0|   1|   0|  0|       0|    8.5|    129|    1|    2.6|    2|     1|
| 37|  1|   NULL|    NULL|  3|     130| NULL|  0|   1|   0|  0|       0|   13.0|    187|    0|    3.5|    3|     0|
| 41|  0|   NULL|    NULL|  2|     130| NULL|  0|   0|   0|  0|       0|    7.0|    172|    0|    1.4|    1|     0|
| 56|  1|   NULL|    NULL|  2|     120| NULL|  0|   0|   0|  0|       0|

Removing rows whose value of target is null first

In [7]:
# Drop rows where the 'target' column has missing values
df = df.na.drop(subset=["target"])


Imputing trestbps using median imputation since it's a numerical value

In [8]:
# Filter the DataFrame to include only rows where 'trestbps' is greater than or equal to 100
# Then calculate the median of these filtered values
median_df = df.filter(col('trestbps') >= 100)
replacement_value = median_df.approxQuantile('trestbps', [0.5], 0.01)[0]  # 0.01 is the relative error, adjust as needed
# Replace 'trestbps' values less than 100 or null with the median, otherwise keep original
df = df.withColumn('trestbps', when((col('trestbps') < 100) | col('trestbps').isNull(), replacement_value).otherwise(col('trestbps')))
df.select('trestbps').show()

+--------+
|trestbps|
+--------+
|   145.0|
|   160.0|
|   120.0|
|   130.0|
|   130.0|
|   120.0|
|   140.0|
|   120.0|
|   130.0|
|   140.0|
|   140.0|
|   140.0|
|   130.0|
|   120.0|
|   172.0|
|   150.0|
|   110.0|
|   140.0|
|   130.0|
|   130.0|
+--------+
only showing top 20 rows



Imputing oldpeak, excluding the thresholds

In [9]:
# Filter the DataFrame to include only rows where 'oldpeak' is between 0 and 4
valid_range_df = df.filter((col('oldpeak') >= 0) & (col('oldpeak') <= 4))

# Calculate the median of these filtered values
replacement_value = valid_range_df.approxQuantile('oldpeak', [0.5], 0.01)[0]
# Replace 'oldpeak' values outside the range 0-4 or null with the median, otherwise keep original
df = df.withColumn('oldpeak', when((col('oldpeak') < 0) | (col('oldpeak') > 4) | col('oldpeak').isNull(), replacement_value).otherwise(col('oldpeak')))
# Show the updated column to verify the replacement
df.select('oldpeak').show()

+-------+
|oldpeak|
+-------+
|    2.3|
|    1.5|
|    2.6|
|    3.5|
|    1.4|
|    0.8|
|    3.6|
|    0.6|
|    1.4|
|    3.1|
|    0.4|
|    1.3|
|    0.6|
|    0.0|
|    0.5|
|    1.6|
|    1.0|
|    1.2|
|    0.2|
|    0.6|
+-------+
only showing top 20 rows



Imputing thaldur and thalach using median imputation (replacing in the null values)

In [10]:
# Calculate the median for 'thaldur'
thaldur_median = df.approxQuantile('thaldur', [0.5], 0.01)[0]

# Calculate the median for 'thalach'
thalach_median = df.approxQuantile('thalach', [0.5], 0.01)[0]
# Replace missing values in 'thaldur' with the median
df = df.na.fill({'thaldur': thaldur_median})

# Replace missing values in 'thalach' with the median
df = df.na.fill({'thalach': thalach_median})


Imputing fbs, prop, nitr, pro, diuretic by
replacing the missing values and values greater than 1.

In [11]:
def replace_values(df, column):
    """ Replace values greater than 1 with 1 in the specified column """
    return df.withColumn(column, when(col(column) > 1, 1).otherwise(col(column)))

def calculate_mode(df, column):
    """ Calculate the mode (most frequent value) of the specified column """
    mode_value = df.groupBy(column).count().orderBy('count', ascending=False).first()[0]
    return mode_value

columns = ['fbs', 'prop', 'nitr', 'pro', 'diuretic']
for column in columns:
    # Replace values greater than 1 with 1
    df = replace_values(df, column)

    # Calculate the mode of the column
    mode_value = calculate_mode(df, column)

    # Replace missing values with the mode
    df = df.na.fill({column: mode_value})

Imputing exang and slope by replacing the missing values with the mode.

In [12]:
def calculate_mode(df, column_name):
    """ Calculate the mode (most common value) for the specified column """
    # Group by the column and count occurrences, then order by count descending and take the first
    mode_value = df.groupBy(column_name).count().orderBy(col("count").desc()).first()[0]
    return mode_value

# Calculate mode for 'exang'
exang_mode = calculate_mode(df, 'exang')

# Calculate mode for 'slope'
slope_mode = calculate_mode(df, 'slope')

# Replace missing values with the mode
from pyspark.sql.functions import when
df = df.withColumn("exang", when(col("exang").isNull(), exang_mode).otherwise(col("exang")))
df = df.withColumn("slope", when(col("slope").isNull(), slope_mode).otherwise(col("slope")))


Cleaning out the age column so that it can be used for finding out the corresponding age range.

We will use median imputation

In [13]:
# Attempt to cast the 'age' column to double type (float in Spark SQL)
# Invalid entries will be turned into nulls
df = df.withColumn("age", col("age").cast(DoubleType()))
# Calculate the median age, ignoring nulls by default
median_age = df.approxQuantile("age", [0.5], 0.01)[0]
# Fill NaN values in the 'age' column with the median age
df = df.na.fill({"age": median_age})



Scraping for the smoke column

Using https://www.abs.gov.au/statistics/health/health-conditions-and-risks/smoking/latest-release

In [18]:


# URL of the webpage
url = 'https://www.abs.gov.au/statistics/health/health-conditions-and-risks/smoking/latest-release'

# Send GET request
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')

# Find all tables with a specific caption
tables = soup.find_all('table', class_='responsive-enabled')
table = None
for tbl in tables:
    caption = tbl.find('caption')
    if caption and 'Proportion of people 15 years and over who were current daily smokers by age, 2011–12 to 2022' in caption.text:
        table = tbl
        break

# Check if the correct table is found
if not table:
    print("The specified table was not found.")
else:
    # Initialize a list to store each row's data
    data = []

    # Iterate over each row in the table body
    for row in table.find_all('tr')[1:]:  # Skip the header row
        cols = row.find_all('td')
        age_group = row.th.text.strip() if row.th else 'No Age Group'  # Ensuring to get the age group if available
        row_data = [age_group] + [col.text.strip() for col in cols]
        data.append(row_data)

    # Define column headers based on the table's structure
    columns = ['Age Group', '2011-12 (%)', 'CI Low (2011-12)', 'CI High (2011-12)',
               '2014-15 (%)', 'CI Low (2014-15)', 'CI High (2014-15)',
               '2017-18 (%)', 'CI Low (2017-18)', 'CI High (2017-18)',
               '2022 (%)', 'CI Low (2022)', 'CI High (2022)']

    # Create a DataFrame
    smoking_data_df = pd.DataFrame(data, columns=columns)

    # Display the first few rows of the DataFrame to check
    print(smoking_data_df['Age Group'],smoking_data_df['2022 (%)'])
    print (type(smoking_data_df['Age Group'][1]))
    print (type(smoking_data_df['2022 (%)'][1]))


0                15–17
1                18–24
2                25–34
3                35–44
4                45–54
5                55–64
6                65–74
7    75 years and over
Name: Age Group, dtype: object 0     1.6
1     7.3
2    10.9
3    10.9
4    13.8
5    14.9
6     8.7
7     2.9
Name: 2022 (%), dtype: object
<class 'str'>
<class 'str'>
Column<'smoke'>


In [21]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType, FloatType

# Initialize Spark Session
spark = SparkSession.builder.appName("Smoking Data Analysis").getOrCreate()

# Convert pandas DataFrame to Spark DataFrame if not already done
smoking_data_spark_df = spark.createDataFrame(smoking_data_df)

# Define a UDF to replace NaN in 'smoke' based on age groups
@udf(StringType())  # Specify the return type of the UDF
def get_smoking_rate(age, smoke):
    age_groups = broadcast_age_groups.value
    if smoke is None:
        age = int(age)
        for age_range, rate in age_groups.items():
            if age_range == '75 years and over':
                if age >= 75:
                    return rate
            else:
                low, high = map(int, age_range.split('–'))
                if low <= age <= high:
                    return rate
        return None
    else:
        return smoke

# Collect age groups and rates into a dictionary for broadcast to workers
age_groups_rates = {row['Age Group']: row['2022 (%)'] for row in smoking_data_spark_df.collect()}
broadcast_age_groups = spark.sparkContext.broadcast(age_groups_rates)

# Apply the UDF to the DataFrame
df = df.withColumn("smoke_abs", get_smoking_rate(col("age"), col("smoke")))

# Show the result
df.select("age", "smoke", "smoke_abs").show()


+----+-----+---------+
| age|smoke|smoke_abs|
+----+-----+---------+
|63.0| NULL|     14.9|
|67.0| NULL|      8.7|
|67.0| NULL|      8.7|
|37.0| NULL|     10.9|
|41.0| NULL|     10.9|
|56.0| NULL|     14.9|
|62.0| NULL|     14.9|
|57.0| NULL|     14.9|
|63.0| NULL|     14.9|
|53.0| NULL|     13.8|
|57.0| NULL|     14.9|
|56.0| NULL|     14.9|
|56.0| NULL|     14.9|
|44.0| NULL|     10.9|
|52.0| NULL|     13.8|
|57.0| NULL|     14.9|
|48.0| NULL|     13.8|
|54.0| NULL|     13.8|
|48.0| NULL|     13.8|
|49.0| NULL|     13.8|
+----+-----+---------+
only showing top 20 rows



Imputing the sex column using mode imputation

In [22]:
# Calculate the mode of the 'sex' column
mode_sex_df = df.groupBy("sex").count().orderBy(col("count").desc()).limit(1)
mode_sex = mode_sex_df.collect()[0]['sex']
# Replace missing values in the 'sex' column with the mode
df = df.na.fill({'sex': mode_sex})

Scraping CDC data by sex

In [24]:
# URL of the webpage
url = 'https://www.cdc.gov/tobacco/data_statistics/fact_sheets/adult_data/cig_smoking/index.htm'

# Send GET request
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')

# Initialize a dictionary to store the gender-specific smoking rates
smoking_rates_by_gender_cdc = {}

# Find the relevant card containing smoking data by gender
cards = soup.find_all('div', class_='card-body')
gender_card = next((card for card in cards if "Current cigarette smoking was higher among men than women" in card.text), None)

# Process the relevant data if found
if gender_card:
    list_items = gender_card.find_all('li')
    for item in list_items:
        text = item.get_text().strip()
        gender = 'men' if 'adult men' in text else 'women'
        percentage = text.split('(')[-1].rstrip('%)').strip()
        smoking_rates_by_gender_cdc[gender] = float(percentage)
else:
    print("Gender-specific smoking section not found.")

# Output the dictionary to verify the extracted data
print(smoking_rates_by_gender_cdc)

{'men': 13.1, 'women': 10.1}


In [26]:
def scrape_smoking_rates(url, search_text):
    response = requests.get(url)
    soup = BeautifulSoup(response.text, 'html.parser')
    smoking_rates = {}

    cards = soup.find_all('div', class_='card-body')
    age_card = next((card for card in cards if search_text in card.text), None)

    if age_card:
        list_items = age_card.find_all('li')
        for item in list_items:
            text = item.get_text().strip()
            age_range = text.split('adults aged')[1].split('(')[0].strip()
            percentage = text.split('(')[-1].rstrip('%)').strip()
            smoking_rates[age_range] = float(percentage)
    else:
        print(f"{search_text} section not found.")

    return smoking_rates

url = 'https://www.cdc.gov/tobacco/data_statistics/fact_sheets/adult_data/cig_smoking/index.htm'
search_text = "Current cigarette smoking was highest among"
smoking_rates_by_age_cdc = scrape_smoking_rates(url, search_text)
print(smoking_rates_by_age_cdc)

{'18–24 years': 5.3, '25–44 years': 12.6, '45–64 years': 14.9, '65 years and older': 8.3}


In [31]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import FloatType, StringType

# Initialize Spark Session
spark = SparkSession.builder.appName("Smoking Data Update").getOrCreate()

# Define a function to categorize age into age groups
def categorize_age_group(age):
    if age < 25:
        return '18–24 years'
    elif 25 <= age < 45:
        return '25–44 years'
    elif 45 <= age < 65:
        return '45–64 years'
    else:
        return '65 years and older'

# Assuming the dictionaries are provided like this:
smoking_rates_by_gender_cdc = {'men': 13.1, 'women': 10.1}
smoking_rates_by_age_cdc = {'18–24 years': 5.3, '25–44 years': 12.6, '45–64 years': 14.9, '65 years and older': 8.3}

# Broadcast the dictionaries
broadcast_smoking_rates_by_gender = spark.sparkContext.broadcast(smoking_rates_by_gender_cdc)
broadcast_smoking_rates_by_age = spark.sparkContext.broadcast(smoking_rates_by_age_cdc)

# Define the UDF to replace missing values based on age and gender
def replace_missing_smoke_rates(age, sex, smoke):
    age_group = categorize_age_group(age)  # Directly use the categorization here
    if smoke is None:
        gender_key = 'men' if sex == 1 else 'women'
        age_rate = broadcast_smoking_rates_by_age.value.get(age_group, 0)  # Default to 0 if not found
        if sex == 1:
            # Calculate the adjusted rate for men
            men_rate = broadcast_smoking_rates_by_gender.value.get('men', 1)
            women_rate = broadcast_smoking_rates_by_gender.value.get('women', 1)
            return age_rate * (men_rate / women_rate)
        else:
            # Use the age rate directly for women
            return age_rate
    else:
        return smoke

# Register the UDF with FloatType return type
replace_missing_smoke_rates_udf = udf(replace_missing_smoke_rates, FloatType())

# Update the DataFrame with the new 'smoke_updated' column
df = df.withColumn("smoke_cdc", replace_missing_smoke_rates_udf(col("age"), col("sex"), col("smoke")))

# Show the result
df.select("age", "sex", "smoke", "smoke_cdc").show()


+----+---+-----+---------+
| age|sex|smoke|smoke_cdc|
+----+---+-----+---------+
|63.0|  1| NULL|19.325743|
|67.0|  1| NULL|10.765347|
|67.0|  1| NULL|10.765347|
|37.0|  1| NULL|16.342575|
|41.0|  0| NULL|     12.6|
|56.0|  1| NULL|19.325743|
|62.0|  0| NULL|     14.9|
|57.0|  0| NULL|     14.9|
|63.0|  1| NULL|19.325743|
|53.0|  1| NULL|19.325743|
|57.0|  1| NULL|19.325743|
|56.0|  0| NULL|     14.9|
|56.0|  1| NULL|19.325743|
|44.0|  1| NULL|16.342575|
|52.0|  1| NULL|19.325743|
|57.0|  1| NULL|19.325743|
|48.0|  1| NULL|19.325743|
|54.0|  1| NULL|19.325743|
|48.0|  0| NULL|     14.9|
|49.0|  1| NULL|19.325743|
+----+---+-----+---------+
only showing top 20 rows



Imputing painloc and painexer columns using mode imputation because they have categorical values

In [34]:

# Initialize Spark Session
spark = SparkSession.builder.appName("Data Cleaning").getOrCreate()

# Assuming df is your Spark DataFrame

# Calculate the mode for the 'painloc' column
painloc_mode = df.groupBy("painloc").count().orderBy(F.desc("count")).first()['painloc']

# Replace missing values with the mode in 'painloc'
df = df.withColumn("painloc", F.when(col("painloc").isNull(), lit(painloc_mode)).otherwise(col("painloc")))

# Calculate the mode for the 'painexer' column
painexer_mode = df.groupBy("painexer").count().orderBy(F.desc("count")).first()['painexer']

# Replace missing values with the mode in 'painexer'
df = df.withColumn("painexer", F.when(col("painexer").isNull(), lit(painexer_mode)).otherwise(col("painexer")))




In [36]:


# Initialize Spark Session
spark = SparkSession.builder.appName("Transform Data").getOrCreate()

# Assuming df is your Spark DataFrame

# Convert string to float using cast
df = df.withColumn("smoke_abs", col("smoke_abs").cast("float"))
df = df.withColumn("smoke_cdc", col("smoke_cdc").cast("float"))

# Define a UDF for applying log transformation
def log_transform(value):
    if value is not None:
        return math.log(value / 100 + 10)
    else:
        return None

# Register the UDF
log_transform_udf = udf(log_transform, DoubleType())

# Apply the log transformation to 'smoke_abs' and 'smoke_cdc'
df = df.withColumn("smoke_abs", log_transform_udf("smoke_abs"))
df = df.withColumn("smoke_cdc", log_transform_udf("smoke_cdc"))

# Show the summary statistics to ensure the transformation was applied correctly
df.select("smoke_abs", "smoke_cdc").summary().show()


+-------+--------------------+--------------------+
|summary|           smoke_abs|           smoke_cdc|
+-------+--------------------+--------------------+
|  count|                 899|                 899|
|   mean|  2.3048947590206144|   2.304897557806103|
| stddev|5.691948627879489E-6|7.495973519195899E-6|
|    min|   2.304885031232229|   2.304885031232229|
|    25%|   2.304886028386382|   2.304886028386382|
|    50%|   2.304898705383064|  2.3049012044170367|
|    75%|   2.304898705383064|   2.304904128511966|
|    max|  2.3048997872043624|   2.304904128511966|
+-------+--------------------+--------------------+



In [40]:

# Initialize Spark Session
spark = SparkSession.builder.appName("Prepare Data for Modeling").getOrCreate()

# Remove the original 'smoke' column
df = df.drop("smoke")

# Display the DataFrame structure and preview to verify the column has been dropped
df.show()
df.printSchema()

# Check for NaN values in each column
nan_counts = df.select([F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) for c in df.columns])

# Print the counts of NaN values
nan_counts.show()


+----+---+-------+--------+---+--------+---+----+----+---+--------+-------+-------+-----+-------+-----+------+------------------+------------------+-------------+------------------+
| age|sex|painloc|painexer| cp|trestbps|fbs|prop|nitr|pro|diuretic|thaldur|thalach|exang|oldpeak|slope|target|         smoke_abs|         age_group|smoke_updated|         smoke_cdc|
+----+---+-------+--------+---+--------+---+----+----+---+--------+-------+-------+-----+-------+-----+------+------------------+------------------+-------------+------------------+
|63.0|  1|      1|       1|  1|   145.0|  1|   0|   0|  0|       0|   10.5|    150|    0|    2.3|    3|     0|2.3048997872043624|       45–64 years|    19.325743| 2.304904128511966|
|67.0|  1|      1|       1|  4|   160.0|  0|   1|   0|  0|       0|    9.5|    108|    1|    1.5|    2|     1|2.3048936735190133|65 years and older|    10.765347|2.3048957144312814|
|67.0|  1|      1|       1|  4|   120.0|  0|   1|   0|  0|       0|    8.5|    129|    1| 

Splitting data 90 - 10

In [44]:

# Initialize Spark Session
spark = SparkSession.builder.appName("ML Data Split").getOrCreate()

# Assuming 'df' is your DataFrame and 'target' is your target column
# Add a column for random numbers
df = df.withColumn("random", F.rand(seed=42))

# Window specification for stratified sampling
windowSpec = Window.partitionBy("target").orderBy("random")

# Adding a row number within each partition of 'target'
df = df.withColumn("row_number", F.row_number().over(windowSpec))

# Count the number of each 'target' category
target_counts = df.groupBy("target").count().collect()
fractions = {row['target']: 0.9 for row in target_counts}  # 90% for training

# Applying the sampling fraction based on 'target'
train_df = df.sampleBy("target", fractions, seed=42)
test_df = df.join(train_df, on=df.columns, how="left_anti")  # Get the rows not in the training set

# Optionally remove the extra columns if you don't need them anymore
train_df = train_df.drop("random", "row_number")
test_df = test_df.drop("random", "row_number")



List the models we'll try:

In [47]:
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier
from pyspark.sql.functions import col, isnan, when, count

# Initialize PySpark models
log_reg = LogisticRegression(featuresCol='features', labelCol='target')
dec_tree = DecisionTreeClassifier(featuresCol='features', labelCol='target')
rand_forest = RandomForestClassifier(featuresCol='features', labelCol='target')

# List of models and their names
models = [log_reg, dec_tree, rand_forest]
model_names = ['Logistic Regression', 'Decision Tree', 'Random Forest']

# Assuming 'df' is your DataFrame and 'target' is your target column
# First, let's split the data into features and target similar to X and y in your scikit-learn approach
# Normally, you'd need a VectorAssembler here to combine feature columns into a single features vector
# Assuming that 'features' column already exists or you've previously created it using VectorAssembler

# Splitting the data into training and testing sets (90% training, 10% testing)
train_df, test_df = df.randomSplit([0.9, 0.1], seed=42)



In [58]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, DecisionTreeClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, Imputer
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql.functions import col

# Prepare data: impute missing values and assemble features
numeric_features = [c for c, t in df.dtypes if c != 'target' and t != 'string']
imputer = Imputer(inputCols=numeric_features, outputCols=[f"{c}_imputed" for c in numeric_features])
df = imputer.setStrategy("median").fit(df).transform(df)
df.cache()  # Cache the DataFrame after imputation

assembler = VectorAssembler(
    inputCols=[f"{c}_imputed" for c in numeric_features],
    outputCol="features",
    handleInvalid="keep"
)
df_assembled = assembler.transform(df)
train_df, test_df = df_assembled.randomSplit([0.9, 0.1], seed=42)

# Define classifiers and parameter grid with reduced options
models = {
    'RandomForest': RandomForestClassifier(featuresCol="features", labelCol="target"),
    'DecisionTree': DecisionTreeClassifier(featuresCol="features", labelCol="target"),
    'GBT': GBTClassifier(featuresCol="features", labelCol="target")
}

paramGrid = ParamGridBuilder() \
    .addGrid(models['RandomForest'].numTrees, [100]) \
    .addGrid(models['RandomForest'].maxDepth, [10]) \
    .build()

# Setup cross-validation
evaluator = BinaryClassificationEvaluator(labelCol="target")
results = {}

for name, model in models.items():
    pipeline = Pipeline(stages=[model])
    crossval = CrossValidator(estimator=pipeline,
                              estimatorParamMaps=paramGrid,
                              evaluator=evaluator,
                              numFolds=5)  # Maintaining 5 folds
    cvModel = crossval.fit(train_df)
    bestModel = cvModel.bestModel
    predictions = bestModel.transform(test_df)
    test_score = evaluator.evaluate(predictions)
    results[name] = {
        'Best Model': bestModel,
        'CV Score': cvModel.avgMetrics[0],
        'Test Score': test_score
    }

# Print results
for result in results:
    print(f"{result}: CV Score: {results[result]['CV Score']}, Test Score: {results[result]['Test Score']}")


RandomForest: CV Score: 0.9016813414467526, Test Score: 0.8968095712861416
DecisionTree: CV Score: 0.7431830608363378, Test Score: 0.7791625124626121
GBT: CV Score: 0.9071413226933194, Test Score: 0.9032901296111666


From the results above, Gradient-Boosted Trees classifier (GBT) is the best classifier for the given dataset. For future predictions, I would use this model.