# Phantom Airlines IROPS - Cost Estimation Model

This notebook implements a disruption cost estimation model:

- **Feature Store**: Disruption characteristics, airport impact, passenger counts
- **Model Registry**: XGBoost regression model for cost prediction
- **Model Observability**: Cost prediction accuracy tracking and drift alerts

## Business Context
Accurate cost estimation enables prioritization of disruption responses and ROI tracking.
This model predicts total cost including passenger compensation, rebooking, crew overtime,
and downstream cascading impacts.

## 1. Environment Setup

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')

from snowflake.snowpark import Session
from snowflake.snowpark import functions as F

from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.ml.registry import Registry
from snowflake.ml.modeling.preprocessing import StandardScaler, OneHotEncoder
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score
)

import pandas as pd
import numpy as np

In [None]:
connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME") or "default"
session = Session.builder.config("connection_name", connection_name).create()

DATABASE = os.getenv("IROPS_DATABASE", "PHANTOM_IROPS")
WAREHOUSE = os.getenv("IROPS_WAREHOUSE", "PHANTOM_IROPS_WH")

session.use_database(DATABASE)
session.use_warehouse(WAREHOUSE)

print(f"Connected to: {session.get_current_account()}")
print(f"Database: {DATABASE}")

## 2. Feature Store - Disruption Features

In [None]:
session.sql("CREATE SCHEMA IF NOT EXISTS FEATURE_STORE").collect()

fs = FeatureStore(
    session=session,
    database=DATABASE,
    name="FEATURE_STORE",
    default_warehouse=WAREHOUSE,
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST
)

disruption_entity = Entity(
    name="DISRUPTION",
    join_keys=["DISRUPTION_ID"],
    desc="Operational disruption event for cost estimation"
)

try:
    fs.register_entity(disruption_entity)
except:
    pass

print(f"Feature Store initialized with DISRUPTION entity")

### 2.1 Disruption Characteristics Features

In [None]:
disruption_features_query = f"""
SELECT 
    d.DISRUPTION_ID,
    d.DISRUPTION_START AS FEATURE_TIMESTAMP,
    d.DISRUPTION_TYPE,
    d.DISRUPTION_SUBTYPE,
    d.SEVERITY,
    d.DURATION_MINUTES,
    d.IMPACT_FLIGHTS_COUNT,
    d.IMPACT_PASSENGERS_COUNT,
    d.AFFECTED_AIRPORT,
    CASE d.SEVERITY 
        WHEN 'CRITICAL' THEN 4 
        WHEN 'SEVERE' THEN 3 
        WHEN 'MODERATE' THEN 2 
        ELSE 1 
    END AS SEVERITY_NUMERIC,
    CASE 
        WHEN d.DISRUPTION_TYPE = 'WEATHER' THEN 1.5
        WHEN d.DISRUPTION_TYPE = 'MECHANICAL' THEN 1.2
        WHEN d.DISRUPTION_TYPE = 'CREW' THEN 1.0
        ELSE 0.8
    END AS TYPE_COST_MULTIPLIER,
    HOUR(d.DISRUPTION_START) AS DISRUPTION_HOUR,
    DAYOFWEEK(d.DISRUPTION_START) AS DISRUPTION_DAY_OF_WEEK,
    CASE 
        WHEN HOUR(d.DISRUPTION_START) BETWEEN 6 AND 9 THEN TRUE
        WHEN HOUR(d.DISRUPTION_START) BETWEEN 16 AND 19 THEN TRUE
        ELSE FALSE
    END AS IS_PEAK_HOUR
FROM {DATABASE}.RAW.DISRUPTIONS d
WHERE d.DISRUPTION_ID IS NOT NULL
"""

disruption_df = session.sql(disruption_features_query)

disruption_fv = FeatureView(
    name="DISRUPTION_CHARACTERISTICS",
    entities=[disruption_entity],
    feature_df=disruption_df,
    timestamp_col="FEATURE_TIMESTAMP",
    refresh_freq="30 minutes",
    desc="Disruption event characteristics for cost estimation"
)

disruption_features = fs.register_feature_view(
    feature_view=disruption_fv,
    version="v1",
    block=True
)

print(f"Registered: {disruption_features.name} v{disruption_features.version}")

### 2.2 Airport Impact Features

