## Imports

In [0]:
from pyspark.sql.functions import col
from pyspark.sql import Window
import pyspark.sql.functions as F
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator

import random

import mlflow
print(mlflow.__version__)

import os
os.environ['PYSPARK_PIN_THREAD'] = 'false'
spark.conf.set("spark.databricks.mlflow.trackMLlib.enabled", "true")

RANDOM_SEED = 0
# Define experiment name with proper Databricks path
EXPERIMENT_NAME = "/Shared/team_2_2/mlflow-nn-mlp-layers"
# Create the experiment if it doesn't exist
try:
    experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
    if experiment is None:
        experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
        print(f"Created new experiment with ID: {experiment_id}")
    else:
        print(f"Using existing experiment: {experiment.name}")
    mlflow.set_experiment(EXPERIMENT_NAME)
except Exception as e:
    print(f"Error with experiment setup: {e}")
    # Fallback to default experiment in workspace
    mlflow.set_experiment(f"/Users/{dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()}/default")



## Helper Functions


In [0]:
def checkpoint_dataset(dataset, file_path):
    # Create base folder
    section = "2"
    number = "2"
    base_folder = f"dbfs:/student-groups/Group_{section}_{number}"
    dbutils.fs.mkdirs(base_folder)
    # Create subfolders if file_path contains directories
    full_path = f"{base_folder}/{file_path}.parquet"
    subfolder = "/".join(full_path.split("/")[:-1])
    dbutils.fs.mkdirs(subfolder)
    # Save dataset as a parquet file
    dataset.write.mode("overwrite").parquet(full_path)
    print(f"Checkpointed {file_path}")

## Datasets

In [0]:
import torch

In [0]:
output_root = f"dbfs:/student-groups/Group_2_2/{month_or_year}/nn_layers"

# Load in PyTorch
train_df = spark.read.parquet(f"{output_root}/train_pytorch.parquet")
validation_df   = spark.read.parquet(f"{output_root}/val_pytorch.parquet")
test_df  = spark.read.parquet(f"{output_root}/test_pytorch.parquet")

# # Convert to torch tensors
# x_cat_train = torch.tensor(train_df[categorical_idx_cols].values, dtype=torch.long)
# x_num_train = torch.tensor(train_df[numerical_cols].values, dtype=torch.float)
# x_time_train = torch.tensor(train_df[[time_col]].values, dtype=torch.float)
# y_train = torch.tensor(train_df[target_col].values, dtype=torch.float).view(-1, 1)


In [0]:
from pyspark.sql import functions as F

# Combine train + validation
df = train_df.unionByName(validation_df)

# Filter out cancelled flights
# df = df.filter(F.col("CANCELLED") != 1)

print(f"Number of rows after filtering: {df.count()}")
display(df.limit(10))


In [0]:
# # combine date and scheduled departure time

# from pyspark.sql import functions as F

# # Combine date and scheduled departure time into a timestamp
# df = df.withColumn(
#     "utc_timestamp",
#     F.to_timestamp(
#         F.concat(
#             F.col("FL_DATE"),               # Flight date
#             F.lit(" "),                      # Space separator
#             F.lpad(F.col("CRS_DEP_TIME").cast("string"), 4, "0")  # Ensure 4-digit HHmm
#         ),
#         "yyyy-MM-dd HHmm"                   # Format to parse
#     )
# )

# display(df.select("FL_DATE", "CRS_DEP_TIME", "utc_timestamp").limit(10))


# Create Splits for CV

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Truncate timestamp to hour level
df_indexed = df.withColumn(
    "hour", 
    F.date_trunc("hour", F.col("utc_timestamp"))
)

# Create time index based on unique hours
window_spec = Window.orderBy("hour")
df_indexed = df_indexed.withColumn(
    "time_idx", 
    F.dense_rank().over(window_spec)
)

df_indexed.display()

# 3 M splits config

In [0]:
max_time_idx = df_indexed.agg(F.max("time_idx")).collect()[0][0]
print(f"  Max time index: {max_time_idx}")

train_size = 720      # 30 days (720 hours)
gap_size = 2          # 2 hours
val_size = 168        # 7 days (168 hours)
step_size = 85       # Calculated to get exactly 10 folds

