## Imports

In [0]:
from pyspark.sql.functions import col
from pyspark.sql import Window
import pyspark.sql.functions as F
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator


import random

import mlflow
print(mlflow.__version__)

import os

spark.conf.set("spark.databricks.mlflow.trackMLlib.enabled", "true")

RANDOM_SEED = 0
# Define experiment name with proper Databricks path
EXPERIMENT_NAME = "/Shared/team_2_2/mlflow-baseline"
# Create the experiment if it doesn't exist
try:
    experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
    if experiment is None:
        experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
        print(f"Created new experiment with ID: {experiment_id}")
    else:
        print(f"Using existing experiment: {experiment.name}")
    mlflow.set_experiment(EXPERIMENT_NAME)
except Exception as e:
    print(f"Error with experiment setup: {e}")
    # Fallback to default experiment in workspace
    mlflow.set_experiment(f"/Users/{dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()}/default")



## Helper Functions


In [0]:
def checkpoint_dataset(dataset, file_path):
    # Create base folder
    section = "2"
    number = "2"
    base_folder = f"dbfs:/student-groups/Group_{section}_{number}"
    dbutils.fs.mkdirs(base_folder)
    # Create subfolders if file_path contains directories
    full_path = f"{base_folder}/{file_path}.parquet"
    subfolder = "/".join(full_path.split("/")[:-1])
    dbutils.fs.mkdirs(subfolder)
    # Save dataset as a parquet file
    dataset.write.mode("overwrite").parquet(full_path)
    print(f"Checkpointed {file_path}")

In [0]:

# def checkpoint_dataset(dataset, file_path):
#     # Create folder
#     section = "2"
#     number = "2"
#     folder_path = f"dbfs:/student-groups/Group_{section}_{number}"
#     dbutils.fs.mkdirs(folder_path)
#     # Save df_weather as a parquet file
#     dataset.write.parquet(f"{folder_path}/{file_path}.parquet")
#     print(f"Checkpointed {file_path}")

## Datasets

### Custom Join Dataset - 1 year

In [0]:
%fs ls dbfs:/mnt/mids-w261/daniel_costa@berkeley.edu/Custom_Joins/V3/

In [0]:
# Read in custom joined data
custom_joined_path = 'dbfs:/mnt/mids-w261/daniel_costa@berkeley.edu/Custom_Joins/V3/custom_join_v3_1y.parquet'

df = spark.read.parquet(custom_joined_path)

df = df.filter(F.col("CANCELLED") != 1)
print(df.count())
display(df.limit(10))

# Things to keep in mind
- Predict two hours before
- Remove all the delay columns
- Are we only predicting departure delays or arrival delays also? For example, the pilot misses the landing, and has to circle back for 20 minutes. Should we solve for that? I don't think we should.

## Preprocessing / Cleanup

In [0]:
df = df.cache() # cache joined dataset

In [0]:
# combine date and scheduled departure time

df = df.withColumn(
    "utc_timestamp",
    F.to_timestamp(
        F.concat(
            F.col("FL_DATE"),
            F.lit(" "),
            F.lpad(F.col("CRS_DEP_TIME").cast("string"), 4, "0")
        ),
        "yyyy-MM-dd HHmm"
    )
)

### Split 3 month joined dataset

In [0]:
from pyspark.sql import Window
TRAIN_SIZE = 0.70
VALIDATION_SIZE = 0.10

# REMOVE ALL CANCELLED FLIGHTS
df = df.filter(F.col("CANCELLED") != 1)

df = df.sort('utc_timestamp')

# Add row number based on timestamp order
window = Window.orderBy('utc_timestamp')
df = df.withColumn("row_num", F.row_number().over(window))

total_rows = df.count()

# Calculate split points
train_end = int(total_rows * TRAIN_SIZE)
validation_end = int(total_rows * (TRAIN_SIZE + VALIDATION_SIZE))  # 70% + 10%

# Split based on row number
train_df = df.filter(F.col("row_num") <= train_end)
validation_df = df.filter((F.col("row_num") > train_end) & (F.col("row_num") <= validation_end))
test_df = df.filter(F.col("row_num") > validation_end)

# Drop the helper column
train_df = train_df.drop("row_num")
validation_df = validation_df.drop("row_num")
test_df = test_df.drop("row_num")

