# Neural Network Stack: Flight Delay Prediction (ResFiLM-MLP)

| Section | Description |
| :--- | :--- |
| **Model** | Multi-Towered Residual FiLM-MLP (Deep Learning) |
| **Objective** | Model architecture code, multi-fold hyperparameter tuning, final training, and prediction generation for 5-Year dataset. |
| **Target Metric** | Maximizing **F2 Score** (Primary), Minimizing **MAE** (Secondary). |
| **Architecture TLDR** | Dual-Head **Multi-Task** Learning. Combines a **Categorical Tower** (Embeddings modulated by FiLM) and a **Numerical Tower** (Residual Blocks) to fuse features for Regression (MAE) and Classification (F2) heads. |
| **Key Optimization** | Class imbalance handled with 4x positive weighting in the Classification Loss (BCE). |
| **Output** | Parquet file of **Out-of-Sample Predictions** for the entire 5-year period. |
| **Best Params Used** | `{'lr': 0.0001556, 'batch_size': 4096, 'alpha': 0.342, 'time_dim': 16, 'emb_drop': 0.046, 'num_drop': 0.324, 'final_drop': 0.100}` |

## Imports & Global Config

In [0]:
# ---------------------------------------------------------
# Imports & Global Configuration
# ---------------------------------------------------------
%pip install optuna

import pyspark.sql.functions as sf 
import torch.nn.functional as F    
import numpy as np
import pandas as pd
from pyspark.sql import Window
from pyspark.ml.feature import StringIndexer
from pyspark.sql.types import FloatType, IntegerType
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import fbeta_score, roc_auc_score, mean_squared_error, mean_absolute_error
import copy
import mlflow
import optuna

# Enable Arrow for faster, lower-memory conversion from Spark to Pandas
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

def get_device(): 
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

DEVICE = get_device()
print(f"Using device: {DEVICE}")

# Feature Definitions
categorical_cols = [
    "OP_UNIQUE_CARRIER", "ORIGIN_AIRPORT_SEQ_ID", "DEST_AIRPORT_SEQ_ID",
    "route", "AIRPORT_HUB_CLASS", "AIRLINE_CATEGORY"
]

numerical_cols = [
    "DISTANCE", "CRS_ELAPSED_TIME", "prev_flight_delay_in_minutes",
    "origin_delays_4h", "delay_origin_7d", "delay_origin_carrier_7d",
    "delay_route_7d", "flight_count_24h", "AVG_TAXI_OUT_ORIGIN",
    "AVG_ARR_DELAY_ORIGIN", "in_degree", "out_degree",
    "weighted_in_degree", "weighted_out_degree", "betweenness",
    "closeness", "N_RUNWAYS", "HourlyVisibility", "HourlyStationPressure",
    "HourlyWindSpeed", "HourlyDryBulbTemperature", "HourlyDewPointTemperature",
    "HourlyRelativeHumidity", "HourlyAltimeterSetting", "HourlyWetBulbTemperature",
    "HourlyPrecipitation", "HourlyCloudCoverage", "HourlyCloudElevation",
    "ground_flights_last_hour", "arrivals_last_hour",
    "dow_sin", "dow_cos", "doy_sin", "doy_cos"
]

# Deduplicate numerical columns to prevent shape mismatches
numerical_cols = list(dict.fromkeys(numerical_cols))
print(f"Numerical columns count: {len(numerical_cols)}")

time_col = "CRS_DEP_MINUTES"
target_col = "DEP_DELAY_NEW"

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
Using device: cpu
Numerical columns count: 34


## Data Prep

In [0]:
EXPERIMENT_NAME = "/Shared/team_2_2/mlflow-nn-tower-tuned"


In [0]:
# ---------------------------------------------------------
# Data Prep
# ---------------------------------------------------------
from pyspark.ml.feature import StringIndexer
from pyspark.sql.types import FloatType, IntegerType

def prepare_fold_data(fold_df, categorical_cols, numerical_cols, time_col, target_col):
    """
    Prepares data for a specific fold with OOM protection and Safe Indexing.
    """
    # Split Data
    train_fe = fold_df.filter(sf.col("split_type") == "train")
    val_fe   = fold_df.filter(sf.col("split_type") == "validation")

    # Cast Numerics to Float32
    for c in numerical_cols + [time_col, target_col]:
        train_fe = train_fe.withColumn(c, sf.col(c).cast(FloatType()))
        val_fe = val_fe.withColumn(c, sf.col(c).cast(FloatType()))

    # Distributed String Indexing
    indexers = []
    encoded_cat_cols = []
    
    for c in categorical_cols:
        output_col = f"{c}_idx"
        encoded_cat_cols.append(output_col)
        # handleInvalid='keep' creates a bucket for UNK values at the end
        indexers.append(StringIndexer(inputCol=c, outputCol=output_col, 
                                      stringOrderType="alphabetAsc", handleInvalid="keep"))
    
    # Fit Indexers on TRAIN data only
    for ind in indexers:
        model = ind.fit(train_fe)
        train_fe = model.transform(train_fe)
        val_fe = model.transform(val_fe)

    # Select & Cast Indices to Integer
    final_cols = encoded_cat_cols + numerical_cols + [time_col, target_col]
    
    for c in encoded_cat_cols:
        train_fe = train_fe.withColumn(c, sf.col(c).cast(IntegerType()))
        val_fe = val_fe.withColumn(c, sf.col(c).cast(IntegerType()))

    # Collect to Pandas
    train_pd = train_fe.select(final_cols).toPandas()
    val_pd   = val_fe.select(final_cols).toPandas()

    # Rename columns
    rename_map = {f"{c}_idx": c for c in categorical_cols}
    train_pd = train_pd.rename(columns=rename_map)
    val_pd = val_pd.rename(columns=rename_map)

    # Fit Scaler (Train Only)
    scaler = StandardScaler()
    train_pd[numerical_cols] = scaler.fit_transform(train_pd[numerical_cols])
    val_pd[numerical_cols]   = scaler.transform(val_pd[numerical_cols])

    # Calculate Embedding Dimensions 
    cat_dims = [int(train_pd[c].max() + 2) for c in categorical_cols]
    
    emb_dims = [min(64, int(n**0.3)) for n in cat_dims]

    return train_pd, val_pd, cat_dims, emb_dims

## Model Definitions