fold_window_size = train_size + gap_size + val_size
n_folds = (max_time_idx - fold_window_size) // step_size + 1
print(f"Step 2: Calculated {n_folds} folds")

# 1 Year splits config

In [0]:
max_time_idx = df_indexed.agg(F.max("time_idx")).collect()[0][0]
print(f"  Max time index: {max_time_idx}")

train_size = 720*4    # 30 days (720 hours)
gap_size = 2          # 2 hours
val_size = 168*4      # 7 days (168 hours)
step_size = 90*4      # Calculated to get exactly 10 folds

fold_window_size = train_size + gap_size + val_size
n_folds = (max_time_idx - fold_window_size) // step_size + 1
print(f"Step 2: Calculated {n_folds} folds")

# 5 Year splits config

In [0]:
# max_time_idx = df_indexed.agg(F.max("time_idx")).collect()[0][0]
# print(f"  Max time index: {max_time_idx}")

# train_size = 720*4*5      # 30 days (720 hours)
# gap_size = 2          # 2 hours
# val_size = 168*4*5        # 7 days (168 hours)
# step_size = 90*4*5       # Calculated to get exactly 10 folds

# fold_window_size = train_size + gap_size + val_size
# n_folds = (max_time_idx - fold_window_size) // step_size + 1
# print(f"Step 2: Calculated {n_folds} folds")

In [0]:
fold_mapping = []

for fold_id in range(1, n_folds + 1):
    fold_start = 1 + (fold_id - 1) * step_size
    # print(fold_id, fold_start)
    for t in range(fold_start, fold_start + train_size):
        fold_mapping.append((t, fold_id, "train"))
    
    for t in range(fold_start + train_size, fold_start + train_size + gap_size):
        fold_mapping.append((t, fold_id, "gap"))

    for t in range(fold_start + train_size + gap_size, fold_start + train_size + gap_size + val_size):
        fold_mapping.append((t, fold_id, "validation"))

fold_df = spark.createDataFrame(fold_mapping, ["time_idx", "fold_id", "split_type"])

result = df_indexed.join(
    F.broadcast(fold_df),
    on='time_idx',
    how='inner'
    )


In [0]:
# month_or_year = "3_month_"
if input("Careful! About to overwrite splits. If you want to continue, type y") == "y":
    result.write \
    .partitionBy("fold_id", "split_type") \
    .mode("overwrite") \
    .parquet(f"dbfs:/student-groups/Group_2_2/{month_or_year}/fe_graph_and_holiday/cv_splits_nn_layers")

## How to read the CV splits

In [0]:
def read_specific_fold(path: str, fold_id: int, split_type: str):
    return spark.read.parquet(f"{path}/fold_id={fold_id}/split_type={split_type}")


In [0]:
# Categorical encoding
carrier_indexer = StringIndexer(inputCol="OP_CARRIER", outputCol="carrier_idx", handleInvalid="keep")
origin_indexer = StringIndexer(inputCol="ORIGIN_AIRPORT_SEQ_ID", outputCol="origin_idx", handleInvalid="keep")
dest_indexer = StringIndexer(inputCol="DEST_AIRPORT_SEQ_ID", outputCol="dest_idx", handleInvalid="keep")
tail_num_indexer = StringIndexer(inputCol="TAIL_NUM", outputCol="tail_num_idx", handleInvalid="keep")

carrier_encoder = OneHotEncoder(inputCol="carrier_idx", outputCol="carrier_vec")
origin_encoder = OneHotEncoder(inputCol="origin_idx", outputCol="origin_vec")
dest_encoder = OneHotEncoder(inputCol="dest_idx", outputCol="dest_vec")
tail_num_encoder = OneHotEncoder(inputCol="tail_num_idx", outputCol="tail_num_vec")


# Assemble all features
assembler = VectorAssembler(
    inputCols=[
        "QUARTER",
        "MONTH", 
        "YEAR",
        "DAY_OF_MONTH",
        "DAY_OF_WEEK",
        "carrier_vec",
        "origin_vec",
        "dest_vec",
        "tail_num_vec",
        "CRS_ELAPSED_TIME",
        "DISTANCE",
        'HourlyDryBulbTemperature',
        'HourlyDewPointTemperature',
        'HourlyRelativeHumidity',
        'HourlyAltimeterSetting',
        'HourlyVisibility',
        'HourlyStationPressure',
        'HourlyWetBulbTemperature',
        'HourlyPrecipitation',
        'HourlyCloudCoverage',
        'HourlyCloudElevation',
        'HourlyWindSpeed'  
    ],
    outputCol="features"
)

