In [47]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd
from typing import Optional

In [15]:
# Allocate 8GB of RAM to the driver
spark = (
    SparkSession.builder.appName("Dessert or Not?")
    .config("spark.driver.memory", "8g")
    .getOrCreate()
)

# Set the maximum number of pd rows to display to make data exploration easier
pd.set_option("display.max_rows", 1000) 


### Exploratory Data Analysis

In [7]:
food = spark.read.csv('./data/epi_r.csv', inferSchema=True, header=True)

print(food.count(), len(food.columns))

food.printSchema()

20057 680
root
 |-- title: string (nullable = true)
 |-- rating: string (nullable = true)
 |-- calories: string (nullable = true)
 |-- protein: double (nullable = true)
 |-- fat: double (nullable = true)
 |-- sodium: double (nullable = true)
 |-- #cakeweek: double (nullable = true)
 |-- #wasteless: double (nullable = true)
 |-- 22-minute meals: double (nullable = true)
 |-- 3-ingredient recipes: double (nullable = true)
 |-- 30 days of groceries: double (nullable = true)
 |-- advance prep required: double (nullable = true)
 |-- alabama: double (nullable = true)
 |-- alaska: double (nullable = true)
 |-- alcoholic: double (nullable = true)
 |-- almond: double (nullable = true)
 |-- amaretto: double (nullable = true)
 |-- anchovy: double (nullable = true)
 |-- anise: double (nullable = true)
 |-- anniversary: double (nullable = true)
 |-- anthony bourdain: double (nullable = true)
 |-- aperitif: double (nullable = true)
 |-- appetizer: double (nullable = true)
 |-- apple: double (nullabl

Many columns contain undesirable characters such as # (from hashtags), and invalid characters. These should be standardised.

In [8]:
def sanitize_column_name(name):
    """Remove unwanted characters from a column name."""
    out = name
    for i, j in ((" ", "_"), ("-", "_"), ("/", "_"), ("&", "and")):
        out = out.replace(i, j)
    # Keep only letters, numbers and underscores
    return "".join(
        [char for char in out if char.isalpha() or char.isdigit() or char == "-"]
    )

In [9]:
food = food.toDF(*[sanitize_column_name(name) for name in food.columns])

Identify binary columns 

In [30]:
is_binary = food.agg(
    # If the number of distinct values in the column is 2, then it's probably binary
    *[(F.size(F.collect_set(x)) == 2).alias(x) for x in food.columns]
).toPandas()

# Unpivot the dataframe to make it easier to read in the terminal
is_binary.unstack().sort_values()

title                  0    False
rating                 0    False
calories               0    False
protein                0    False
fat                    0    False
sodium                 0    False
cakeweek               0    False
wasteless              0    False
passionfruit           0     True
passover               0     True
pasta                  0     True
pastamaker             0     True
pastry                 0     True
pea                    0     True
peach                  0     True
peanut                 0     True
pecan                  0     True
peanutfree             0     True
pear                   0     True
pasadena               0     True
pennsylvania           0     True
pepper                 0     True
pernod                 0     True
persiannewyear         0     True
persimmon              0     True
peanutbutter           0     True
party                  0     True
paris                  0     True
parsley                0     True
orange        

`cakeweek` and `wasteless` do not appear to be binary:

In [39]:
# Show the distinct values for the suspect columns
food.agg(*[F.collect_set(x) for x in ("cakeweek", "wasteless")]).show(1, False)

# Show the first and last few records with non-binary values to check for data alignment issues
(
    food
    .where("cakeweek > 1.0 or wasteless > 1.0")
    .select("title", "rating", "wasteless", "cakeweek", food.columns[-1])
    .show(truncate=False)
)

+-------------------------------+----------------------+
|collect_set(cakeweek)          |collect_set(wasteless)|
+-------------------------------+----------------------+
|[0.0, 1.0, 1188.0, 24.0, 880.0]|[0.0, 1.0, 1439.0]    |
+-------------------------------+----------------------+

+----------------------------------------------------------------+------------------------+---------+--------+------+
|title                                                           |rating                  |wasteless|cakeweek|turkey|
+----------------------------------------------------------------+------------------------+---------+--------+------+
|"Beet Ravioli with Pine Nut ""Goat Cheese"" Rosemary-Cream Sauce| Aged Balsamic Vinegar "|0.0      |880.0   |0.0   |
|"Seafood ""Cataplana"" with Saffron                             | Vermouth               |1439.0   |24.0    |0.0   |
|"""Pot Roast"" of Seitan                                        | Aunt Gloria-Style "    |0.0      |1188.0  |0.0   |
+-----

Since this is only a small number of records compared to the total dataset size, we remove them.

In [41]:
# Keep only legit values for cakeweek and wasteless
food = (
    food.where(
        (F.col("cakeweek").isin([0.0, 1.0]) | F.col("cakeweek").isNull())
        &
        (F.col("wasteless").isin([0.0, 1.0]) | F.col("wasteless").isNull())   
    )
)

print(food.count(), len(food.columns))

20054 680


Classifying variable types

In [43]:
# Columns containing information unique to each record
IDENTIFIERS = ["title"]

# Columns containing ML features
CONTINUOUS_COLUMNS = [
    "rating",
    "calories",
    "protein",
    "fat",
    "sodium"
]

# Columns containing the features we wish to predict
TARGET_COLUMN = ["dessert"]

# Columns containing binary features
BINARY_COLUMNS = [
    x for x in food.columns 
    if x not in CONTINUOUS_COLUMNS 
    and x not in TARGET_COLUMN
    and x not in IDENTIFIERS
]

We remove records that have only `null` values. After that, we equate `null` to `False` and fill zero as a default value. 

In [44]:
# Remove records that have only null values
food = food.dropna(
    how="all",
    subset=[x for x in food.columns if x not in IDENTIFIERS]
)

food = food.dropna(subset=TARGET_COLUMN)

print(food.count(), len(food.columns))

20049 680


In [46]:
food = food.fillna(0.0, subset=BINARY_COLUMNS)

print(food.where(F.col(BINARY_COLUMNS[0]).isNull()).count())

0


Cleaning continuous columns

In [49]:
# If the value is not null, try casting it to a float.
# If this fails, it's not a number.
@F.udf(T.BooleanType())
def is_a_number(value: Optional[str]) -> bool:
    if not value:
        return True
    try:
        _ = float(value)
    except ValueError:
        return False
    return True

In [51]:
# Show non-numerical values in the continuous columns
food.where(~is_a_number(F.col("rating"))).select(*CONTINUOUS_COLUMNS).show()

+---------+------------+-------+----+------+
|   rating|    calories|protein| fat|sodium|
+---------+------------+-------+----+------+
| Cucumber| and Lemon "|   3.75|null|  null|
+---------+------------+-------+----+------+



In [52]:
# Remove rogue continuous column values, cast remaining values to double
for column in CONTINUOUS_COLUMNS:
    food = food.where(is_a_number(F.col(column)))
    food = food.withColumn(column, F.col(column).cast(T.DoubleType()))

print(food.count(), len(food.columns))

20048 680


We examine the summary statistics of our continuous columns to look for remaining non-sensible values.

In [55]:
food.select(CONTINUOUS_COLUMNS).summary(
    "mean",
    "stddev",
    "min",
    "1%",
    "5%",
    "50%",
    "95%",
    "99%",
    "max"
).show()

+-------+------------------+------------------+------------------+-----------------+-----------------+
|summary|            rating|          calories|           protein|              fat|           sodium|
+-------+------------------+------------------+------------------+-----------------+-----------------+
|   mean| 3.714460295291301|6324.0634571930705|100.17385283565179|346.9398083953107|6226.927244193346|
| stddev|1.3409187660508959|359079.83696340164|3840.6809971287403|20458.04034412409|333349.5680370268|
|    min|               0.0|               0.0|               0.0|              0.0|              0.0|
|     1%|               0.0|              18.0|               0.0|              0.0|              1.0|
|     5%|               0.0|              62.0|               0.0|              0.0|              5.0|
|    50%|             4.375|             331.0|               8.0|             17.0|            294.0|
|    95%|               5.0|            1315.0|              75.0|       

Some remaining nutrition values are very high (way higher than the 75th percentile). We cap the values at the 99th percentile.

In [58]:
# Remove rows with values above the 99th percentile.
# Use hardcoded maxima to ensure analysis is consistent across runs.
maximum = {
    "calories": 3184.0,
    "protein": 173.0,
    "fat": 207.0,
    "sodium": 5649.0
}

# Replace large values while holding onto null values:
for k, v in maximum.items():
    food = food.withColumn(
        k,
        F.when(F.isnull(F.col(k)), F.col(k)).otherwise(
            F.least(F.col(k), F.lit(v))
        )
    )   

Weed out binary columns which are not present enough to be reliable predictors. 

In [70]:
# Calculate the proportion of True and False values for each binary column.
# In this calculation, we consider anything other than 1 to be "False".
proportions_df = food.agg(*[(F.mean(F.col(c)).alias(c)) for c in BINARY_COLUMNS])

# Transpose the proportions DataFrame to create the final table
final_table = (
    proportions_df
    .withColumn("column_name", F.lit(proportions_df.columns[0]))
    .select("column_name", *proportions_df.columns[1:])
    .orderBy(*proportions_df.columns[1:], ascending=False)
)

# Show the resulting table
final_table.show()

+-----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+------------------+-------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+-------------

In [None]:
# Remove binary features that occur too little or too often
inst_sum_of_