In [0]:
# Get the last utc_timestamp from train_df
last_flight_ts = train_df.agg(F.max("utc_timestamp").alias("last_ts")).collect()[0]["last_ts"]

# Add a 2 hour gap
gap_ts = F.timestamp_add("HOUR", F.lit(2), F.lit(last_flight_ts))

# Filter validation_df to keep everything after the gap timestamp
# validation_after_gap_df = validation_df.filter(F.col("utc_timestamp") > gap_ts)
validation__df = validation_df.filter(F.col("utc_timestamp") > gap_ts)

In [0]:
%fs ls dbfs:/mnt/mids-w261/daniel_costa@berkeley.edu/Custom_Joins/V3/

In [0]:
if input("CAREFUL: You're about to write to DBFS. Type 'y' to continue.") == "y":
    checkpoint_dataset(train_df, "1_year_custom_joined/raw_data/training_splits/train")
    checkpoint_dataset(validation_df, "1_year_custom_joined/raw_data/training_splits/validation")
    checkpoint_dataset(test_df, "1_year_custom_joined/raw_data/training_splits/test")

#### check checkpoint files

In [0]:
# %fs ls dbfs:/student-groups/Group_2_2/1_year_custom_joined

## Model Iterations

In [0]:
checkpoint_path = f"dbfs:/student-groups/Group_2_2"
month_or_year = "1_year_custom_joined"

dataset_path = f"{checkpoint_path}/{month_or_year}/raw_data/training_splits"

# Read datasets from checkpoint
train_df = spark.read.parquet(f"{dataset_path}/train.parquet")
validation_df = spark.read.parquet(f"{dataset_path}/validation.parquet")
test_df = spark.read.parquet(f"{dataset_path}/test.parquet")

### Ignore weather rows with nan's

In [0]:
train_df = train_df.dropna(subset=[
        'HourlyDryBulbTemperature',
        'HourlyDewPointTemperature',
        'HourlyRelativeHumidity',
        'HourlyAltimeterSetting',
        'HourlyVisibility',
        'HourlyStationPressure',
        'HourlyWetBulbTemperature',
        'HourlyPrecipitation',
        'HourlyCloudCoverage',
        'HourlyCloudElevation',
        'HourlyWindSpeed'  
    ])

validation_df = validation_df.dropna(subset=[
        'HourlyDryBulbTemperature',
        'HourlyDewPointTemperature',
        'HourlyRelativeHumidity',
        'HourlyAltimeterSetting',
        'HourlyVisibility',
        'HourlyStationPressure',
        'HourlyWetBulbTemperature',
        'HourlyPrecipitation',
        'HourlyCloudCoverage',
        'HourlyCloudElevation',
        'HourlyWindSpeed'  
    ])

test_df = test_df.dropna(subset=[
        'HourlyDryBulbTemperature',
        'HourlyDewPointTemperature',
        'HourlyRelativeHumidity',
        'HourlyAltimeterSetting',
        'HourlyVisibility',
        'HourlyStationPressure',
        'HourlyWetBulbTemperature',
        'HourlyPrecipitation',
        'HourlyCloudCoverage',
        'HourlyCloudElevation',
        'HourlyWindSpeed'  
    ])

In [0]:
# Feature Engineering

## CRS_DEP_TIME is local time so we can use this feature 
## But in order to use it, we have to convert it to minutes since midnight
## Otherwise the timing will be off b/c it's not true UTC

train_df = train_df.withColumn(
    "CRS_DEP_MINUTES", 
    (F.floor(F.col("CRS_DEP_TIME") / 100) * 60 + (F.col("CRS_DEP_TIME") % 100))
)

validation_df = validation_df.withColumn(
    "CRS_DEP_MINUTES", 
    (F.floor(F.col("CRS_DEP_TIME") / 100) * 60 + (F.col("CRS_DEP_TIME") % 100))
)

test_df = test_df.withColumn(
    "CRS_DEP_MINUTES", 
    (F.floor(F.col("CRS_DEP_TIME") / 100) * 60 + (F.col("CRS_DEP_TIME") % 100))
)


### Feature Eng.

#### Was the previous flight delayed? And by how much was the previous flight delayed?

In [0]:
train_df = train_df.cache()
validation_df = validation_df.cache()
test_df = test_df.cache()

