# Delay Prediction Model

Predicts departure/arrival delays for flights using gradient-boosted trees.

**Target Variable:** `target_delay_minutes` (continuous)
**Algorithm:** XGBoost Regression via Snowflake ML
**Output:** `IROP_GNN_RISK.ML_PROCESSING.DELAY_PREDICTIONS`

In [None]:
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, lit, when, avg, sum as sum_, count, current_timestamp
from snowflake.snowpark.types import FloatType, IntegerType, StringType
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.registry import Registry
import uuid

In [None]:
session = get_active_session()
session.use_database('IROP_GNN_RISK')
session.use_schema('ATOMIC')
print(f"Connected: {session.get_current_database()}.{session.get_current_schema()}")

In [None]:
flights_df = session.table('FLIGHT_INSTANCE')
weather_df = session.table('WEATHER_ATC')
airports_df = session.table('AIRPORT_CAPABILITY')

print(f"Flights: {flights_df.count()} rows")
print(f"Weather: {weather_df.count()} rows")
print(f"Airports: {airports_df.count()} rows")

In [None]:
weather_agg = weather_df.group_by('STATION_CODE').agg(
    avg('CONVECTIVE_INDEX').alias('AVG_CONVECTIVE'),
    avg('EDCT_DELAY_MEAN').alias('AVG_EDCT'),
    avg('HOLDING_PROBABILITY').alias('AVG_HOLDING_PROB'),
    sum_(when(col('FLOW_PROGRAM_FLAG') == True, 1).otherwise(0)).alias('GDP_COUNT')
)

features_df = flights_df.join(
    weather_agg, 
    flights_df['DEPARTURE_STATION'] == weather_agg['STATION_CODE'],
    'left'
).join(
    airports_df,
    flights_df['DEPARTURE_STATION'] == airports_df['STATION_CODE'],
    'left'
).select(
    flights_df['FLIGHT_KEY'],
    flights_df['CURRENT_DELAY_DEPARTURE'].alias('TARGET_DELAY_MINUTES'),
    col('BLOCK_TIME_MINUTES'),
    col('TURN_BUFFER_MINUTES'),
    col('PAX_COUNT'),
    col('CONNECTING_PAX_PCT'),
    when(col('HUB_FLAG') == True, 1).otherwise(0).alias('IS_HUB'),
    when(col('INTL_CONNECTOR_FLAG') == True, 1).otherwise(0).alias('IS_INTL'),
    col('AVG_CONVECTIVE').cast(FloatType()).alias('WEATHER_CONVECTIVE'),
    col('AVG_EDCT').cast(FloatType()).alias('WEATHER_EDCT'),
    col('AVG_HOLDING_PROB').cast(FloatType()).alias('WEATHER_HOLDING'),
    col('ATC_CONGESTION_INDEX').cast(FloatType()).alias('ATC_CONGESTION')
).na.fill(0)

print(f"Feature dataset: {features_df.count()} rows")
features_df.show(5)

In [None]:
train_df, test_df = features_df.random_split([0.8, 0.2], seed=42)
print(f"Training set: {train_df.count()} rows")
print(f"Test set: {test_df.count()} rows")

In [None]:
feature_cols = [
    'BLOCK_TIME_MINUTES', 'TURN_BUFFER_MINUTES', 'PAX_COUNT',
    'CONNECTING_PAX_PCT', 'IS_HUB', 'IS_INTL',
    'WEATHER_CONVECTIVE', 'WEATHER_EDCT', 'WEATHER_HOLDING', 'ATC_CONGESTION'
]
target_col = 'TARGET_DELAY_MINUTES'

model = XGBRegressor(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=['PREDICTED_DELAY'],
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1
)

model.fit(train_df)
print("Model training complete")

In [None]:
predictions = model.predict(test_df)
predictions.select('FLIGHT_KEY', 'TARGET_DELAY_MINUTES', 'PREDICTED_DELAY').show(10)

from snowflake.snowpark.functions import abs as abs_, sqrt
metrics = predictions.select(
    avg(abs_(col('TARGET_DELAY_MINUTES') - col('PREDICTED_DELAY'))).alias('MAE'),
    sqrt(avg((col('TARGET_DELAY_MINUTES') - col('PREDICTED_DELAY')) ** 2)).alias('RMSE')
)
metrics.show()

In [None]:
all_predictions = model.predict(features_df)

output_df = all_predictions.select(
    lit(str(uuid.uuid4())[:8].upper()).alias('PREDICTION_ID'),
    col('FLIGHT_KEY'),
    current_timestamp().alias('SNAPSHOT_TS'),
    col('PREDICTED_DELAY').alias('PREDICTED_DELAY_MINUTES'),
    (col('PREDICTED_DELAY') / 60 * 100).alias('DELAY_RISK_SCORE'),
    lit('v1.0').alias('MODEL_VERSION'),
    lit(None).alias('FEATURE_IMPORTANCE')
)

session.use_schema('ML_PROCESSING')
output_df.write.mode('overwrite').save_as_table('DELAY_PREDICTIONS')
print(f"Saved {output_df.count()} predictions to ML_PROCESSING.DELAY_PREDICTIONS")

In [None]:
reg = Registry(session=session, database_name='IROP_GNN_RISK', schema_name='ML_PROCESSING')

model_version = reg.log_model(
    model=model,
    model_name='DELAY_PREDICTION_MODEL',
    version_name='v1',
    sample_input_data=train_df.limit(10),
    comment='XGBoost regressor for flight delay prediction'
)

print(f"Model registered: {model_version.model_name} v{model_version.version_name}")