# AOG/MEL Risk Prediction Model

Predicts probability of Aircraft on Ground (AOG) events based on maintenance data.

**Target Variable:** `aog_event_flag` (binary)
**Algorithm:** XGBoost Classifier
**Output:** `IROP_GNN_RISK.ML_PROCESSING.AOG_RISK_PREDICTIONS`

In [None]:
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, lit, when, avg, current_timestamp
from snowflake.snowpark.types import FloatType, IntegerType
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
import uuid

In [None]:
session = get_active_session()
session.use_database('IROP_GNN_RISK')
session.use_schema('ATOMIC')
print(f"Connected: {session.get_current_database()}.{session.get_current_schema()}")

In [None]:
rotations_df = session.table('AIRCRAFT_ROTATION')
flights_df = session.table('FLIGHT_INSTANCE')

print(f"Rotations: {rotations_df.count()} rows")
print(f"Flights: {flights_df.count()} rows")

In [None]:
features_df = rotations_df.select(
    col('TAIL_NUMBER'),
    col('FLIGHT_KEY'),
    when(col('AOG_RISK_SCORE') >= 0.5, 1).otherwise(0).alias('AOG_FLAG'),
    col('AIRCRAFT_AGE_YEARS'),
    col('UTILIZATION_HOURS_24H'),
    when(col('MEL_APU_FLAG') == True, 1).otherwise(0).alias('HAS_APU_MEL'),
    when(col('MEL_ITEM_CODE').isNotNull(), 1).otherwise(0).alias('HAS_MEL'),
    when(col('MEL_SEVERITY') == 'CAT-A', 4)
        .when(col('MEL_SEVERITY') == 'CAT-B', 3)
        .when(col('MEL_SEVERITY') == 'CAT-C', 2)
        .when(col('MEL_SEVERITY') == 'CAT-D', 1)
        .otherwise(0).alias('MEL_SEVERITY_SCORE'),
    when(col('MAINTENANCE_STATION_FLAG') == True, 1).otherwise(0).alias('AT_MX_STATION'),
    when(col('ETOPS_CAPABLE_FLAG') == True, 1).otherwise(0).alias('IS_ETOPS')
).na.fill(0)

print(f"AOG feature dataset: {features_df.count()} rows")
features_df.show(5)

In [None]:
train_df, test_df = features_df.random_split([0.8, 0.2], seed=42)
print(f"Training set: {train_df.count()} rows")
print(f"Test set: {test_df.count()} rows")

In [None]:
feature_cols = [
    'AIRCRAFT_AGE_YEARS', 'UTILIZATION_HOURS_24H', 'HAS_APU_MEL',
    'HAS_MEL', 'MEL_SEVERITY_SCORE', 'AT_MX_STATION', 'IS_ETOPS'
]
target_col = 'AOG_FLAG'

model = XGBClassifier(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=['PREDICTED_AOG'],
    n_estimators=100,
    max_depth=4,
    learning_rate=0.1,
    scale_pos_weight=5
)

model.fit(train_df)
print("Model training complete")

In [None]:
predictions = model.predict(test_df)
predictions.select('TAIL_NUMBER', 'AOG_FLAG', 'PREDICTED_AOG').show(10)

accuracy = predictions.filter(
    col('AOG_FLAG') == col('PREDICTED_AOG')
).count() / predictions.count()
print(f"Accuracy: {accuracy:.2%}")

In [None]:
all_predictions = model.predict(features_df)

output_df = all_predictions.select(
    lit(str(uuid.uuid4())[:8].upper()).alias('PREDICTION_ID'),
    col('TAIL_NUMBER'),
    current_timestamp().alias('SNAPSHOT_TS'),
    col('PREDICTED_AOG').cast(FloatType()).alias('AOG_RISK_SCORE'),
    when(col('HAS_APU_MEL') == 1, True).otherwise(False).alias('CRITICAL_MEL_FLAG'),
    lit(None).alias('MEL_NARRATIVE_SUMMARY'),
    lit('v1.0').alias('MODEL_VERSION'),
    lit(None).alias('FEATURE_IMPORTANCE')
)

session.use_schema('ML_PROCESSING')
output_df.write.mode('overwrite').save_as_table('AOG_RISK_PREDICTIONS')
print(f"Saved {output_df.count()} predictions to ML_PROCESSING.AOG_RISK_PREDICTIONS")

In [None]:
reg = Registry(session=session, database_name='IROP_GNN_RISK', schema_name='ML_PROCESSING')

model_version = reg.log_model(
    model=model,
    model_name='AOG_RISK_MODEL',
    version_name='v1',
    sample_input_data=train_df.limit(10),
    comment='XGBoost classifier for AOG/MEL risk prediction'
)

print(f"Model registered: {model_version.model_name} v{model_version.version_name}")