In [0]:
# ---------------------------------------------------------
# Model Definitions
# ---------------------------------------------------------
class FlightDataset(Dataset):
    def __init__(self, df):
        # Data is already cast to correct types in prepare_fold_data
        self.cat = torch.tensor(df[categorical_cols].values, dtype=torch.long)
        self.num = torch.tensor(df[numerical_cols].values, dtype=torch.float32)
        self.time = torch.tensor(df[time_col].values, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(df[target_col].values, dtype=torch.float32).unsqueeze(1)

    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.cat[idx], self.num[idx], self.time[idx], self.y[idx]

class ResBlock(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        h = F.gelu(self.fc1(self.ln(x)))
        return x + self.fc2(self.dropout(h))

class Time2Vec(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.wb = nn.Linear(1, 1)
        self.ws = nn.Linear(1, k)
    def forward(self, t):
        return torch.cat([self.wb(t), torch.sin(self.ws(t))], dim=-1)

class ResFiLMMLP(nn.Module):
    def __init__(self, cat_dims, emb_dims, num_numerical, time_dim=8, 
                 emb_dropout=0.1, num_dropout=0.1, film_dropout=0.1, final_dropout=0.2):
        super().__init__()
        
        # Embedding Tower
        self.embeddings = nn.ModuleList([nn.Embedding(d, e) for d, e in zip(cat_dims, emb_dims)])
        self.emb_dropout = nn.Dropout(emb_dropout)
        emb_total = sum(emb_dims)
        
        # Numeric Tower
        self.fc_num = nn.Linear(num_numerical, 256)
        self.res_blocks = nn.ModuleList([ResBlock(256, num_dropout) for _ in range(4)])
        
        # FiLM & Time
        self.film = nn.Linear(256, 2 * emb_total)
        self.film_dropout = nn.Dropout(film_dropout)
        self.t2v = Time2Vec(time_dim)
        
        # Heads
        fused_dim = 256 + emb_total + (time_dim + 1) + 1
        self.reg_head = nn.Sequential(nn.Linear(fused_dim, 256), nn.GELU(), nn.Dropout(final_dropout),
                                      nn.Linear(256, 128), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(128, 1))
        self.clf_head = nn.Sequential(nn.Linear(fused_dim, 256), nn.GELU(), nn.Dropout(final_dropout),
                                      nn.Linear(256, 128), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(128, 1))

    def forward(self, x_cat, x_num, x_time):
        emb = self.emb_dropout(torch.cat([emb(x_cat[:, i]) for i, emb in enumerate(self.embeddings)], dim=-1))
        
        h = F.gelu(self.fc_num(x_num))
        for block in self.res_blocks: h = block(h)
        
        gamma, beta = torch.chunk(self.film(h), 2, dim=-1)
        emb_mod = self.film_dropout(gamma) * emb + self.film_dropout(beta)
        
        z = torch.cat([emb_mod, h, self.t2v(x_time), x_time], dim=-1)
        return self.reg_head(z), self.clf_head(z)

## Hyperparameter Tuning

In [0]:
def train_one_epoch(model, loader, opt, crit_reg, crit_clf, alpha):
    model.train()
    for cat, num, time, y in loader:
        cat, num, time, y = cat.to(DEVICE), num.to(DEVICE), time.to(DEVICE), y.to(DEVICE)
        opt.zero_grad()
        reg, clf = model(cat, num, time)
        loss = alpha * crit_reg(reg, y) + (1-alpha) * crit_clf(clf, (y >= 15.0).float())
        loss.backward()
        opt.step()

In [0]:
def evaluate_metrics(model, loader):
    model.eval()
    preds_reg, preds_clf = [], []
    targets_reg, targets_clf = [], []
    with torch.no_grad():
        for cat, num, time, y in loader:
            reg, clf = model(cat.to(DEVICE), num.to(DEVICE), time.to(DEVICE))
            preds_reg.append(reg.cpu()); targets_reg.append(y.cpu())
            preds_clf.append(torch.sigmoid(clf).cpu()); targets_clf.append((y >= 15.0).float().cpu())
    
    y_pred_reg, y_true_reg = torch.cat(preds_reg).numpy(), torch.cat(targets_reg).numpy()
    y_pred_clf, y_true_clf = torch.cat(preds_clf).numpy(), torch.cat(targets_clf).numpy()
    
    mae = mean_absolute_error(y_true_reg, y_pred_reg)
    f2 = fbeta_score(y_true_clf, (y_pred_clf > 0.5).astype(int), beta=2)
    return f2, mae

In [0]:
def objective(trial):
    params = {
        "lr": trial.suggest_float("lr", 1e-4, 5e-3, log=True),
        "batch_size": trial.suggest_categorical("batch_size", [1024, 2048, 4096]),
        "alpha": trial.suggest_float("alpha", 0.3, 0.7),
        "time_dim": trial.suggest_categorical("time_dim", [4, 8, 16]),
        "emb_drop": trial.suggest_float("emb_drop", 0.0, 0.4),
        "num_drop": trial.suggest_float("num_drop", 0.0, 0.4),
        "final_drop": trial.suggest_float("final_drop", 0.0, 0.4)
    }
    
    tuning_folds = [folds[0], folds[len(folds)//2], folds[-1]]
    val_f2_list, val_mae_list = [], []
    train_f2_list, train_mae_list = [], [] 
    
    print(f"\n[Trial {trial.number}] START | Batch={params['batch_size']}, LR={params['lr']:.5f}")
    sys.stdout.flush() 

    for i, fold in enumerate(tuning_folds):
        print(f"  > Fold {fold} ({i+1}/3)...", end=" ")
        sys.stdout.flush()
        start_t = time.time()
        
        train_pd, val_pd, cat_dims, emb_dims = prepare_fold_data(
            cv_full_df.filter(sf.col("fold_id") == fold), 
            categorical_cols, numerical_cols, time_col, target_col
        )
        
        # num_workers=0 to prevent multiprocessing crash
        train_dl = DataLoader(FlightDataset(train_pd), batch_size=params["batch_size"], shuffle=True, num_workers=0, pin_memory=True)
        val_dl = DataLoader(FlightDataset(val_pd), batch_size=params["batch_size"], num_workers=0, pin_memory=True)
        
        model = ResFiLMMLP(cat_dims, emb_dims, len(numerical_cols), time_dim=params["time_dim"],
                           emb_dropout=params["emb_drop"], num_dropout=params["num_drop"], 
                           final_dropout=params["final_drop"]).to(DEVICE)
        opt = optim.AdamW(model.parameters(), lr=params["lr"])
        crit_reg = nn.L1Loss()
        crit_clf = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4.0]).to(DEVICE))
        
        for _ in range(3):
            train_one_epoch(model, train_dl, opt, crit_reg, crit_clf, params["alpha"])
        
        v_f2, v_mae = evaluate_metrics(model, val_dl)
        val_f2_list.append(v_f2); val_mae_list.append(v_mae)
        
        t_f2, t_mae = evaluate_metrics(model, train_dl)
        train_f2_list.append(t_f2); train_mae_list.append(t_mae)
        
        elapsed = (time.time() - start_t) / 60
        print(f"Done ({elapsed:.1f}m). Val F2={v_f2:.3f} | Train F2={t_f2:.3f}")
        sys.stdout.flush()
        
        trial.report(np.mean(val_f2_list), i)
        if trial.should_prune(): 
            print(f"  > Pruned at Fold {fold}.")
            mlflow.set_tag("status", "pruned")
            raise optuna.exceptions.TrialPruned()
        
    with mlflow.start_run(nested=True, run_name=f"Trial_{trial.number}"):
        mlflow.log_params(params)
        mlflow.log_metric("val_f2", np.mean(val_f2_list))
        mlflow.log_metric("val_mae", np.mean(val_mae_list))
        mlflow.log_metric("train_f2", np.mean(train_f2_list)) 
        mlflow.log_metric("train_mae", np.mean(train_mae_list)) 
    
    return np.mean(val_f2_list)

In [0]:
# ---------------------------------------------------------
# Hyperparameter Tuning (Single Worker)
# ---------------------------------------------------------
import optuna
from optuna.integration.mlflow import MLflowCallback
import sys
import time

CV_DATA_PATH = "dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday_nnfeat/cv_splits"
cv_full_df = spark.read.parquet(CV_DATA_PATH)
folds = sorted([row['fold_id'] for row in cv_full_df.select("fold_id").distinct().collect()])

EXPERIMENT_NAME = "/Shared/team_2_2/mlflow-nn-tower-tuned"
mlflow.set_experiment(EXPERIMENT_NAME)

print(f"--- Starting Optuna Tuning (8 Trials, Stable Mode) ---")
sys.stdout.flush()

with mlflow.start_run(run_name="Hyperparameter_Tuning_Session"):
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=8)

BEST_PARAMS = study.best_params
print(f"\n>>> Best Params: {BEST_PARAMS}")

--- Starting Optuna Tuning (8 Trials, Stable Mode) ---


[I 2025-12-11 07:39:08,784] A new study created in memory with name: no-name-b9410282-c4aa-4b84-8474-d6581c428c56



[Trial 0] START | Batch=2048, LR=0.00096
  > Fold 1 (1/3)... Done (23.5m). Val F2=0.569 | Train F2=0.590
  > Fold 6 (2/3)... Done (22.7m). Val F2=0.536 | Train F2=0.575
  > Fold 10 (3/3)... Done (24.9m). Val F2=0.572 | Train F2=0.582


[I 2025-12-11 08:50:18,478] Trial 0 finished with value: 0.5588664642492351 and parameters: {'lr': 0.0009575058150175052, 'batch_size': 2048, 'alpha': 0.40065021283464114, 'time_dim': 4, 'emb_drop': 0.16404933909407846, 'num_drop': 0.16736632485412783, 'final_drop': 0.21101424331333254}. Best is trial 0 with value: 0.5588664642492351.



[Trial 1] START | Batch=1024, LR=0.00038
  > Fold 1 (1/3)... Done (22.6m). Val F2=0.560 | Train F2=0.584
  > Fold 6 (2/3)... Done (24.1m). Val F2=0.555 | Train F2=0.598
  > Fold 10 (3/3)... Done (25.3m). Val F2=0.572 | Train F2=0.579


[I 2025-12-11 10:02:19,849] Trial 1 finished with value: 0.5625465733551208 and parameters: {'lr': 0.00038191325052377973, 'batch_size': 1024, 'alpha': 0.5817744495942692, 'time_dim': 4, 'emb_drop': 0.2973993187515431, 'num_drop': 0.01820144637246144, 'final_drop': 0.05244952735329807}. Best is trial 1 with value: 0.5625465733551208.



[Trial 2] START | Batch=4096, LR=0.00061
  > Fold 1 (1/3)... Done (23.8m). Val F2=0.577 | Train F2=0.597
  > Fold 6 (2/3)... Done (22.7m). Val F2=0.539 | Train F2=0.576
  > Fold 10 (3/3)... Done (26.1m). Val F2=0.576 | Train F2=0.581


[I 2025-12-11 11:14:55,948] Trial 2 finished with value: 0.5638415356441168 and parameters: {'lr': 0.0006079442355415744, 'batch_size': 4096, 'alpha': 0.6171095826693376, 'time_dim': 16, 'emb_drop': 0.31252774208004025, 'num_drop': 0.08086622245692543, 'final_drop': 0.3506675733290697}. Best is trial 2 with value: 0.5638415356441168.



[Trial 3] START | Batch=4096, LR=0.00016
  > Fold 1 (1/3)... Done (24.7m). Val F2=0.565 | Train F2=0.582
  > Fold 6 (2/3)... Done (24.5m). Val F2=0.544 | Train F2=0.584
  > Fold 10 (3/3)... Done (26.5m). Val F2=0.597 | Train F2=0.600


[I 2025-12-11 12:30:36,594] Trial 3 finished with value: 0.5687250006294451 and parameters: {'lr': 0.00015558999196709387, 'batch_size': 4096, 'alpha': 0.34214834331587723, 'time_dim': 16, 'emb_drop': 0.046257475013397054, 'num_drop': 0.32431529675298293, 'final_drop': 0.10010488301186182}. Best is trial 3 with value: 0.5687250006294451.



[Trial 4] START | Batch=4096, LR=0.00013
  > Fold 1 (1/3)... Done (24.9m). Val F2=0.565 | Train F2=0.586
  > Fold 6 (2/3)... Done (24.6m). Val F2=0.548 | Train F2=0.587
  > Fold 10 (3/3)... Done (26.9m). Val F2=0.577 | Train F2=0.583


[I 2025-12-11 13:47:01,732] Trial 4 finished with value: 0.5634774039139566 and parameters: {'lr': 0.00012577372010664665, 'batch_size': 4096, 'alpha': 0.33059836514431606, 'time_dim': 16, 'emb_drop': 0.3265276151501351, 'num_drop': 0.37956164780145407, 'final_drop': 0.2627991690022928}. Best is trial 3 with value: 0.5687250006294451.



[Trial 5] START | Batch=4096, LR=0.00186
  > Fold 1 (1/3)... Done (24.7m). Val F2=0.561 | Train F2=0.580
  > Pruned at Fold 1.


[I 2025-12-11 14:11:45,191] Trial 5 pruned. 



[Trial 6] START | Batch=2048, LR=0.00042
  > Fold 1 (1/3)... Done (25.0m). Val F2=0.577 | Train F2=0.596
  > Fold 6 (2/3)... Done (28.7m). Val F2=0.543 | Train F2=0.584
  > Fold 10 (3/3)... Done (28.5m). Val F2=0.571 | Train F2=0.580


[I 2025-12-11 15:33:51,876] Trial 6 finished with value: 0.5639208265448387 and parameters: {'lr': 0.0004174074924664462, 'batch_size': 2048, 'alpha': 0.6374870217158395, 'time_dim': 4, 'emb_drop': 0.17756553684120396, 'num_drop': 0.33457012601060093, 'final_drop': 0.091024878754628}. Best is trial 3 with value: 0.5687250006294451.



[Trial 7] START | Batch=2048, LR=0.00130
  > Fold 1 (1/3)... Done (25.2m). Val F2=0.555 | Train F2=0.577


[I 2025-12-11 15:59:01,277] Trial 7 pruned. 


  > Pruned at Fold 1.

>>> Best Params: {'lr': 0.00015558999196709387, 'batch_size': 4096, 'alpha': 0.34214834331587723, 'time_dim': 16, 'emb_drop': 0.046257475013397054, 'num_drop': 0.32431529675298293, 'final_drop': 0.10010488301186182}


## Training on full dataset + early stopping

In [0]:
# ---------------------------------------------------------
# Final Training Run (FULL DATASET + EARLY STOPPING)
# ---------------------------------------------------------
import torch
import mlflow
import pandas as pd
import numpy as np
import sys
import time
import copy
import os
import gc
import uuid
import pyspark.sql.functions as sf 
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from pyspark.ml.feature import StringIndexer
from pyspark.sql.types import IntegerType, FloatType
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import fbeta_score
from mlflow.models import infer_signature

# MODEL ARCHITECTURE
class ResBlock(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        h = F.gelu(self.fc1(self.ln(x)))
        return x + self.fc2(self.dropout(h))

class Time2Vec(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.wb = nn.Linear(1, 1)
        self.ws = nn.Linear(1, k)
    def forward(self, t):
        return torch.cat([self.wb(t), torch.sin(self.ws(t))], dim=-1)

class ResFiLMMLP(nn.Module):
    def __init__(self, cat_dims, emb_dims, num_numerical, time_dim=8, 
                 emb_dropout=0.1, num_dropout=0.1, film_dropout=0.1, final_dropout=0.2):
        super().__init__()
        self.embeddings = nn.ModuleList([nn.Embedding(d, e) for d, e in zip(cat_dims, emb_dims)])
        self.emb_dropout = nn.Dropout(emb_dropout)
        emb_total = sum(emb_dims)
        self.fc_num = nn.Linear(num_numerical, 256)
        self.res_blocks = nn.ModuleList([ResBlock(256, num_dropout) for _ in range(4)])
        self.film = nn.Linear(256, 2 * emb_total)
        self.film_dropout = nn.Dropout(film_dropout)
        self.t2v = Time2Vec(time_dim)
        fused_dim = 256 + emb_total + (time_dim + 1) + 1
        self.reg_head = nn.Sequential(nn.Linear(fused_dim, 256), nn.GELU(), nn.Dropout(final_dropout),
                                      nn.Linear(256, 128), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(128, 1))
        self.clf_head = nn.Sequential(nn.Linear(fused_dim, 256), nn.GELU(), nn.Dropout(final_dropout),
                                      nn.Linear(256, 128), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(128, 1))

    def forward(self, x_cat, x_num, x_time):
        emb = self.emb_dropout(torch.cat([emb(x_cat[:, i]) for i, emb in enumerate(self.embeddings)], dim=-1))
        h = F.gelu(self.fc_num(x_num))
        for block in self.res_blocks: h = block(h)
        gamma, beta = torch.chunk(self.film(h), 2, dim=-1)
        emb_mod = self.film_dropout(gamma) * emb + self.film_dropout(beta)
        z = torch.cat([emb_mod, h, self.t2v(x_time), x_time], dim=-1)
        return self.reg_head(z), self.clf_head(z)

# DATASET & HELPERS
class ProductionFlightDataset(Dataset):
    def __init__(self, df, cat_cols, num_cols, time_col, target_col, id_col="flight_uid"):
        self.cat = torch.tensor(df[cat_cols].values, dtype=torch.long)
        self.num = torch.tensor(df[num_cols].values, dtype=torch.float32)
        self.time = torch.tensor(df[time_col].values, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(df[target_col].values, dtype=torch.float32).unsqueeze(1)
        self.ids = df[id_col].values 
    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.cat[idx], self.num[idx], self.time[idx], self.y[idx], self.ids[idx]

def evaluate_metrics(model, loader):
    model.eval()
    all_y, all_reg, all_clf = [], [], []
    with torch.no_grad():
        for cat, num, time_t, y, _ in loader:
            reg, clf = model(cat, num, time_t)
            all_y.append(y); all_reg.append(reg); all_clf.append(clf)
    
    y_true = torch.cat(all_y).cpu().numpy().flatten()
    y_reg = torch.cat(all_reg).cpu().numpy().flatten()
    y_prob = torch.sigmoid(torch.cat(all_clf)).cpu().numpy().flatten()
    
    mae = np.mean(np.abs(y_true - y_reg))
    f2 = fbeta_score((y_true >= 15.0).astype(int), (y_prob > 0.5).astype(int), beta=2, zero_division=0)
    return f2, mae

def save_predictions_and_link(model, loader, fold_name, save_path_base):
    print(f"  > Generating predictions for {fold_name}...")
    model.eval()
    res = {"flight_uid": [], "target_delay": [], "pred_delay": [], "target_class": [], "pred_prob": []}
    with torch.no_grad():
        for cat, num, time_t, y, ids in loader:
            reg, clf = model(cat, num, time_t)
            res["flight_uid"].append(ids)
            res["target_delay"].append(y.numpy().flatten())
            res["pred_delay"].append(reg.numpy().flatten())
            res["target_class"].append((y >= 15.0).float().numpy().flatten())
            res["pred_prob"].append(torch.sigmoid(clf).numpy().flatten())
    
    flat_ids = np.concatenate(res["flight_uid"]) if len(res["flight_uid"]) > 0 else []
    pdf = pd.DataFrame({
        "flight_uid": flat_ids, "target_delay": np.concatenate(res["target_delay"]),
        "pred_delay": np.concatenate(res["pred_delay"]), "target_class": np.concatenate(res["target_class"]),
        "pred_prob": np.concatenate(res["pred_prob"]), "split_name": fold_name
    })
    
    unique_id = str(uuid.uuid4())[:8]
    # MLflow Upload
    temp_file = f"/tmp/{fold_name}_{unique_id}.parquet"
    try:
        pdf.to_parquet(temp_file)
        mlflow.log_artifact(temp_file, artifact_path="predictions")
    except: pass
    
    # DBFS Save
    save_path = f"{save_path_base}/{fold_name}_{unique_id}"
    print(f"    >> Saving to DBFS: {save_path}")
    spark.createDataFrame(pdf).write.mode("overwrite").parquet(save_path)
    
    # CSV Link
    csv_name = f"{fold_name}_{unique_id}.csv"
    local_csv = f"/tmp/{csv_name}"
    pdf.to_csv(local_csv, index=False)
    dbutils.fs.cp(f"file:{local_csv}", f"dbfs:/FileStore/shared_uploads/predictions/{csv_name}")
    displayHTML(f"""<div style="background-color:#e6f7ff;padding:10px;border:1px solid #91d5ff;">
    <b>{fold_name} Predictions:</b> <a href="/files/shared_uploads/predictions/{csv_name}" target="_blank">Download CSV</a></div>""")
    
    if os.path.exists(temp_file): os.remove(temp_file)
    if os.path.exists(local_csv): os.remove(local_csv)

# CONFIG & DATA 
BASE_PATH = "dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/"
TRAIN_PATH = BASE_PATH + "train.parquet/"
VAL_PATH   = BASE_PATH + "val.parquet/"
PREDS_SAVE_PATH = "dbfs:/student-groups/Group_2_2/5_year_custom_joined/nn_predictions_final"

params = {'lr': 0.0001556, 'batch_size': 4096, 'alpha': 0.342, 'time_dim': 16, 
          'emb_drop': 0.046, 'num_drop': 0.324, 'final_drop': 0.100}
NUM_EPOCHS = 10  # Increased since we have early stopping
PATIENCE = 4

print("--- !!! FULL PRODUCTION RUN !!! ---")
# Load FULL Dataset
train_spark = spark.read.parquet(TRAIN_PATH)
val_spark = spark.read.parquet(VAL_PATH)

# Leakage Check
if "xgb_predicted_delay" in train_spark.columns: train_spark = train_spark.drop("xgb_predicted_delay")
if "xgb_predicted_delay" in val_spark.columns: val_spark = val_spark.drop("xgb_predicted_delay")

# Indexing
print("  > Indexing...")
indexers = [StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep").fit(train_spark) for c in categorical_cols]

def spark_to_pandas_safe(df, indexers):
    for ind in indexers: df = ind.transform(df)
    select_expr = [sf.col(f"{c}_idx").cast(IntegerType()).alias(c) for c in categorical_cols] + \
                  [sf.col(c).cast(FloatType()) for c in numerical_cols] + \
                  [sf.col(time_col).cast(FloatType()), sf.col(target_col).cast(FloatType()), sf.col("flight_uid")]
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    return df.select(*select_expr).toPandas()

train_pd = spark_to_pandas_safe(train_spark, indexers)
val_pd = spark_to_pandas_safe(val_spark, indexers)

# Scaling
scaler = StandardScaler()
train_pd[numerical_cols] = scaler.fit_transform(train_pd[numerical_cols])
val_pd[numerical_cols] = scaler.transform(val_pd[numerical_cols])

# Dims
cat_dims = [int(max(train_pd[c].max(), val_pd[c].max()) + 2) for c in categorical_cols]
emb_dims = [min(64, int(n**0.3)) for n in cat_dims]

train_ds = ProductionFlightDataset(train_pd, categorical_cols, numerical_cols, time_col, target_col)
val_ds = ProductionFlightDataset(val_pd, categorical_cols, numerical_cols, time_col, target_col)
print(f"  > Train: {len(train_ds)} | Val: {len(val_ds)}")
del train_pd, val_pd, train_spark, val_spark; gc.collect()

# TRAINING LOOP
torch.set_num_threads(60)
DEVICE = torch.device("cpu")
train_dl = DataLoader(train_ds, batch_size=params["batch_size"], shuffle=True, num_workers=0)
val_dl = DataLoader(val_ds, batch_size=params["batch_size"], num_workers=0)

mlflow.set_experiment("/Shared/team_2_2/mlflow-nn-tower-final")
print(f"\n--- Starting Full Training ({NUM_EPOCHS} Epochs, Patience={PATIENCE}) ---")

with mlflow.start_run(run_name="Final_Prod_Run_Full_Data"):
    mlflow.log_params(params)
    model = ResFiLMMLP(cat_dims, emb_dims, len(numerical_cols), time_dim=params["time_dim"],
                       emb_dropout=params["emb_drop"], num_dropout=params["num_drop"], 
                       final_dropout=params["final_drop"]).to(DEVICE)
    opt = optim.AdamW(model.parameters(), lr=params["lr"])
    crit_reg = nn.L1Loss()
    crit_clf = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4.0]))

    best_f2 = -1.0; best_state = None
    patience_cnt = 0 # Early Stopping Counter
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        t_loss_sum, t_ae_sum, t_tp, t_fp, t_fn = 0.0, 0.0, 0, 0, 0
        batch_cnt = 0; total_samples = 0
        
        for i, (cat, num, t, y, _) in enumerate(train_dl):
            opt.zero_grad()
            reg, clf = model(cat, num, t)
            loss = params["alpha"]*crit_reg(reg, y) + (1-params["alpha"])*crit_clf(clf, (y>=15.0).float())
            loss.backward(); opt.step()
            
            with torch.no_grad():
                batch_size = y.size(0)
                t_loss_sum += loss.item(); t_ae_sum += torch.sum(torch.abs(reg - y)).item()
                y_true = (y >= 15.0).long(); y_pred = (torch.sigmoid(clf) > 0.5).long()
                t_tp += ((y_true == 1) & (y_pred == 1)).sum().item()
                t_fp += ((y_true == 0) & (y_pred == 1)).sum().item()
                t_fn += ((y_true == 1) & (y_pred == 0)).sum().item()
                batch_cnt += 1; total_samples += batch_size

            if i % 500 == 0 and i > 0: print(f"    Ep {epoch} Batch {i}/{len(train_dl)}...", end="\r")
        
        train_loss = t_loss_sum / batch_cnt
        train_mae = t_ae_sum / total_samples
        precision = t_tp / (t_tp + t_fp + 1e-8); recall = t_tp / (t_tp + t_fn + 1e-8); beta = 2
        train_f2 = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall + 1e-8)
        
        val_f2, val_mae = evaluate_metrics(model, val_dl)
        print(f"    Ep {epoch}: Val F2={val_f2:.4f} MAE={val_mae:.2f} | Train F2={train_f2:.4f} MAE={train_mae:.2f}")
        
        mlflow.log_metrics({"train_loss": train_loss, "train_mae": train_mae, "train_f2": train_f2, "val_mae": val_mae, "val_f2": val_f2}, step=epoch)
        
        # --- EARLY STOPPING CHECK ---
        if val_f2 > best_f2:
            best_f2 = val_f2
            best_state = copy.deepcopy(model.state_dict())
            patience_cnt = 0 # Reset counter
        else:
            patience_cnt += 1
            print(f"       >> No improve. Patience {patience_cnt}/{PATIENCE}")
            if patience_cnt >= PATIENCE:
                print("       >> Early stopping triggered.")
                break
            
    print("\n--- Finalizing Best Model ---")
    model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        c, n, t, _, _ = next(iter(train_dl))
        r, cl = model(c, n, t)
        sig = infer_signature({"cat": c.numpy(), "num": n.numpy(), "time": t.numpy()}, {"pred_delay": r.numpy(), "pred_class": cl.numpy()})
    mlflow.pytorch.log_model(model, "model_final", signature=sig)
    
    save_predictions_and_link(model, val_dl, "final_val_preds", PREDS_SAVE_PATH)
    print("SUCCESS. Run complete.")

--- !!! FULL PRODUCTION RUN !!! ---
  > Indexing...
  > Train: 18066230 | Val: 4384854

--- Starting Full Training (10 Epochs, Patience=4) ---
    Ep 0 Batch 500/4411...    Ep 0 Batch 1000/4411...    Ep 0 Batch 1500/4411...    Ep 0 Batch 2000/4411...    Ep 0 Batch 2500/4411...    Ep 0 Batch 3000/4411...    Ep 0 Batch 3500/4411...    Ep 0 Batch 4000/4411...    Ep 0: Val F2=0.5935 MAE=10.97 | Train F2=0.5641 MAE=9.94
    Ep 1 Batch 500/4411...    Ep 1 Batch 1000/4411...    Ep 1 Batch 1500/4411...    Ep 1 Batch 2000/4411...    Ep 1 Batch 2500/4411...    Ep 1 Batch 3000/4411...    Ep 1 Batch 3500/4411...    Ep 1 Batch 4000/4411...    Ep 1: Val F2=0.5876 MAE=10.90 | Train F2=0.5803 MAE=9.72
       >> No improve. Patience 1/4
    Ep 2 Batch 500/4411...    Ep 2 Batch 1000/4411...    Ep 2 Batch 1500/4411...    Ep 2 Batch 2000/4411...    Ep 2 Batch 2500/4411...    Ep 2 Batch 3000/4411...    Ep 2 Batch 3500/4411...    Ep 2 Batch 4000/4411...    Ep 2: Val F2=0.5987 MAE=10.

Uploading artifacts:   0%|          | 0/10 [00:00<?, ?it/s]

  > Generating predictions for final_val_preds...
    >> Saving to DBFS: dbfs:/student-groups/Group_2_2/5_year_custom_joined/nn_predictions_final/final_val_preds_4e65b009


SUCCESS. Run complete.


## Optimal Classification Threshold Finder

In [0]:
# ---------------------------------------------------------
# Optimal Classification Threshold Finder
# ---------------------------------------------------------
import mlflow
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import fbeta_score
import time

print("--- Starting Optimal Threshold Optimization ---")

# --- 1. Generate Full Validation Probabilities and True Labels ---
def get_validation_probabilities(model, loader):
    """Generates a single array of true labels and probabilities for the entire validation set."""
    model.eval()
    all_y_true = []
    all_y_prob = []
    
    # Disable gradient tracking as this is just inference
    with torch.no_grad():
        for cat, num, time_t, y, _ in loader:
            # Note: We are only interested in the classification head (clf)
            _, clf = model(cat, num, time_t)
            
            # Convert delay minutes (y) to binary class (0 or 1)
            y_true_class = (y >= 15.0).long().cpu().numpy().flatten()
            
            # Convert logits (clf) to probabilities
            y_prob = torch.sigmoid(clf).cpu().numpy().flatten()
            
            all_y_true.append(y_true_class)
            all_y_prob.append(y_prob)
            
    return np.concatenate(all_y_true), np.concatenate(all_y_prob)

start_time = time.time()
print("  > Generating probabilities for Validation Set...")
y_true_val, y_prob_val = get_validation_probabilities(model, val_dl)
print(f"  > Probabilities generated in {time.time() - start_time:.1f} seconds.")

# --- 2. Define Threshold Sweep Range ---
# Sweep thresholds from 0.05 to 0.95 with small steps
threshold_range = np.arange(0.05, 0.95, 0.01)

best_f2_score = -1.0
optimal_threshold = 0.5
f2_scores = []

# --- 3. Sweep Thresholds and Calculate F2 Score ---
print("  > Sweeping thresholds and calculating F2...")
for threshold in threshold_range:
    # Classify based on the current threshold
    y_pred_class = (y_prob_val >= threshold).astype(int)
    
    # Calculate F2 score (beta=2 is what we optimize for)
    current_f2 = fbeta_score(y_true_val, y_pred_class, beta=2, zero_division=0)
    f2_scores.append(current_f2)
    
    if current_f2 > best_f2_score:
        best_f2_score = current_f2
        optimal_threshold = threshold

# --- 4. Log Results to MLflow ---
mlflow.set_experiment("/Shared/team_2_2/mlflow-nn-tower-final")
with mlflow.start_run(run_name="Threshold_Optimization_F2_Score", nested=True):
    
    mlflow.log_param("optimization_metric", "F2 Score (beta=2)")
    mlflow.log_param("sweep_range_start", threshold_range[0])
    mlflow.log_param("sweep_range_end", threshold_range[-1])
    
    mlflow.log_metric("best_f2_score_found", best_f2_score)
    mlflow.log_metric("optimal_threshold_f2", optimal_threshold)
    
    # Log the full curve for visualization
    threshold_log_data = pd.DataFrame({
        "threshold": threshold_range,
        "f2_score": f2_scores
    })
    
    # Create a local file to log as an artifact
    temp_csv_path = f"/tmp/threshold_curve_{str(uuid.uuid4())[:8]}.csv"
    threshold_log_data.to_csv(temp_csv_path, index=False)
    mlflow.log_artifact(temp_csv_path, "threshold_analysis")
    os.remove(temp_csv_path)

    print(f"\nOptimization Complete:")
    print(f"  > Optimal F2 Score: {best_f2_score:.5f}")
    print(f"  > Optimal Threshold: {optimal_threshold:.3f}")
    print(f"  > Detailed curve logged to MLflow run.")

--- Starting Optimal Threshold Optimization ---
  > Generating probabilities for Validation Set...
  > Probabilities generated in 107.8 seconds.
  > Sweeping thresholds and calculating F2...

Optimization Complete:
  > Optimal F2 Score: 0.62595
  > Optimal Threshold: 0.360
  > Detailed curve logged to MLflow run.


## Run Inferences using trained model + optimized threshold (Restart-safe)

In [0]:
import torch
import mlflow
import pandas as pd
import numpy as np
import sys
import time
import copy
import os
import gc
import uuid
import pyspark.sql.functions as sf 
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from pyspark.ml.feature import StringIndexer
from pyspark.sql.types import IntegerType, FloatType
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import fbeta_score
from mlflow.models import infer_signature

In [0]:
# ---------------------------------------------------------
# Rebuild Environment and Run Final Inference
# ---------------------------------------------------------

# config
MLFLOW_RUN_ID = "8706b956e0bd4ed681234979ad86206b"
OPTIMAL_THRESHOLD = 0.36 

# features
categorical_cols = [
    "OP_UNIQUE_CARRIER", "ORIGIN_AIRPORT_SEQ_ID", "DEST_AIRPORT_SEQ_ID",
    "route", "AIRPORT_HUB_CLASS", "AIRLINE_CATEGORY"
]
numerical_cols = [
    "DISTANCE", "CRS_ELAPSED_TIME", "prev_flight_delay_in_minutes",
    "origin_delays_4h", "delay_origin_7d", "delay_origin_carrier_7d",
    "delay_route_7d", "flight_count_24h", "AVG_TAXI_OUT_ORIGIN",
    "AVG_ARR_DELAY_ORIGIN", "in_degree", "out_degree",
    "weighted_in_degree", "weighted_out_degree", "betweenness",
    "closeness", "N_RUNWAYS", "HourlyVisibility", "HourlyStationPressure",
    "HourlyWindSpeed", "HourlyDryBulbTemperature", "HourlyDewPointTemperature",
    "HourlyRelativeHumidity", "HourlyAltimeterSetting", "HourlyWetBulbTemperature",
    "HourlyPrecipitation", "HourlyCloudCoverage", "HourlyCloudElevation",
    "ground_flights_last_hour", "arrivals_last_hour",
    "dow_sin", "dow_cos", "doy_sin", "doy_cos"
]
numerical_cols = list(dict.fromkeys(numerical_cols))
time_col = "CRS_DEP_MINUTES"
target_col = "DEP_DELAY_NEW"

# paths & params
BASE_PATH = "dbfs:/student-groups/Group_2_2/5_year_custom_joined/fe_graph_and_holiday_nnfeat/training_splits/"
TRAIN_PATH = BASE_PATH + "train.parquet/"
VAL_PATH   = BASE_PATH + "val.parquet/" 
TEST_PATH  = BASE_PATH + "test.parquet/"
PREDS_SAVE_PATH = "dbfs:/student-groups/Group_2_2/5_year_custom_joined/nn_predictions_final"

# hyperparameters
params = {'lr': 0.0001556, 'batch_size': 4096, 'alpha': 0.342, 'time_dim': 16, 
          'emb_drop': 0.046, 'num_drop': 0.324, 'final_drop': 0.100}


# =========================================================
# REDEFINE CLASSES AND FUNCTIONS (restart-safe)
# =========================================================

class ResBlock(nn.Module):
    def __init__(self, dim, dropout=0.1): super().__init__(); self.ln = nn.LayerNorm(dim); self.fc1 = nn.Linear(dim, dim); self.fc2 = nn.Linear(dim, dim); self.dropout = nn.Dropout(dropout)
    def forward(self, x): h = F.gelu(self.fc1(self.ln(x))); return x + self.fc2(self.dropout(h))

class Time2Vec(nn.Module):
    def __init__(self, k): super().__init__(); self.wb = nn.Linear(1, 1); self.ws = nn.Linear(1, k)
    def forward(self, t): return torch.cat([self.wb(t), torch.sin(self.ws(t))], dim=-1)

class ResFiLMMLP(nn.Module):
    def __init__(self, cat_dims, emb_dims, num_numerical, time_dim=8, emb_dropout=0.1, num_dropout=0.1, film_dropout=0.1, final_dropout=0.2):
        super().__init__()
        self.embeddings = nn.ModuleList([nn.Embedding(d, e) for d, e in zip(cat_dims, emb_dims)]); self.emb_dropout = nn.Dropout(emb_dropout); emb_total = sum(emb_dims)
        self.fc_num = nn.Linear(num_numerical, 256); self.res_blocks = nn.ModuleList([ResBlock(256, num_dropout) for _ in range(4)])
        self.film = nn.Linear(256, 2 * emb_total); self.film_dropout = nn.Dropout(film_dropout); self.t2v = Time2Vec(time_dim); fused_dim = 256 + emb_total + (time_dim + 1) + 1
        self.reg_head = nn.Sequential(nn.Linear(fused_dim, 256), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(256, 128), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(128, 1))
        self.clf_head = nn.Sequential(nn.Linear(fused_dim, 256), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(256, 128), nn.GELU(), nn.Dropout(final_dropout), nn.Linear(128, 1))
    def forward(self, x_cat, x_num, x_time):
        emb = self.emb_dropout(torch.cat([emb(x_cat[:, i]) for i, emb in enumerate(self.embeddings)], dim=-1)); h = F.gelu(self.fc_num(x_num))
        for block in self.res_blocks: h = block(h); gamma, beta = torch.chunk(self.film(h), 2, dim=-1); emb_mod = self.film_dropout(gamma) * emb + self.film_dropout(beta)
        z = torch.cat([emb_mod, h, self.t2v(x_time), x_time], dim=-1); return self.reg_head(z), self.clf_head(z)

class ProductionFlightDataset(Dataset):
    def __init__(self, df, cat_cols, num_cols, time_col, target_col, id_col="flight_uid"):
        self.cat = torch.tensor(df[cat_cols].values, dtype=torch.long); self.num = torch.tensor(df[num_cols].values, dtype=torch.float32)
        self.time = torch.tensor(df[time_col].values, dtype=torch.float32).unsqueeze(1); self.y = torch.tensor(df[target_col].values, dtype=torch.float32).unsqueeze(1)
        self.ids = df[id_col].values 
    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.cat[idx], self.num[idx], self.time[idx], self.y[idx], self.ids[idx]

def spark_to_pandas_safe(df, indexers, categorical_cols_names, numerical_cols_names, time_col_name, target_col_name):
    for ind in indexers: df = ind.transform(df)
    select_expr = [sf.col(f"{c}_idx").cast(IntegerType()).alias(c) for c in categorical_cols_names] + \
                  [sf.col(c).cast(FloatType()) for c in numerical_cols_names] + \
                  [sf.col(time_col_name).cast(FloatType()), sf.col(target_col_name).cast(FloatType()), sf.col("flight_uid")]
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    return df.select(*select_expr).toPandas()

def get_metrics_from_loader(model, loader, threshold):
    model.eval(); all_y, all_reg, all_clf = [], [], []
    with torch.no_grad():
        for cat, num, time_t, y, _ in loader:
            reg, clf = model(cat, num, time_t); all_y.append(y); all_reg.append(reg); all_clf.append(clf)
    y_true = torch.cat(all_y).cpu().numpy().flatten(); y_reg = torch.cat(all_reg).cpu().numpy().flatten(); y_prob = torch.sigmoid(torch.cat(all_clf)).cpu().numpy().flatten()
    mae = mean_absolute_error(y_true, y_reg); f2 = fbeta_score((y_true >= 15.0).astype(int), (y_prob >= threshold).astype(int), beta=2, zero_division=0)
    return f2, mae

def save_predictions_optimized(model, loader, fold_name, save_path_base, threshold):
    """Generates predictions, calculates metrics, saves to DBFS/MLflow, and generates download link."""
    print(f"  > Generating predictions for {fold_name}...")
    
    # 1. Start Timer 
    start_time = time.time()
    
    model.eval(); res = {"flight_uid": [], "target_delay": [], "pred_delay": [], "target_class": [], "pred_prob": [], "pred_class_optimized": []}
    
    with torch.no_grad():
        for cat, num, time_t, y, ids in loader:
            reg, clf = model(cat, num, time_t); y_prob = torch.sigmoid(clf).cpu().numpy().flatten()
            res["flight_uid"].append(ids); res["target_delay"].append(y.numpy().flatten()); res["pred_delay"].append(reg.numpy().flatten())
            res["target_class"].append((y >= 15.0).float().numpy().flatten()); res["pred_prob"].append(y_prob)
            res["pred_class_optimized"].append((y_prob >= threshold).astype(int))
    
    # 2. Stop Timer
    inference_time_seconds = time.time() - start_time
    
    flat_ids = np.concatenate(res["flight_uid"]) if len(res["flight_uid"]) > 0 else []; pdf = pd.DataFrame({
        "flight_uid": flat_ids, "target_delay": np.concatenate(res["target_delay"]), "pred_delay": np.concatenate(res["pred_delay"]), 
        "target_class": np.concatenate(res["target_class"]), "pred_prob": np.concatenate(res["pred_prob"]), "pred_class_optimized": np.concatenate(res["pred_class_optimized"]),
        "split_name": fold_name
    }); 
    
    # Calculate F2 and MAE from the collected predictions
    final_f2 = fbeta_score(pdf["target_class"], pdf["pred_class_optimized"], beta=2, zero_division=0)
    final_mae = mean_absolute_error(pdf["target_delay"], pdf["pred_delay"])
    
    # MLflow Logging
    mlflow.set_experiment("/Shared/team_2_2/mlflow-nn-tower-final"); run_name=f"{fold_name}_Result_T={threshold:.3f}"
    with mlflow.start_run(run_name=run_name, nested=True):
        mlflow.log_param("optimal_threshold_used", threshold); 
        mlflow.log_metric(f"{fold_name}_f2_optimized", final_f2);
        mlflow.log_metric(f"{fold_name}_mae_optimized", final_mae); # <-- MAE LOG ADDED
        mlflow.log_metric(f"{fold_name}_inference_time_seconds", inference_time_seconds); # <-- INFERENCE TIME LOG ADDED
    
    # DBFS Saving
    unique_id = str(uuid.uuid4())[:8]; save_path_final = f"{save_path_base}/{fold_name}_T{int(threshold*100)}_{unique_id}"
    spark.createDataFrame(pdf).write.mode("overwrite").parquet(save_path_final)
    
    # Download Link
    csv_name = f"{fold_name}_T{int(threshold*100)}_{unique_id}.csv"; local_csv = f"/tmp/{csv_name}"
    pdf.to_csv(local_csv, index=False); dbutils.fs.cp(f"file:{local_csv}", f"dbfs:/FileStore/shared_uploads/predictions/{csv_name}")
    print(f"  > Saved to DBFS: {save_path_final}"); print(f"  > Final F2 Score: {final_f2:.5f} | Final MAE: {final_mae:.4f}") # <-- MAE PRINT ADDED
    print(f"  > Inference Time: {inference_time_seconds:.2f} seconds") 
    displayHTML(f"""<div style="background-color:#ccffee;padding:10px;border:1px solid #91d5ff;"><b>{fold_name} PREDICTIONS (T={threshold:.3f}):</b> <a href="/files/shared_uploads/predictions/{csv_name}" target="_blank">Download CSV</a></div>""")
    if os.path.exists(local_csv): os.remove(local_csv)


# =========================================================
# EXECUTION
# =========================================================

print("--- 1. Recovering Environment and Data Scalers ---")
train_spark_full = spark.read.parquet(TRAIN_PATH)
print("  > Re-fitting String Indexers on Train Set...")
indexers = [StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep").fit(train_spark_full) for c in categorical_cols]
train_pd_full = spark_to_pandas_safe(train_spark_full, indexers, categorical_cols, numerical_cols, time_col, target_col)
del train_spark_full; gc.collect()

print("  > Re-fitting StandardScaler on Train Set...")
scaler = StandardScaler()
train_pd_full[numerical_cols] = scaler.fit_transform(train_pd_full[numerical_cols])

# Load Model
print("  > Loading Model & Determining embedding dimensions...")
# val_spark_temp = spark.read.parquet(VAL_PATH).limit(10000)
# val_pd_temp = spark_to_pandas_safe(val_spark_temp, indexers, categorical_cols, numerical_cols, time_col, target_col)
# cat_dims = [int(max(val_pd_temp[c].max(), val_pd_temp[c].max()) + 2) for c in categorical_cols]
# emb_dims = [min(64, int(n**0.3)) for n in cat_dims]
# del val_spark_temp, val_pd_temp; gc.collect()
model = mlflow.pytorch.load_model(f"runs:/{MLFLOW_RUN_ID}/model_final")


# --- B. Train & Validation Set Metric Calculation and Logging ---
print("\n--- 2. Calculating and Logging TRAIN & VAL Set Metrics ---")

# 1. Train Data Prep & Loader
train_ds = ProductionFlightDataset(train_pd_full, categorical_cols, numerical_cols, time_col, target_col)
train_dl = DataLoader(train_ds, batch_size=params["batch_size"], num_workers=0)
del train_pd_full; gc.collect()

# 2. Validation Data Prep & Loader (Full Load)
val_spark_full = spark.read.parquet(VAL_PATH)
val_pd_full = spark_to_pandas_safe(val_spark_full, indexers, categorical_cols, numerical_cols, time_col, target_col)
val_pd_full[numerical_cols] = scaler.transform(val_pd_full[numerical_cols])
val_ds = ProductionFlightDataset(val_pd_full, categorical_cols, numerical_cols, time_col, target_col)
del val_spark_full; 
val_dl = DataLoader(val_ds, batch_size=params["batch_size"], num_workers=0)


# 3. Calculate Metrics
train_f2, train_mae = get_metrics_from_loader(model, train_dl, OPTIMAL_THRESHOLD)
val_f2, val_mae = get_metrics_from_loader(model, val_dl, OPTIMAL_THRESHOLD)

# 4. Log Metrics (Nested under a new run for documentation)
mlflow.set_experiment("/Shared/team_2_2/mlflow-nn-tower-final")
with mlflow.start_run(run_name=f"Final_Reported_Metrics_T={OPTIMAL_THRESHOLD:.3f}", nested=True):
    mlflow.log_param("reporting_threshold", OPTIMAL_THRESHOLD)
    mlflow.log_metric("final_train_f2", train_f2)
    mlflow.log_metric("final_train_mae", train_mae)
    mlflow.log_metric("final_val_f2", val_f2)
    mlflow.log_metric("final_val_mae", val_mae)

# 5. Print Final Report
print("\n=======================================================")
print(f"| FINAL REPORTED METRICS (Threshold: {OPTIMAL_THRESHOLD:.3f}) |")
print("=======================================================")
print(f"| TRAIN SET (Internal Fit) | F2: {train_f2:.4f} | MAE: {train_mae:.4f} |")
print(f"| VALIDATION SET (Tuning)  | F2: {val_f2:.4f} | MAE: {val_mae:.4f} |")
print("=======================================================")


# --- C. Save Validation Predictions ---
print("\n--- 3. Saving VALIDATION Predictions ---")
# This run will calculate F2 and MAE and log them to MLflow
save_predictions_optimized(
    model, 
    val_dl, 
    "FINAL_VAL_OPTIMIZED",
    PREDS_SAVE_PATH, 
    OPTIMAL_THRESHOLD
)

# --- D. Run Final Optimized Inference (TEST Set) ---
print("\n--- 4. Running Final Optimized Inference on TEST Set ---")
test_spark = spark.read.parquet(TEST_PATH)
if "xgb_predicted_delay" in test_spark.columns: test_spark = test_spark.drop("xgb_predicted_delay")
    
test_pd = spark_to_pandas_safe(test_spark, indexers, categorical_cols, numerical_cols, time_col, target_col)
test_pd[numerical_cols] = scaler.transform(test_pd[numerical_cols])

test_ds = ProductionFlightDataset(test_pd, categorical_cols, numerical_cols, time_col, target_col)
test_dl = DataLoader(test_ds, batch_size=params["batch_size"], num_workers=0)
del test_pd, test_spark; gc.collect()

# This function calculates metrics and saves predictions, including the inference time
# This run will calculate F2, MAE, and Inference Time, logging them to MLflow
save_predictions_optimized(
    model, 
    test_dl, 
    "FINAL_TEST_OPTIMIZED", 
    PREDS_SAVE_PATH, 
    OPTIMAL_THRESHOLD
)

print("\nSUCCESS: All predictions and metrics generated and saved.")

--- 1. Recovering Environment and Data Scalers ---
  > Re-fitting String Indexers on Train Set...
  > Re-fitting StandardScaler on Train Set...
  > Loading Model & Determining embedding dimensions...


Downloading artifacts:   0%|          | 0/10 [00:00<?, ?it/s]


--- 2. Calculating and Logging TRAIN & VAL Set Metrics ---

| FINAL REPORTED METRICS (Threshold: 0.360) |
| TRAIN SET (Internal Fit) | F2: 0.6213 | MAE: 9.5335 |
| VALIDATION SET (Tuning)  | F2: 0.6260 | MAE: 10.8419 |

--- 3. Saving VALIDATION Predictions ---
  > Generating predictions for FINAL_VAL_OPTIMIZED...
  > Saved to DBFS: dbfs:/student-groups/Group_2_2/5_year_custom_joined/nn_predictions_final/FINAL_VAL_OPTIMIZED_T36_93b24868
  > Final F2 Score: 0.62595 | Final MAE: 10.8419
  > Inference Time: 95.92 seconds



--- 4. Running Final Optimized Inference on TEST Set ---
  > Generating predictions for FINAL_TEST_OPTIMIZED...
  > Saved to DBFS: dbfs:/student-groups/Group_2_2/5_year_custom_joined/nn_predictions_final/FINAL_TEST_OPTIMIZED_T36_9fea73a1
  > Final F2 Score: 0.61870 | Final MAE: 11.3363
  > Inference Time: 142.73 seconds



SUCCESS: All predictions and metrics generated and saved.