In [0]:
## Model architecture defined
import torch
import torch.nn as nn
import torch.nn.functional as F


# -----------------------------
# Residual Block
# -----------------------------
class ResBlock(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        h = F.gelu(self.fc1(self.ln(x)))
        h = self.dropout(h)
        h = self.fc2(h)
        return x + h


# -----------------------------
# Time2Vec
# -----------------------------
class Time2Vec(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.wb = nn.Linear(1, 1)
        self.ws = nn.Linear(1, k)

    def forward(self, t):
        b = self.wb(t)
        s = torch.sin(self.ws(t))
        return torch.cat([b, s], dim=-1)


# -----------------------------
# ResFiLM MLP (F2-optimized)
# -----------------------------
class ResFiLMMLP(nn.Module):
    def __init__(
        self,
        cat_dims,
        emb_dims,
        num_numerical,
        time_dim=8,
        emb_dropout=0.05,
        num_dropout=0.1,
        film_dropout=0.1,
        final_dropout=0.2
    ):
        super().__init__()

        # --- Embedding tower ---
        self.embeddings = nn.ModuleList([
            nn.Embedding(cat_dim, emb_dim)
            for cat_dim, emb_dim in zip(cat_dims, emb_dims)
        ])
        self.emb_total = sum(emb_dims)
        self.emb_dropout = nn.Dropout(emb_dropout)

        # --- Numeric tower ---
        self.fc_num = nn.Linear(num_numerical, 256)
        self.res_blocks = nn.ModuleList([
            ResBlock(256, dropout=num_dropout)
            for _ in range(4)
        ])

        # --- FiLM for embeddings ---
        self.film = nn.Linear(256, 2 * self.emb_total)
        self.film_dropout = nn.Dropout(film_dropout)

        # --- Time2Vec ---
        self.t2v = Time2Vec(time_dim)

        # --- Optional: classification-specific FiLM ---
        self.clf_film = nn.Linear(256, 2 * self.emb_total)
        self.clf_film_dropout = nn.Dropout(film_dropout)

        # --- Final fusion dimension ---
        fused_dim = 256 + self.emb_total + (time_dim + 1) + 1

        # -------------------------------
        # Multi-task heads
        # -------------------------------

        # Regression head (delay minutes)
        self.reg_head = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.GELU(),
            nn.Dropout(final_dropout),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(final_dropout),
            nn.Linear(128, 1)  # raw regression output
        )

        # Classification head (delay yes/no) – deeper
        self.clf_head = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.GELU(),
            nn.Dropout(final_dropout),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(final_dropout),
            nn.Linear(128, 1)  # raw logit, no Sigmoid
        )

    def forward(self, x_cat, x_num, x_time):

        # --- Embeddings ---
        emb = [emb_layer(x_cat[:, i]) for i, emb_layer in enumerate(self.embeddings)]
        emb = torch.cat(emb, dim=-1)
        emb = self.emb_dropout(emb)

        # --- Numeric tower ---
        h = F.gelu(self.fc_num(x_num))
        for block in self.res_blocks:
            h = block(h)

        # --- FiLM modulation ---
        gamma, beta = torch.chunk(self.film(h), 2, dim=-1)
        gamma = self.film_dropout(gamma)
        beta = self.film_dropout(beta)
        emb_mod = gamma * emb + beta

        # --- Time2Vec ---
        t_feat = self.t2v(x_time)

        # --- Fuse towers ---
        z = torch.cat([emb_mod, h, t_feat, x_time], dim=-1)

        # --- Output tasks ---
        reg_out = self.reg_head(z)       # regression output
        clf_out = self.clf_head(z)       # classification logit (no sigmoid)

        return reg_out, clf_out


# -----------------------------
# Example F2-focused loss function
# -----------------------------
def f2_loss(logits, targets, pos_weight=4.0):
    # logits: raw outputs from clf_head
    # targets: binary labels (0/1)
    weight = torch.tensor([pos_weight], device=logits.device)
    return nn.BCEWithLogitsLoss(pos_weight=weight)(logits, targets)