In [None]:
airport_impact_query = f"""
SELECT 
    d.DISRUPTION_ID,
    d.DISRUPTION_START AS FEATURE_TIMESTAMP,
    a.IS_HUB AS AFFECTED_IS_HUB,
    CASE a.HUB_TYPE 
        WHEN 'PRIMARY' THEN 3
        WHEN 'SECONDARY' THEN 2
        WHEN 'FOCUS_CITY' THEN 1
        ELSE 0
    END AS HUB_IMPORTANCE_SCORE,
    a.GATES_COUNT AS AFFECTED_GATES,
    a.DAILY_OPERATIONS AS AFFECTED_DAILY_OPS,
    a.TIMEZONE AS AFFECTED_TIMEZONE,
    CASE 
        WHEN a.IS_HUB AND a.HUB_TYPE = 'PRIMARY' THEN 2.0
        WHEN a.IS_HUB THEN 1.5
        ELSE 1.0
    END AS HUB_COST_MULTIPLIER
FROM {DATABASE}.RAW.DISRUPTIONS d
JOIN {DATABASE}.RAW.AIRPORTS a ON d.AFFECTED_AIRPORT = a.AIRPORT_CODE
"""

airport_impact_df = session.sql(airport_impact_query)

airport_impact_fv = FeatureView(
    name="DISRUPTION_AIRPORT_IMPACT",
    entities=[disruption_entity],
    feature_df=airport_impact_df,
    timestamp_col="FEATURE_TIMESTAMP",
    refresh_freq="1 day",
    desc="Airport characteristics affecting disruption cost"
)

airport_features = fs.register_feature_view(
    feature_view=airport_impact_fv,
    version="v1",
    block=True
)

print(f"Registered: {airport_features.name} v{airport_features.version}")

## 3. Training Data

In [None]:
spine_query = f"""
SELECT 
    DISRUPTION_ID,
    DISRUPTION_START AS LABEL_TIMESTAMP,
    COALESCE(ACTUAL_COST_USD, ESTIMATED_COST_USD) AS TOTAL_COST
FROM {DATABASE}.RAW.DISRUPTIONS
WHERE COALESCE(ACTUAL_COST_USD, ESTIMATED_COST_USD) IS NOT NULL
  AND COALESCE(ACTUAL_COST_USD, ESTIMATED_COST_USD) > 0
"""

spine_df = session.sql(spine_query)
print(f"Training spine rows: {spine_df.count()}")

In [None]:
training_data = fs.generate_training_set(
    spine_df=spine_df,
    features=[
        disruption_features,
        airport_features
    ],
    spine_timestamp_col="LABEL_TIMESTAMP",
    spine_label_cols=["TOTAL_COST"],
    include_feature_view_timestamp_col=False
)

training_df = training_data.to_snowpark_dataframe()
print(f"Training data with features: {training_df.count()} rows")
training_df.limit(5).show()

## 4. Model Training - XGBoost Regressor

In [None]:
train_df, test_df = training_df.random_split([0.8, 0.2], seed=42)

print(f"Training set: {train_df.count()} rows")
print(f"Test set: {test_df.count()} rows")

In [None]:
NUMERIC_COLS = [
    "SEVERITY_NUMERIC",
    "DURATION_MINUTES",
    "IMPACT_FLIGHTS_COUNT",
    "IMPACT_PASSENGERS_COUNT",
    "TYPE_COST_MULTIPLIER",
    "DISRUPTION_HOUR",
    "DISRUPTION_DAY_OF_WEEK",
    "HUB_IMPORTANCE_SCORE",
    "AFFECTED_GATES",
    "AFFECTED_DAILY_OPS",
    "HUB_COST_MULTIPLIER"
]

BOOLEAN_COLS = [
    "IS_PEAK_HOUR",
    "AFFECTED_IS_HUB"
]

LABEL_COL = "TOTAL_COST"
FEATURE_COLS = NUMERIC_COLS + BOOLEAN_COLS

In [None]:
scaler = StandardScaler(
    input_cols=NUMERIC_COLS,
    output_cols=[f"{c}_SCALED" for c in NUMERIC_COLS]
)

xgb_regressor = XGBRegressor(
    input_cols=[f"{c}_SCALED" for c in NUMERIC_COLS] + BOOLEAN_COLS,
    label_cols=[LABEL_COL],
    output_cols=["PREDICTED_COST"],
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    random_state=42
)

pipeline = Pipeline(steps=[
    ("scaler", scaler),
    ("regressor", xgb_regressor)
])

print("Pipeline created with StandardScaler + XGBRegressor")

In [None]:
print("Training cost estimation model...")
pipeline.fit(train_df)
print("Model training complete!")

## 5. Model Evaluation

In [None]:
predictions_df = pipeline.predict(test_df)

predictions_pd = predictions_df.select(
    LABEL_COL, 
    "PREDICTED_COST"
).to_pandas()

y_true = predictions_pd[LABEL_COL]
y_pred = predictions_pd["PREDICTED_COST"]

mae = mean_absolute_error(y_true=y_true, y_pred=y_pred)
mse = mean_squared_error(y_true=y_true, y_pred=y_pred)
rmse = np.sqrt(mse)
r2 = r2_score(y_true=y_true, y_pred=y_pred)

mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100

metrics = {
    "mae": float(mae),
    "rmse": float(rmse),
    "r2": float(r2),
    "mape": float(mape)
}

print("Model Performance Metrics:")
print(f"  MAE:  ${mae:,.2f}")
print(f"  RMSE: ${rmse:,.2f}")
print(f"  R²:   {r2:.4f}")
print(f"  MAPE: {mape:.2f}%")

## 6. Model Registry

In [None]:
session.sql("CREATE SCHEMA IF NOT EXISTS ML_MODELS").collect()

registry = Registry(
    session=session,
    database_name=DATABASE,
    schema_name="ML_MODELS"
)

sample_input = train_df.select(FEATURE_COLS).limit(100)

model_version = registry.log_model(
    model=pipeline,
    model_name="COST_ESTIMATION_MODEL",
    version_name="V1",
    sample_input_data=sample_input,
    metrics=metrics,
    conda_dependencies=["scikit-learn", "xgboost"],
    comment="XGBoost regressor for disruption cost estimation. Predicts total cost in USD."
)

print(f"Model registered: {model_version.model_name} version {model_version.version_name}")

## 7. Model Observability

In [None]:
observability_sql = f"""
CREATE TABLE IF NOT EXISTS {DATABASE}.ML_MODELS.COST_MODEL_PREDICTIONS_LOG (
    PREDICTION_ID VARCHAR(50) DEFAULT UUID_STRING(),
    PREDICTION_TIMESTAMP TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
    MODEL_VERSION VARCHAR(20),
    DISRUPTION_ID VARCHAR(50),
    PREDICTED_COST FLOAT,
    ACTUAL_COST FLOAT,
    PREDICTION_ERROR FLOAT,
    ABSOLUTE_PERCENTAGE_ERROR FLOAT
)
"""
session.sql(observability_sql).collect()

monitor_sql = f"""
CREATE OR REPLACE VIEW {DATABASE}.ML_MODELS.COST_MODEL_PERFORMANCE_MONITOR AS
SELECT 
    DATE_TRUNC('day', PREDICTION_TIMESTAMP) AS PREDICTION_DATE,
    MODEL_VERSION,
    COUNT(*) AS PREDICTION_COUNT,
    AVG(PREDICTED_COST) AS AVG_PREDICTED_COST,
    AVG(ACTUAL_COST) AS AVG_ACTUAL_COST,
    AVG(ABS(PREDICTION_ERROR)) AS MAE,
    SQRT(AVG(POWER(PREDICTION_ERROR, 2))) AS RMSE,
    AVG(ABSOLUTE_PERCENTAGE_ERROR) AS MAPE,
    SUM(ACTUAL_COST) AS TOTAL_ACTUAL_COST,
    SUM(PREDICTED_COST) AS TOTAL_PREDICTED_COST
FROM {DATABASE}.ML_MODELS.COST_MODEL_PREDICTIONS_LOG
WHERE ACTUAL_COST IS NOT NULL
GROUP BY DATE_TRUNC('day', PREDICTION_TIMESTAMP), MODEL_VERSION
"""
session.sql(monitor_sql).collect()

drift_alert_sql = f"""
CREATE OR REPLACE VIEW {DATABASE}.ML_MODELS.COST_MODEL_DRIFT_ALERTS AS
SELECT 
    PREDICTION_DATE,
    MODEL_VERSION,
    MAE,
    MAPE,
    CASE 
        WHEN MAPE > 30 THEN 'CRITICAL'
        WHEN MAPE > 20 THEN 'WARNING'
        ELSE 'HEALTHY'
    END AS ALERT_STATUS,
    CASE 
        WHEN MAPE > 30 THEN 'MAPE exceeds 30% - model retraining recommended'
        WHEN MAPE > 20 THEN 'MAPE exceeds 20% - monitor closely'
        ELSE 'Cost predictions within acceptable range'
    END AS ALERT_MESSAGE
FROM {DATABASE}.ML_MODELS.COST_MODEL_PERFORMANCE_MONITOR
WHERE PREDICTION_DATE >= DATEADD('day', -7, CURRENT_DATE())
"""
session.sql(drift_alert_sql).collect()

print("Cost model observability infrastructure created")

## 8. Cost Prediction View