In [0]:
window_4h = Window \
    .partitionBy("ORIGIN_AIRPORT_SEQ_ID") \
    .orderBy(F.col("utc_timestamp").cast("long")) \
    .rangeBetween(-14400, -7200) # 4 hours to 2 hours before

train_df = train_df \
    .withColumn("origin_delays_4h", F.count(F.when(F.col("DEP_DELAY_NEW") > 15, 1)) \
        .over(window_4h)
    )
validation_df = validation_df \
    .withColumn("origin_delays_4h", F.count(F.when(F.col("DEP_DELAY_NEW") > 15, 1)) \
        .over(window_4h)
    )

test_df = test_df \
    .withColumn("origin_delays_4h", F.count(F.when(F.col("DEP_DELAY_NEW") > 15, 1)) \
        .over(window_4h)
    )

In [0]:
train_df = train_df \
    .withColumn("prev_flight_delay_in_minutes", F.lag("DEP_DELAY_NEW", 1) \
        .over(Window.partitionBy("TAIL_NUM") \
        .orderBy("utc_timestamp"))) \
    .withColumn("prev_flight_delay_in_minutes", F.when(F.col("prev_flight_delay_in_minutes").isNull(), -1) \
        .otherwise(F.col("prev_flight_delay_in_minutes"))) \
    .withColumn("prev_flight_delay", F.when(F.col("prev_flight_delay_in_minutes") > 15, 1) \
        .otherwise(F.lit(0)))
    
validation_df = validation_df \
    .withColumn("prev_flight_delay_in_minutes", F.lag("DEP_DELAY_NEW", 1) \
        .over(Window.partitionBy("TAIL_NUM") \
        .orderBy("utc_timestamp"))) \
    .withColumn("prev_flight_delay_in_minutes", F.when(F.col("prev_flight_delay_in_minutes").isNull(), -1) \
        .otherwise(F.col("prev_flight_delay_in_minutes"))) \
    .withColumn("prev_flight_delay", F.when(F.col("prev_flight_delay_in_minutes") > 15, 1) \
        .otherwise(F.lit(0)))
    
test_df = test_df \
    .withColumn("prev_flight_delay_in_minutes", F.lag("DEP_DELAY_NEW", 1) \
        .over(Window.partitionBy("TAIL_NUM") \
        .orderBy("utc_timestamp"))) \
    .withColumn("prev_flight_delay_in_minutes", F.when(F.col("prev_flight_delay_in_minutes").isNull(), -1) \
        .otherwise(F.col("prev_flight_delay_in_minutes"))) \
    .withColumn("prev_flight_delay", F.when(F.col("prev_flight_delay_in_minutes") > 15, 1) \
        .otherwise(F.lit(0)))


### [Feature] Delay time for flights at departure locations over the past 7 days

In [0]:
window_7d_origin = Window \
    .partitionBy("ORIGIN_AIRPORT_SEQ_ID") \
    .orderBy(F.col("utc_timestamp").cast("long")) \
    .rangeBetween(-604800, -14400) # -7 days, -4 hours

train_df = train_df.withColumn(
    'delay_origin_7d_sum_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_origin)
)

# Handle the nulls by coalescing the raw feature with 0
train_df = train_df.withColumn(
    'delay_origin_7d', 
    F.coalesce(F.col('delay_origin_7d_sum_raw'), F.lit(0))
).drop('delay_origin_7d_sum_raw') 

validation_df = validation_df.withColumn(
    'delay_origin_7d_sum_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_origin)
)

# Handle the nulls by coalescing the raw feature with 0
validation_df = validation_df.withColumn(
    'delay_origin_7d', 
    F.coalesce(F.col('delay_origin_7d_sum_raw'), F.lit(0))
).drop('delay_origin_7d_sum_raw') 

test_df = test_df.withColumn(
    'delay_origin_7d_sum_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_origin)
)

# Handle the nulls by coalescing the raw feature with 0
test_df = test_df.withColumn(
    'delay_origin_7d', 
    F.coalesce(F.col('delay_origin_7d_sum_raw'), F.lit(0))
).drop('delay_origin_7d_sum_raw') 

### [Feature] Number of delayed flights at departure and carrier location over the last 7 days

