# Add additional features needed for NN Tower

## Imports

In [0]:
from pyspark.sql.functions import col
from pyspark.sql import Window
import pyspark.sql.functions as F
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline

from pyspark.ml.regression import LinearRegression, RandomForestRegressor
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import RegressionEvaluator, BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from xgboost.spark import SparkXGBRegressor

from mlflow.models import infer_signature


import random
import numpy as np
import pandas as pd

import mlflow
print(mlflow.__version__)

import os

2.21.3


## Helper Functions

In [0]:
def checkpoint_dataset(dataset, file_path):
    # Create base folder
    section = "2"
    number = "2"
    base_folder = f"dbfs:/student-groups/Group_{section}_{number}"
    dbutils.fs.mkdirs(base_folder)
    # Create subfolders if file_path contains directories
    full_path = f"{base_folder}/{file_path}.parquet"
    subfolder = "/".join(full_path.split("/")[:-1])
    dbutils.fs.mkdirs(subfolder)
    # Save dataset as a parquet file
    dataset.write.mode("overwrite").parquet(full_path)
    print(f"Checkpointed {file_path}")

## Datasets (Custom Join)

In [0]:
display(dbutils.fs.ls("dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits"))

path,name,size,modificationTime
dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits/test.parquet/,test.parquet/,0,1765424905930
dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits/train.parquet/,train.parquet/,0,1765424905930
dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits/validation.parquet/,validation.parquet/,0,1765424905930


In [0]:
train_df = spark.read.parquet(
    "dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits/train.parquet/"
) # or 1_year_custom_joined
val_df = spark.read.parquet(
    "dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits/validation.parquet/"
) # or 1_year_custom_joined
test_df = spark.read.parquet(
    "dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday/training_splits/test.parquet/"
) # or 1_year_custom_joined

## Feature Engineering

In [0]:
# -----------------------------
# Time-derived cyclic features
# -----------------------------
def add_time_features(df):
    # Hour of departure as float
    df = df.withColumn("dep_hour", F.col("CRS_DEP_MINUTES") / 60.0)

    # Day of year
    df = df.withColumn("day_of_year", F.dayofyear("utc_timestamp").cast("double"))

    # Cyclic transforms (double precision)
    df = df.withColumn("dep_hour_sin", F.sin(2 * F.lit(np.pi) * F.col("dep_hour") / 24))
    df = df.withColumn("dep_hour_cos", F.cos(2 * F.lit(np.pi) * F.col("dep_hour") / 24))

    df = df.withColumn("dow_sin", F.sin(2 * F.lit(np.pi) * F.col("DAY_OF_WEEK") / 7))
    df = df.withColumn("dow_cos", F.cos(2 * F.lit(np.pi) * F.col("DAY_OF_WEEK") / 7))

    df = df.withColumn("doy_sin", F.sin(2 * F.lit(np.pi) * F.col("day_of_year") / 365))
    df = df.withColumn("doy_cos", F.cos(2 * F.lit(np.pi) * F.col("day_of_year") / 365))
    return df

In [0]:
# -----------------------------
# Weather delta features (3-hour changes)
# -----------------------------
def add_weather_deltas(df):
    w = Window.partitionBy("ORIGIN_AIRPORT_SEQ_ID").orderBy("utc_timestamp")
    
    for col in [
        "HourlyVisibility", "HourlyStationPressure",
        "HourlyDryBulbTemperature", "HourlyWindSpeed",
        "HourlyPrecipitation"
    ]:
        lag_col = F.lag(col, 3).over(w)
        delta_col = F.col(col) - lag_col
        # Use lag value if missing instead of 0 to avoid small bias
        df = df.withColumn(
            f"{col}_3h_change",
            F.when(lag_col.isNull(), F.lit(None)).otherwise(delta_col)
        )
    return df

In [0]:
# -----------------------------
# Origin congestion features
# -----------------------------
def add_congestion_features(df):
    # Rolling window: 1 hour before current event
    df = df.withColumn("utc_ts_sec", F.col("utc_timestamp").cast("long"))

    w = Window.partitionBy("ORIGIN_AIRPORT_SEQ_ID").orderBy("utc_ts_sec").rangeBetween(-3600, 0)

    df = df.withColumn(
        "ground_flights_last_hour",
        F.count("utc_ts_sec").over(w) - 1
    )

    return df

In [0]:
# -----------------------------
# Destination congestion features
# -----------------------------
def add_dest_congestion_features(df):
    # Convert timestamp to seconds
    df = df.withColumn("utc_ts_sec", F.col("utc_timestamp").cast("long"))
    
    # Rolling window: 1 hour (3600 seconds) before current row
    w = Window.partitionBy("DEST_AIRPORT_SEQ_ID") \
              .orderBy("utc_ts_sec") \
              .rangeBetween(-3600, 0)
    
    df = df.withColumn(
        "arrivals_last_hour",
        F.count("utc_ts_sec").over(w) - 1  # exclude current row
    )
    return df

In [0]:
### Apply additional feature engineering
train_df_fe = (train_df
               .transform(add_time_features)
               .transform(add_weather_deltas)
               .transform(add_congestion_features)
               .transform(add_dest_congestion_features))

val_df_fe   = (val_df
               .transform(add_time_features)
               .transform(add_weather_deltas)
               .transform(add_congestion_features)
               .transform(add_dest_congestion_features))

test_df_fe  = (test_df
               .transform(add_time_features)
               .transform(add_weather_deltas)
               .transform(add_congestion_features)
               .transform(add_dest_congestion_features))


In [0]:
## Checkpoint updated data

checkpoint_dataset(train_df_fe, "5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/train")
checkpoint_dataset(val_df_fe, "5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/val")
checkpoint_dataset(test_df_fe, "5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/test")


Checkpointed 5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/train
Checkpointed 5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/val
Checkpointed 5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/test
