# PNR Misconnect Prediction Model

Predicts probability of passenger misconnection based on itinerary and delay propagation.

**Target Variable:** `pnr_misconnect_flag` (binary)
**Algorithm:** XGBoost Classifier
**Output:** `IROP_GNN_RISK.ML_PROCESSING.PNR_MISCONNECT_PREDICTIONS`

In [None]:
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, lit, when, avg, current_timestamp, array_size
from snowflake.snowpark.types import FloatType, IntegerType
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
import uuid

In [None]:
session = get_active_session()
session.use_database('IROP_GNN_RISK')
session.use_schema('ATOMIC')
print(f"Connected: {session.get_current_database()}.{session.get_current_schema()}")

In [None]:
pnr_df = session.table('PNR_TRIP')
flights_df = session.table('FLIGHT_INSTANCE')
airports_df = session.table('AIRPORT_CAPABILITY')

print(f"PNR trips: {pnr_df.count()} rows")
print(f"Flights: {flights_df.count()} rows")

In [None]:
features_df = pnr_df.select(
    col('TRIP_ID'),
    col('PNR_ID'),
    when(col('PNR_MISCONNECT_PROB') >= 0.3, 1).otherwise(0).alias('MISCONNECT_FLAG'),
    col('GROUP_SIZE'),
    when(col('INTL_FLAG') == True, 1).otherwise(0).alias('IS_INTL'),
    col('REBOOK_FLEXIBILITY_INDEX'),
    col('LOYALTY_VALUE_INDEX'),
    col('PNR_REACCOM_COMPLEXITY_SCORE'),
    when(col('ELITE_STATUS_LEVEL').isNotNull(), 1).otherwise(0).alias('IS_ELITE'),
    col('ESTIMATED_VOUCHER_COST_USD')
).na.fill(0)

print(f"PNR feature dataset: {features_df.count()} rows")
features_df.show(5)

In [None]:
train_df, test_df = features_df.random_split([0.8, 0.2], seed=42)
print(f"Training set: {train_df.count()} rows")
print(f"Test set: {test_df.count()} rows")

In [None]:
feature_cols = [
    'GROUP_SIZE', 'IS_INTL', 'REBOOK_FLEXIBILITY_INDEX',
    'LOYALTY_VALUE_INDEX', 'PNR_REACCOM_COMPLEXITY_SCORE', 
    'IS_ELITE', 'ESTIMATED_VOUCHER_COST_USD'
]
target_col = 'MISCONNECT_FLAG'

model = XGBClassifier(
    input_cols=feature_cols,
    label_cols=[target_col],
    output_cols=['PREDICTED_MISCONNECT'],
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1,
    scale_pos_weight=2
)

model.fit(train_df)
print("Model training complete")

In [None]:
predictions = model.predict(test_df)
predictions.select('TRIP_ID', 'MISCONNECT_FLAG', 'PREDICTED_MISCONNECT').show(10)

accuracy = predictions.filter(
    col('MISCONNECT_FLAG') == col('PREDICTED_MISCONNECT')
).count() / predictions.count()
print(f"Accuracy: {accuracy:.2%}")

In [None]:
all_predictions = model.predict(features_df)

output_df = all_predictions.select(
    lit(str(uuid.uuid4())[:8].upper()).alias('PREDICTION_ID'),
    col('PNR_ID'),
    col('TRIP_ID'),
    current_timestamp().alias('SNAPSHOT_TS'),
    col('PREDICTED_MISCONNECT').cast(FloatType()).alias('PNR_MISCONNECT_PROB'),
    lit(None).alias('CONNECTION_LEG_AT_RISK'),
    lit('v1.0').alias('MODEL_VERSION'),
    lit(None).alias('FEATURE_IMPORTANCE')
)

session.use_schema('ML_PROCESSING')
output_df.write.mode('overwrite').save_as_table('PNR_MISCONNECT_PREDICTIONS')
print(f"Saved {output_df.count()} predictions to ML_PROCESSING.PNR_MISCONNECT_PREDICTIONS")

In [None]:
reg = Registry(session=session, database_name='IROP_GNN_RISK', schema_name='ML_PROCESSING')

model_version = reg.log_model(
    model=model,
    model_name='PNR_MISCONNECT_MODEL',
    version_name='v1',
    sample_input_data=train_df.limit(10),
    comment='XGBoost classifier for PNR misconnect prediction'
)

print(f"Model registered: {model_version.model_name} v{model_version.version_name}")