In [0]:
window_7d_origin_carrier = Window \
    .partitionBy("ORIGIN_AIRPORT_SEQ_ID", "OP_UNIQUE_CARRIER") \
    .orderBy(F.col("utc_timestamp").cast("long")) \
    .rangeBetween(-604800, -14400) # -7 days, -4 hours

train_df = train_df.withColumn(
    'delay_origin_carrier_7d_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_origin_carrier)
)

# Handle the nulls by coalescing the raw feature with 0
train_df = train_df.withColumn(
    'delay_origin_carrier_7d', 
    F.coalesce(F.col('delay_origin_carrier_7d_raw'), F.lit(0))
).drop('delay_origin_carrier_7d_raw') 


validation_df = validation_df.withColumn(
    'delay_origin_carrier_7d_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_origin_carrier)
)

# Handle the nulls by coalescing the raw feature with 0
validation_df = validation_df.withColumn(
    'delay_origin_carrier_7d', 
    F.coalesce(F.col('delay_origin_carrier_7d_raw'), F.lit(0))
).drop('delay_origin_carrier_7d_raw') 

test_df = test_df.withColumn(
    'delay_origin_carrier_7d_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_origin_carrier)
)

# Handle the nulls by coalescing the raw feature with 0
test_df = test_df.withColumn(
    'delay_origin_carrier_7d', 
    F.coalesce(F.col('delay_origin_carrier_7d_raw'), F.lit(0))
).drop('delay_origin_carrier_7d_raw') 

### [Feature] number of delays in route in the last 7 days

In [0]:
train_df = train_df.withColumn(
  "route",
  F.concat(F.col("ORIGIN"), F.lit("-"), F.col("DEST"))
)

validation_df = validation_df.withColumn(
  "route",
  F.concat(F.col("ORIGIN"), F.lit("-"), F.col("DEST"))
)

test_df = test_df.withColumn(
  "route",
  F.concat(F.col("ORIGIN"), F.lit("-"), F.col("DEST"))
)

In [0]:
window_7d_route = Window \
    .partitionBy("route") \
    .orderBy(F.col("utc_timestamp").cast("long")) \
    .rangeBetween(-604800, -14400) # -7 days, -4 hours


train_df = train_df.withColumn(
    'delay_route_7d_sum_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_route)
)

# Handle the nulls by coalescing the raw feature with 0
train_df = train_df.withColumn(
    'delay_route_7d', 
    F.coalesce(F.col('delay_route_7d_sum_raw'), F.lit(0))
).drop('delay_route_7d_sum_raw') 

validation_df = validation_df.withColumn(
    'delay_route_7d_sum_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_route)
)

# Handle the nulls by coalescing the raw feature with 0
validation_df = validation_df.withColumn(
    'delay_route_7d', 
    F.coalesce(F.col('delay_route_7d_sum_raw'), F.lit(0))
).drop('delay_route_7d_sum_raw') 

test_df = test_df.withColumn(
    'delay_route_7d_sum_raw', 
    F.sum('DEP_DELAY_NEW').over(window_7d_route)
)

# Handle the nulls by coalescing the raw feature with 0
test_df = test_df.withColumn(
    'delay_route_7d', 
    F.coalesce(F.col('delay_route_7d_sum_raw'), F.lit(0))
).drop('delay_route_7d_sum_raw') 

### [Feature] - number of flights per day for one plane

In [0]:
window_flights_24h = Window \
  .partitionBy("TAIL_NUM", "FL_DATE") \
  .orderBy(F.col("utc_timestamp").cast("long"))

train_df = train_df.withColumn(
    'flight_count_24h', 
    F.count("*").over(window_flights_24h)
)

validation_df = validation_df.withColumn(
    'flight_count_24h', 
    F.count("*").over(window_flights_24h)
)
test_df = test_df.withColumn(
    'flight_count_24h', 
    F.count("*").over(window_flights_24h)
)

### [Feature] time between landed and scheduled flight

In [0]:
from pyspark.sql import functions as F

def hhmm_to_time_str(col):
    padded = F.lpad(F.col(col).cast("string"), 4, "0")
    return F.concat_ws(":", padded.substr(1, 2), padded.substr(3, 2))

In [0]:
# train_df = train_df.withColumn(
#     "CRS_ARR_TIME_STR",
#     hhmm_to_time_str("ARR_TIME")
# ).withColumn(
#     "WHEELS_ON_STR",
#     hhmm_to_time_str("WHEELS_ON")
# )

