# Crew Timeout Prediction Model

Predicts probability of crew FDP (Flight Duty Period) timeout based on current delays.

**Target Variable:** `crew_timeout_flag` (binary)
**Algorithm:** Gradient-Boosted Classifier
**Output:** `IROP_GNN_RISK.ML_PROCESSING.CREW_TIMEOUT_PREDICTIONS`

In [None]:
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, lit, when, avg, current_timestamp, datediff
from snowflake.snowpark.types import FloatType, IntegerType
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
import uuid

In [None]:
session = get_active_session()
session.use_database('IROP_GNN_RISK')
session.use_schema('ATOMIC')
print(f"Connected: {session.get_current_database()}.{session.get_current_schema()}")

In [None]:
crew_df = session.table('CREW_DUTY_PERIOD')
assignments_df = session.table('CREW_ASSIGNMENT')
flights_df = session.table('FLIGHT_INSTANCE')

print(f"Crew duties: {crew_df.count()} rows")
print(f"Assignments: {assignments_df.count()} rows")

In [None]:
duty_delays = assignments_df.alias('a').join(
    flights_df.alias('f'),
    col('a.FLIGHT_KEY') == col('f.FLIGHT_KEY')
).group_by('a.DUTY_ID').agg(
    avg('f.CURRENT_DELAY_DEPARTURE').alias('AVG_SEGMENT_DELAY'),
    snowpark.functions.sum('f.CURRENT_DELAY_DEPARTURE').alias('TOTAL_DELAY')
)

features_df = crew_df.alias('c').join(
    duty_delays.alias('d'),
    col('c.DUTY_ID') == col('d.DUTY_ID'),
    'left'
).select(
    col('c.DUTY_ID'),
    when(col('c.CREW_TIMEOUT_RISK_SCORE') >= 0.5, 1).otherwise(0).alias('TIMEOUT_FLAG'),
    col('c.FDP_LIMIT_MINUTES'),
    col('c.FDP_TIME_USED_MINUTES'),
    col('c.FDP_REMAINING_MINUTES'),
    col('c.NUM_SEGMENTS'),
    col('c.TIME_ZONE_SPAN_HOURS'),
    when(col('c.AUGMENTED_CREW_FLAG') == True, 1).otherwise(0).alias('IS_AUGMENTED'),
    col('c.REST_IN_LAST_168_HOURS_MINUTES'),
    col('d.AVG_SEGMENT_DELAY').cast(FloatType()).alias('AVG_DELAY'),
    col('d.TOTAL_DELAY').cast(FloatType()).alias('TOTAL_DELAY')
).na.fill(0)

print(f"Crew feature dataset: {features_df.count()} rows")
features_df.show(5)

In [None]:
train_df, test_df = features_df.random_split([0.8, 0.2], seed=42)
print(f"Training set: {train_df.count()} rows")
print(f"Test set: {test_df.count()} rows")

In [None]:
feature_cols = [
    'FDP_LIMIT_MINUTES', 'FDP_TIME_USED_MINUTES', 'FDP_REMAINING_MINUTES',
    'NUM_SEGMENTS', 'TIME_ZONE_SPAN_HOURS', 'IS_AUGMENTED',
    'REST_IN_LAST_168_HOURS_MINUTES', 'AVG_DELAY', 'TOTAL_DELAY'
]
target_col = 'TIMEOUT_FLAG'

model = XGBClassifier(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=['PREDICTED_TIMEOUT'],
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1,
    scale_pos_weight=3
)

model.fit(train_df)
print("Model training complete")

In [None]:
predictions = model.predict(test_df)
predictions.select('DUTY_ID', 'TIMEOUT_FLAG', 'PREDICTED_TIMEOUT', 'FDP_REMAINING_MINUTES').show(10)

accuracy = predictions.filter(
    col('TIMEOUT_FLAG') == col('PREDICTED_TIMEOUT')
).count() / predictions.count()
print(f"Accuracy: {accuracy:.2%}")

In [None]:
all_predictions = model.predict(features_df)

output_df = all_predictions.select(
    lit(str(uuid.uuid4())[:8].upper()).alias('PREDICTION_ID'),
    col('DUTY_ID'),
    current_timestamp().alias('SNAPSHOT_TS'),
    col('PREDICTED_TIMEOUT').cast(FloatType()).alias('TIMEOUT_PROB'),
    col('FDP_REMAINING_MINUTES').alias('TIME_TO_TIMEOUT_MINUTES'),
    lit('v1.0').alias('MODEL_VERSION'),
    lit(None).alias('FEATURE_IMPORTANCE')
)

session.use_schema('ML_PROCESSING')
output_df.write.mode('overwrite').save_as_table('CREW_TIMEOUT_PREDICTIONS')
print(f"Saved {output_df.count()} predictions to ML_PROCESSING.CREW_TIMEOUT_PREDICTIONS")

In [None]:
reg = Registry(session=session, database_name='IROP_GNN_RISK', schema_name='ML_PROCESSING')

model_version = reg.log_model(
    model=model,
    model_name='CREW_TIMEOUT_MODEL',
    version_name='v1',
    sample_input_data=train_df.limit(10),
    comment='XGBoost classifier for crew FDP timeout prediction'
)

print(f"Model registered: {model_version.model_name} v{model_version.version_name}")