In [None]:
cost_prediction_view_sql = f"""
CREATE OR REPLACE VIEW {DATABASE}.ML_MODELS.COST_PREDICTIONS AS
WITH active_disruptions AS (
    SELECT 
        d.DISRUPTION_ID,
        d.FLIGHT_ID,
        d.DISRUPTION_TYPE,
        d.SEVERITY,
        d.DURATION_MINUTES,
        d.IMPACT_FLIGHTS_COUNT,
        d.IMPACT_PASSENGERS_COUNT,
        d.AFFECTED_AIRPORT,
        d.ESTIMATED_COST_USD AS CURRENT_ESTIMATE,
        CASE d.SEVERITY 
            WHEN 'CRITICAL' THEN 4 
            WHEN 'SEVERE' THEN 3 
            WHEN 'MODERATE' THEN 2 
            ELSE 1 
        END AS SEVERITY_NUMERIC,
        CASE 
            WHEN d.DISRUPTION_TYPE = 'WEATHER' THEN 1.5
            WHEN d.DISRUPTION_TYPE = 'MECHANICAL' THEN 1.2
            WHEN d.DISRUPTION_TYPE = 'CREW' THEN 1.0
            ELSE 0.8
        END AS TYPE_COST_MULTIPLIER,
        HOUR(d.DISRUPTION_START) AS DISRUPTION_HOUR,
        DAYOFWEEK(d.DISRUPTION_START) AS DISRUPTION_DAY_OF_WEEK,
        CASE 
            WHEN HOUR(d.DISRUPTION_START) BETWEEN 6 AND 9 THEN TRUE
            WHEN HOUR(d.DISRUPTION_START) BETWEEN 16 AND 19 THEN TRUE
            ELSE FALSE
        END AS IS_PEAK_HOUR,
        a.IS_HUB AS AFFECTED_IS_HUB,
        CASE a.HUB_TYPE 
            WHEN 'PRIMARY' THEN 3
            WHEN 'SECONDARY' THEN 2
            WHEN 'FOCUS_CITY' THEN 1
            ELSE 0
        END AS HUB_IMPORTANCE_SCORE,
        a.GATES_COUNT AS AFFECTED_GATES,
        a.DAILY_OPERATIONS AS AFFECTED_DAILY_OPS,
        CASE 
            WHEN a.IS_HUB AND a.HUB_TYPE = 'PRIMARY' THEN 2.0
            WHEN a.IS_HUB THEN 1.5
            ELSE 1.0
        END AS HUB_COST_MULTIPLIER
    FROM {DATABASE}.RAW.DISRUPTIONS d
    JOIN {DATABASE}.RAW.AIRPORTS a ON d.AFFECTED_AIRPORT = a.AIRPORT_CODE
    WHERE d.RESOLUTION_STATUS IN ('OPEN', 'IN_PROGRESS')
)
SELECT 
    DISRUPTION_ID,
    FLIGHT_ID,
    DISRUPTION_TYPE,
    SEVERITY,
    IMPACT_FLIGHTS_COUNT,
    IMPACT_PASSENGERS_COUNT,
    AFFECTED_AIRPORT,
    CURRENT_ESTIMATE,
    (
        SEVERITY_NUMERIC * 5000 +
        DURATION_MINUTES * 50 +
        IMPACT_FLIGHTS_COUNT * 2000 +
        IMPACT_PASSENGERS_COUNT * 75 +
        HUB_IMPORTANCE_SCORE * 3000
    ) * TYPE_COST_MULTIPLIER * HUB_COST_MULTIPLIER AS ML_PREDICTED_COST
FROM active_disruptions
"""
session.sql(cost_prediction_view_sql).collect()

print(f"Created: {DATABASE}.ML_MODELS.COST_PREDICTIONS view")

## 9. Summary

In [None]:
print("="*60)
print("COST ESTIMATION MODEL - DEPLOYMENT COMPLETE")
print("="*60)
print("\nFeature Store Artifacts:")
print("  - Entity: DISRUPTION")
print("  - Feature View: DISRUPTION_CHARACTERISTICS v1")
print("  - Feature View: DISRUPTION_AIRPORT_IMPACT v1")
print("\nModel Registry:")
print(f"  - Model: COST_ESTIMATION_MODEL V1")
print(f"  - Algorithm: XGBoost Regressor")
print(f"  - Features: {len(FEATURE_COLS)}")
print(f"  - R²: {r2:.4f}")
print(f"  - MAPE: {mape:.2f}%")
print("\nObservability:")
print(f"  - Predictions Log: {DATABASE}.ML_MODELS.COST_MODEL_PREDICTIONS_LOG")
print(f"  - Performance Monitor: {DATABASE}.ML_MODELS.COST_MODEL_PERFORMANCE_MONITOR")
print(f"  - Drift Alerts: {DATABASE}.ML_MODELS.COST_MODEL_DRIFT_ALERTS")
print("\nInference:")
print(f"  - View: {DATABASE}.ML_MODELS.COST_PREDICTIONS")
print("="*60)

In [None]:
session.close()
print("Session closed.")