# train_df = train_df.withColumn(
#     "CRS_ARR_TIMESTAMP",
#     F.to_timestamp("CRS_ARR_TIME_STR", "HH:mm")
# ).withColumn(
#     "WHEELS_ON_TIMESTAMP",
#     F.to_timestamp("WHEELS_ON_STR", "HH:mm")
# )

# train_df = train_df.withColumn(
#     "LANDING_TIME_DIFF_MINUTES",
#     F.coalesce(
#         (
#             (F.col("WHEELS_ON_TIMESTAMP").cast("long") - 
#              F.col("CRS_ARR_TIMESTAMP").cast("long")) / 60
#         ),
#         F.lit(0)
#     )
# )

# validation_df = validation_df.withColumn(
#     "CRS_ARR_TIME_STR",
#     hhmm_to_time_str("ARR_TIME")
# ).withColumn(
#     "WHEELS_ON_STR",
#     hhmm_to_time_str("WHEELS_ON")
# )

# validation_df = validation_df.withColumn(
#     "CRS_ARR_TIMESTAMP",
#     F.to_timestamp("CRS_ARR_TIME_STR", "HH:mm")
# ).withColumn(
#     "WHEELS_ON_TIMESTAMP",
#     F.to_timestamp("WHEELS_ON_STR", "HH:mm")
# )

# validation_df = validation_df.withColumn(
#     "LANDING_TIME_DIFF_MINUTES",
#     F.coalesce(
#         (
#             (F.col("WHEELS_ON_TIMESTAMP").cast("long") - 
#              F.col("CRS_ARR_TIMESTAMP").cast("long")) / 60
#         ),
#         F.lit(0)
#     )
# )

In [0]:
window_turnaround = Window \
    .partitionBy("TAIL_NUM") \
    .orderBy(F.col("WHEELS_ON").cast("long")) 


train_df = train_df.withColumn(
    "next_scheduled_dep_ts", 
    F.lead("CRS_DEP_TIME", 1).over(window_turnaround)
)

train_df = train_df.withColumn(
    "LANDING_TIME_DIFF_MINUTES",
    F.coalesce(
        (F.col("next_scheduled_dep_ts").cast("long") - F.col("WHEELS_ON").cast("long")) / 60,
        F.lit(-999) 
    )
).drop("next_scheduled_dep_ts")

train_df.select("TAIL_NUM", "WHEELS_ON", "CRS_DEP_TIME", "LANDING_TIME_DIFF_MINUTES").orderBy("TAIL_NUM", "WHEELS_ON").show(5)

In [0]:
validation_df = validation_df.withColumn(
    "next_scheduled_dep_ts", 
    F.lead("CRS_DEP_TIME", 1).over(window_turnaround)
)

validation_df = validation_df.withColumn(
    "LANDING_TIME_DIFF_MINUTES",
    F.coalesce(
        (F.col("next_scheduled_dep_ts").cast("long") - F.col("WHEELS_ON").cast("long")) / 60,
        F.lit(-999) 
    )
).drop("next_scheduled_dep_ts")

In [0]:
test_df = test_df.withColumn(
    "next_scheduled_dep_ts", 
    F.lead("CRS_DEP_TIME", 1).over(window_turnaround)
)

test_df = test_df.withColumn(
    "LANDING_TIME_DIFF_MINUTES",
    F.coalesce(
        (F.col("next_scheduled_dep_ts").cast("long") - F.col("WHEELS_ON").cast("long")) / 60,
        F.lit(-999) 
    )
).drop("next_scheduled_dep_ts")

### [Feature] Average Delay time by airport
- by origin airport and by destination

In [0]:
avg_delay_by_airport_train = train_df.groupBy("DEST_AIRPORT_SEQ_ID").agg(
    F.avg("ARR_DELAY").alias("AVG_ARR_DELAY")
)

avg_delay_by_airport_val = validation_df.groupBy("DEST_AIRPORT_SEQ_ID").agg(
    F.avg("ARR_DELAY").alias("AVG_ARR_DELAY")
)

avg_delay_by_airport_test = test_df.groupBy("DEST_AIRPORT_SEQ_ID").agg(
    F.avg("ARR_DELAY").alias("AVG_ARR_DELAY")
)