In [0]:
def read_specific_fold(path: str, fold_id: int, split_type: str):
    """
    Read a specific fold from partitioned parquet data.
    Falls back to filtering if direct partition read fails.
    """
    fold_path = f"{path}/fold_id={fold_id}/split_type={split_type}"
    
    try:
        # Try direct partition read
        return spark.read.parquet(fold_path)
    except:
        # Fallback: read all data and filter
        print(f"Direct read failed for fold {fold_id}, using filter method...")
        all_data = spark.read.parquet(path)
        return all_data.filter(
            (all_data.fold_id == fold_id) & 
            (all_data.split_type == split_type)
        )


# Your original train function works as-is now
def train_cv_resfilm(
        n_folds=10,
        month_or_year="1_year_custom_joined",
        params=params
    ):
    
    cv_results = []
    cv_models = []
    
    mlflow.pytorch.autolog(log_models=False)  # you will log manually

    with mlflow.start_run(run_name=f"ResFiLM_CV_{month_or_year}") as parent_run:
        
        mlflow.log_param("n_folds", n_folds)
        mlflow.log_param("model_type", "ResFiLMMLP")
        mlflow.log_param("dataset", month_or_year)
        mlflow.log_params(params)

        for fold_id in range(1, n_folds + 1):

            # ============================
            # LOAD THIS FOLD'S TRAIN + VAL
            # ============================
            fold_train_df = read_specific_fold(
                path=f"dbfs:/student-groups/Group_2_2/{month_or_year}/cv_splits",
                fold_id=fold_id,
                split_type="train"
            )
            fold_val_df = read_specific_fold(
                path=f"dbfs:/student-groups/Group_2_2/{month_or_year}/cv_splits",
                fold_id=fold_id,
                split_type="validation"
            )

            print(f"\n===== FOLD {fold_id} =====")
            print(f"Train rows: {fold_train_df.count()}, Val rows: {fold_val_df.count()}")

            # Convert Spark → Pandas → PyTorch-ready numpy arrays
            train_pd = fold_train_df.select(categorical_cols + numerical_cols + [time_col, target_col]).toPandas()
            val_pd   = fold_val_df.select(categorical_cols + numerical_cols + [time_col, target_col]).toPandas()

            # Build dataloaders
            train_dl = make_dataloader(train_pd, params["batch_size"], shuffle=True)
            val_dl   = make_dataloader(val_pd, params["batch_size"], shuffle=False)

            # ============================
            # NEW Pytorch Model For This Fold
            # ============================
            model = ResFiLMMLP(
                cat_dims,
                emb_dims,
                num_numerical=len(numerical_cols),
                time_dim=params["time_dim"]
            ).to(device)

            optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=params["lr"],
                weight_decay=params["weight_decay"]
            )
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=2
            )
            clf_loss_fn = nn.BCEWithLogitsLoss(
                pos_weight=torch.tensor(params["pos_weight"]).to(device)
            )
            reg_loss_fn = nn.L1Loss()

            best_val_f2 = -1
            best_threshold = 0.5
            patience_counter = 0

            # ============================
            # MLflow nested run for this fold
            # ============================
            with mlflow.start_run(run_name=f"fold_{fold_id}", nested=True):

                for epoch in range(params["num_epochs"]):

                    # -----------------------------------------
                    # TRAIN
                    # -----------------------------------------
                    model.train()
                    y_true_reg_train = []
                    y_pred_reg_train = []
                    y_true_clf_train = []
                    y_pred_clf_train = []

                    for x_cat, x_num, x_time, y_reg in train_dl:
                        x_cat = x_cat.to(device).long()
                        x_num = x_num.to(device).float()
                        x_time = x_time.to(device).float()
                        y_reg = y_reg.to(device).float().view(-1, 1)
                        y_clf = (y_reg > 15).float()

                        optimizer.zero_grad()

                        pred_reg, pred_clf_logits = model(x_cat, x_num, x_time)

                        loss_reg = reg_loss_fn(pred_reg, y_reg)
                        loss_clf = clf_loss_fn(pred_clf_logits, y_clf)
                        loss = loss_reg + params["clf_loss_weight"] * loss_clf

                        loss.backward()
                        optimizer.step()

                        # collect preds
                        y_true_reg_train.append(y_reg.cpu().numpy())
                        y_pred_reg_train.append(pred_reg.detach().cpu().numpy())
                        y_true_clf_train.append(y_clf.cpu().numpy())
                        y_pred_clf_train.append(torch.sigmoid(pred_clf_logits).detach().cpu().numpy())

                    # flatten train predictions
                    y_true_reg_train = np.concatenate(y_true_reg_train)
                    y_pred_reg_train = np.concatenate(y_pred_reg_train)
                    y_true_clf_train = np.concatenate(y_true_clf_train).reshape(-1)
                    y_pred_clf_train = np.concatenate(y_pred_clf_train).reshape(-1)
                    train_pred_binary = (y_pred_clf_train >= 0.5).astype(int)

                    train_f2 = fbeta_score(y_true_clf_train, train_pred_binary, beta=2, zero_division=0)
                    train_mae = mean_absolute_error(y_true_reg_train, y_pred_reg_train)

                    # -----------------------------------------
                    # VALIDATION
                    # -----------------------------------------
                    model.eval()
                    y_true_reg_val = []
                    y_pred_reg_val = []
                    y_true_clf_val = []
                    y_pred_clf_val = []

                    with torch.no_grad():
                        for x_cat, x_num, x_time, y_reg in val_dl:
                            x_cat = x_cat.to(device).long()
                            x_num = x_num.to(device).float()
                            x_time = x_time.to(device).float()
                            y_reg = y_reg.to(device).float().view(-1, 1)
                            y_clf = (y_reg > 15).float()

                            pred_reg, pred_clf_logits = model(x_cat, x_num, x_time)

                            y_true_reg_val.append(y_reg.cpu().numpy())
                            y_pred_reg_val.append(pred_reg.cpu().numpy())
                            y_true_clf_val.append(y_clf.cpu().numpy())
                            y_pred_clf_val.append(torch.sigmoid(pred_clf_logits).cpu().numpy())

                    # flatten val predictions
                    y_true_reg_val = np.concatenate(y_true_reg_val)
                    y_pred_reg_val = np.concatenate(y_pred_reg_val)
                    y_true_clf_val = np.concatenate(y_true_clf_val).reshape(-1)
                    y_pred_clf_val = np.concatenate(y_pred_clf_val).reshape(-1)

                    # --- Find best threshold for F2
                    thresholds = np.linspace(0.05, 0.95, 40)
                    f2_scores = [fbeta_score(y_true_clf_val,
                                             (y_pred_clf_val >= t).astype(int),
                                             beta=2,
                                             zero_division=0)
                                 for t in thresholds]
                    best_idx = np.argmax(f2_scores)
                    val_f2 = f2_scores[best_idx]
                    val_thresh = thresholds[best_idx]

                    val_mae = mean_absolute_error(y_true_reg_val, y_pred_reg_val)

                    print(f"[Fold {fold_id} | Epoch {epoch}] Val F2={val_f2:.4f}  | Best Thresh={val_thresh:.2f}  | Val MAE={val_mae:.4f}")

                    # MLflow
                    mlflow.log_metrics({
                        "train_mae": train_mae,
                        "train_f2": train_f2,
                        "val_mae": val_mae,
                        "val_f2": val_f2,
                        "threshold": val_thresh,
                    }, step=epoch)

                    scheduler.step(val_mae)

                    # Early stopping
                    if val_f2 > best_val_f2:
                        best_val_f2 = val_f2
                        best_threshold = val_thresh
                        patience_counter = 0
                        torch.save(model.state_dict(), f"/tmp/best_model_fold_{fold_id}.pth")
                        mlflow.log_artifact(f"/tmp/best_model_fold_{fold_id}.pth")
                    else:
                        patience_counter += 1
                        if patience_counter >= params["patience"]:
                            print(f"Early stopping at epoch {epoch}")
                            break

            # store fold results
            cv_results.append({
                "fold": fold_id,
                "val_f2": best_val_f2,
                "val_mae": val_mae,
                "best_threshold": best_threshold
            })
            cv_models.append(f"/tmp/best_model_fold_{fold_id}.pth")

        # Log summary to parent run
        df = pd.DataFrame(cv_results)
        mlflow.log_table(df, "cv_results.json")

        print("\nCV COMPLETE:")
        print(df)

    return cv_results, cv_models


In [0]:
cv_results, cv_models = train_cv_models()