In [0]:
# train_df.select("DEST", "ARR_DELAY", "AVG_ARR_DELAY").show(20, False)

In [0]:

window_7d_origin = Window \
    .partitionBy("ORIGIN_AIRPORT_SEQ_ID") \
    .orderBy(F.col("utc_timestamp").cast("long")) \
    .rangeBetween(-604800, -14400) # -7 days (604800s) to -4 hours (14400s)

train_df = train_df.withColumn(
    'avg_delay_origin_7d_raw', 
    F.avg('ARR_DELAY').over(window_7d_origin)
)

train_df = train_df.withColumn(
    'AVG_ARR_DELAY_ORIGIN', 
    F.coalesce(F.col('avg_delay_origin_7d_raw'), F.lit(0))
).drop('avg_delay_origin_7d_raw') 

validation_df = validation_df.withColumn(
    'avg_delay_origin_7d_raw', 
    F.avg('ARR_DELAY').over(window_7d_origin)
)

validation_df = validation_df.withColumn(
    'AVG_ARR_DELAY_ORIGIN', 
    F.coalesce(F.col('avg_delay_origin_7d_raw'), F.lit(0))
).drop('avg_delay_origin_7d_raw')


test_df = test_df.withColumn(
    'avg_delay_origin_7d_raw', 
    F.avg('ARR_DELAY').over(window_7d_origin)
)

test_df = test_df.withColumn(
    'AVG_ARR_DELAY_ORIGIN', 
    F.coalesce(F.col('avg_delay_origin_7d_raw'), F.lit(0))
).drop('avg_delay_origin_7d_raw')

### [Feature] Average taxi-out time by airport

In [0]:
train_df = train_df.withColumn(
    'avg_taxi_out_origin_7d_raw', 
    F.avg('TAXI_OUT').over(window_7d_origin)
)

train_df = train_df.withColumn(
    'AVG_TAXI_OUT_ORIGIN', 
    F.coalesce(F.col('avg_taxi_out_origin_7d_raw'), F.lit(0))
).drop('avg_taxi_out_origin_7d_raw') 

validation_df = validation_df.withColumn(
    'avg_taxi_out_origin_7d_raw', 
    F.avg('TAXI_OUT').over(window_7d_origin)
)

validation_df = validation_df.withColumn(
    'AVG_TAXI_OUT_ORIGIN', 
    F.coalesce(F.col('avg_taxi_out_origin_7d_raw'), F.lit(0))
).drop('avg_taxi_out_origin_7d_raw')

test_df = test_df.withColumn(
    'avg_taxi_out_origin_7d_raw', 
    F.avg('TAXI_OUT').over(window_7d_origin)
)

test_df = test_df.withColumn(
    'AVG_TAXI_OUT_ORIGIN', 
    F.coalesce(F.col('avg_taxi_out_origin_7d_raw'), F.lit(0))
).drop('avg_taxi_out_origin_7d_raw')

In [0]:
null_counts = validation_df.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in validation_df.columns])
display(null_counts)

## Checkpoint results with feature engineering

In [0]:
%fs ls dbfs:/student-groups/Group_2_2/1_year_custom_joined/feature_eng

In [0]:
if input("CAREFUL: You're about to write to DBFS. Type 'y' to continue.") == "y":
    checkpoint_dataset(train_df, f"{month_or_year}/feature_eng/training_splits/train")
    checkpoint_dataset(validation_df, f"{month_or_year}/feature_eng/training_splits/validation")
    checkpoint_dataset(test_df, f"{month_or_year}/feature_eng/training_splits/test")

### Check data checkpoint

In [0]:
checkpoint_path = f"dbfs:/student-groups/Group_2_2"
dataset_path = f"{checkpoint_path}/3_month_custom_joined/feature_eng/training_splits/"

# Read datasets from checkpoint
check_train_df = spark.read.parquet(f"{dataset_path}/train.parquet")
check_validation_df = spark.read.parquet(f"{dataset_path}/validation.parquet")

In [0]:
check_train_df.columns == check_validation_df.columns

In [0]:
for col in check_train_df.columns:
    if col not in check_validation_df.columns:
        print(col)

In [0]:
display(check_train_df)
# display(check_validation_df)

In [0]:
check_train_df.columns

move to modeling!