# Kratos Defense ML Models - Model Registry

This notebook trains ML models for the Kratos Defense Intelligence Agent:
- **Program Risk Prediction** - Predict risk level for defense programs
- **Supplier Risk Prediction** - Identify suppliers at risk based on performance
- **Asset Maintenance Prediction** - Predict maintenance urgency for assets

All models are registered to Snowflake Model Registry and can be added as tools to the Intelligence Agent.

## Prerequisites

**Required Packages** (configured automatically):
- `snowflake-ml-python`
- `scikit-learn`
- `xgboost`
- `matplotlib`

**Database Context:**
- **Database:** KRATOS_INTELLIGENCE  
- **Schema:** ANALYTICS  
- **Warehouse:** KRATOS_WH

**Note:** This notebook uses Snowflake Model Registry. Ensure you have appropriate permissions to create and register models.


In [None]:
# Import Python packages
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Import Snowpark
from snowflake.snowpark.context import get_active_session
import snowflake.snowpark.functions as F
import snowflake.snowpark.types as T
from snowflake.snowpark import Window

# Import Snowpark ML
from snowflake.ml.modeling.preprocessing import StandardScaler, OneHotEncoder
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.linear_model import LinearRegression, LogisticRegression
from snowflake.ml.modeling.ensemble import RandomForestClassifier
from snowflake.ml.modeling.metrics import mean_squared_error, mean_absolute_error, accuracy_score, roc_auc_score
from snowflake.ml.registry import Registry

print("✅ Packages imported successfully")


In [None]:
# Get active Snowflake session
session = get_active_session()

# Set context - MUST MATCH 01_database_and_schema.sql
session.use_database('KRATOS_INTELLIGENCE')
session.use_schema('ANALYTICS')
session.use_warehouse('KRATOS_WH')

print(f"✅ Connected - Role: {session.get_current_role()}")
print(f"   Warehouse: {session.get_current_warehouse()}")
print(f"   Database.Schema: {session.get_fully_qualified_current_schema()}")


## MODEL 1: Program Risk Prediction

Predict program risk level based on budget, schedule, and milestone performance.


In [None]:
# Get program risk data with features
# COLUMN NAMES VERIFIED against 02_create_tables.sql PROGRAMS table
# All numeric columns explicitly cast to ::FLOAT to avoid Decimal conversion warnings
program_risk_df = session.sql("""
SELECT
    p.program_id,
    p.budget_amount::FLOAT AS budget,
    p.spent_amount::FLOAT AS spent,
    p.budget_variance::FLOAT AS variance,
    p.schedule_variance_days::FLOAT AS schedule_variance,
    p.percent_complete::FLOAT AS completion_pct,
    p.milestone_count::FLOAT AS total_milestones,
    COALESCE((p.milestones_completed::FLOAT / NULLIF(p.milestone_count, 0) * 100), 0)::FLOAT AS milestone_pct,
    COALESCE((p.spent_amount::FLOAT / NULLIF(p.budget_amount, 0) * 100), 0)::FLOAT AS budget_utilization,
    p.program_type AS prog_type,
    CASE p.risk_level
        WHEN 'LOW' THEN 0
        WHEN 'MEDIUM' THEN 1
        WHEN 'HIGH' THEN 2
        ELSE 3
    END AS risk_label
FROM RAW.PROGRAMS p
WHERE p.program_status IN ('ACTIVE', 'COMPLETED', 'ON_HOLD')
""")

print(f"Program risk data: {program_risk_df.count()} program records")
program_risk_df.show(5)


In [None]:
# Train/test split (80/20)
train_program, test_program = program_risk_df.random_split([0.8, 0.2], seed=42)
train_program = train_program.drop("PROGRAM_ID")
test_program = test_program.drop("PROGRAM_ID")

print(f"Training set: {train_program.count()} rows")
print(f"Test set: {test_program.count()} rows")

# Create pipeline with preprocessing and classification
program_pipeline = Pipeline([
    ("Encoder", OneHotEncoder(
        input_cols=["PROG_TYPE"],
        output_cols=["PROG_TYPE_ENC"],
        drop_input_cols=True,
        handle_unknown="ignore"
    )),
    ("Scaler", StandardScaler(
        input_cols=["BUDGET", "SPENT", "VARIANCE", "SCHEDULE_VARIANCE", "COMPLETION_PCT", "TOTAL_MILESTONES", "MILESTONE_PCT", "BUDGET_UTILIZATION"],
        output_cols=["BUDGET_SCALED", "SPENT_SCALED", "VARIANCE_SCALED", "SCHEDULE_VARIANCE_SCALED", "COMPLETION_PCT_SCALED", "TOTAL_MILESTONES_SCALED", "MILESTONE_PCT_SCALED", "BUDGET_UTILIZATION_SCALED"]
    )),
    ("Classifier", RandomForestClassifier(
        label_cols=["RISK_LABEL"],
        output_cols=["PREDICTED_RISK"],
        n_estimators=100,
        max_depth=10
    ))
])

# Train model
program_pipeline.fit(train_program)
print("✅ Program risk prediction model trained")


In [None]:
# Evaluate and register program risk model
program_predictions = program_pipeline.predict(test_program)

accuracy = accuracy_score(
    df=program_predictions, 
    y_true_col_names="RISK_LABEL", 
    y_pred_col_names="PREDICTED_RISK"
)

program_metrics = {"accuracy": round(accuracy, 4)}
print(f"Model metrics: {program_metrics}")

# Register model
reg = Registry(session)
reg.log_model(
    model=program_pipeline,
    model_name="PROGRAM_RISK_PREDICTOR",
    version_name="V1",
    comment="Predicts program risk level using Random Forest based on budget, schedule, and milestone performance",
    metrics=program_metrics
)

print("✅ Program risk model registered to Model Registry as PROGRAM_RISK_PREDICTOR")


## MODEL 2: Supplier Risk Prediction

Predict supplier risk level based on quality and delivery performance.


In [None]:
# Get supplier risk data with features
# COLUMN NAMES VERIFIED against 02_create_tables.sql SUPPLIERS table
# All numeric columns explicitly cast to ::FLOAT to avoid Decimal conversion warnings
supplier_risk_df = session.sql("""
SELECT
    s.supplier_id,
    s.quality_rating::FLOAT AS quality_score,
    s.delivery_rating::FLOAT AS delivery_score,
    ((s.quality_rating + s.delivery_rating) / 2)::FLOAT AS overall_rating,
    s.total_orders::FLOAT AS order_count,
    s.total_spend::FLOAT AS total_spend,
    COALESCE((s.total_spend::FLOAT / NULLIF(s.total_orders, 0)), 0)::FLOAT AS avg_order_value,
    s.payment_terms::FLOAT AS payment_terms,
    s.supplier_type AS sup_type,
    CASE s.risk_rating
        WHEN 'LOW' THEN 0
        WHEN 'MEDIUM' THEN 1
        WHEN 'HIGH' THEN 2
        ELSE 3
    END AS risk_label
FROM RAW.SUPPLIERS s
WHERE s.supplier_status IN ('ACTIVE', 'PREFERRED', 'PROBATION')
""")

print(f"Supplier risk data: {supplier_risk_df.count()} supplier records")
supplier_risk_df.show(5)


In [None]:
# Train/test split (80/20)
train_supplier, test_supplier = supplier_risk_df.random_split([0.8, 0.2], seed=42)
train_supplier = train_supplier.drop("SUPPLIER_ID")
test_supplier = test_supplier.drop("SUPPLIER_ID")

print(f"Training set: {train_supplier.count()} rows")
print(f"Test set: {test_supplier.count()} rows")

# Create pipeline for supplier risk prediction
supplier_pipeline = Pipeline([
    ("Encoder", OneHotEncoder(
        input_cols=["SUP_TYPE"],
        output_cols=["SUP_TYPE_ENC"],
        drop_input_cols=True,
        handle_unknown="ignore"
    )),
    ("Scaler", StandardScaler(
        input_cols=["QUALITY_SCORE", "DELIVERY_SCORE", "OVERALL_RATING", "ORDER_COUNT", "TOTAL_SPEND", "AVG_ORDER_VALUE", "PAYMENT_TERMS"],
        output_cols=["QUALITY_SCORE_SCALED", "DELIVERY_SCORE_SCALED", "OVERALL_RATING_SCALED", "ORDER_COUNT_SCALED", "TOTAL_SPEND_SCALED", "AVG_ORDER_VALUE_SCALED", "PAYMENT_TERMS_SCALED"]
    )),
    ("Classifier", RandomForestClassifier(
        label_cols=["RISK_LABEL"],
        output_cols=["PREDICTED_RISK"],
        n_estimators=100,
        max_depth=10
    ))
])

# Train model
supplier_pipeline.fit(train_supplier)
print("✅ Supplier risk prediction model trained")


In [None]:
# Evaluate and register supplier risk model
supplier_predictions = supplier_pipeline.predict(test_supplier)

accuracy = accuracy_score(
    df=supplier_predictions, 
    y_true_col_names="RISK_LABEL", 
    y_pred_col_names="PREDICTED_RISK"
)

supplier_metrics = {"accuracy": round(accuracy, 4)}
print(f"Model metrics: {supplier_metrics}")

# Register model
reg.log_model(
    model=supplier_pipeline,
    model_name="SUPPLIER_RISK_PREDICTOR",
    version_name="V1",
    comment="Predicts supplier risk level using Random Forest based on quality and delivery performance",
    metrics=supplier_metrics
)

print("✅ Supplier risk model registered to Model Registry as SUPPLIER_RISK_PREDICTOR")


## MODEL 3: Asset Maintenance Prediction

Predict maintenance urgency for assets based on usage and maintenance history.


In [None]:
# Get asset maintenance data with features
# COLUMN NAMES VERIFIED against 02_create_tables.sql ASSETS table
# All numeric columns explicitly cast to ::FLOAT to avoid Decimal conversion warnings
asset_maintenance_df = session.sql("""
SELECT
    a.asset_id,
    a.total_flight_hours::FLOAT AS flight_hours,
    a.maintenance_interval_hours::FLOAT AS maint_interval,
    COALESCE((a.total_flight_hours::FLOAT / NULLIF(a.maintenance_interval_hours, 0) * 100), 0)::FLOAT AS utilization_pct,
    DATEDIFF('day', a.last_maintenance_date, CURRENT_DATE())::FLOAT AS days_since_maintenance,
    DATEDIFF('day', CURRENT_DATE(), a.next_maintenance_due)::FLOAT AS days_until_due,
    CASE a.condition_rating
        WHEN 'EXCELLENT' THEN 4
        WHEN 'GOOD' THEN 3
        WHEN 'FAIR' THEN 2
        ELSE 1
    END::FLOAT AS condition_score,
    CASE WHEN a.mission_ready = TRUE THEN 1 ELSE 0 END::FLOAT AS is_ready,
    a.asset_type AS ast_type,
    -- Urgency label based on maintenance due date
    CASE 
        WHEN DATEDIFF('day', CURRENT_DATE(), a.next_maintenance_due) < 0 THEN 2  -- OVERDUE
        WHEN DATEDIFF('day', CURRENT_DATE(), a.next_maintenance_due) <= 14 THEN 1  -- DUE_SOON
        ELSE 0  -- ON_SCHEDULE
    END AS urgency_label
FROM RAW.ASSETS a
WHERE a.asset_status IN ('OPERATIONAL', 'MAINTENANCE', 'STANDBY')
  AND a.next_maintenance_due IS NOT NULL
  AND a.last_maintenance_date IS NOT NULL
""")

print(f"Asset maintenance data: {asset_maintenance_df.count()} asset records")
asset_maintenance_df.show(5)


In [None]:
# Train/test split (80/20)
train_asset, test_asset = asset_maintenance_df.random_split([0.8, 0.2], seed=42)
train_asset = train_asset.drop("ASSET_ID")
test_asset = test_asset.drop("ASSET_ID")

print(f"Training set: {train_asset.count()} rows")
print(f"Test set: {test_asset.count()} rows")

# Create pipeline for asset maintenance prediction
asset_pipeline = Pipeline([
    ("Encoder", OneHotEncoder(
        input_cols=["AST_TYPE"],
        output_cols=["AST_TYPE_ENC"],
        drop_input_cols=True,
        handle_unknown="ignore"
    )),
    ("Scaler", StandardScaler(
        input_cols=["FLIGHT_HOURS", "MAINT_INTERVAL", "UTILIZATION_PCT", "DAYS_SINCE_MAINTENANCE", "DAYS_UNTIL_DUE", "CONDITION_SCORE", "IS_READY"],
        output_cols=["FLIGHT_HOURS_SCALED", "MAINT_INTERVAL_SCALED", "UTILIZATION_PCT_SCALED", "DAYS_SINCE_MAINTENANCE_SCALED", "DAYS_UNTIL_DUE_SCALED", "CONDITION_SCORE_SCALED", "IS_READY_SCALED"]
    )),
    ("Classifier", RandomForestClassifier(
        label_cols=["URGENCY_LABEL"],
        output_cols=["PREDICTED_URGENCY"],
        n_estimators=100,
        max_depth=10
    ))
])

# Train model
asset_pipeline.fit(train_asset)
print("✅ Asset maintenance prediction model trained")


In [None]:
# Evaluate and register asset maintenance model
asset_predictions = asset_pipeline.predict(test_asset)

accuracy = accuracy_score(
    df=asset_predictions,
    y_true_col_names="URGENCY_LABEL",
    y_pred_col_names="PREDICTED_URGENCY"
)

asset_metrics = {"accuracy": round(accuracy, 4)}
print(f"Model metrics: {asset_metrics}")

# Register model
reg.log_model(
    model=asset_pipeline,
    model_name="ASSET_MAINTENANCE_PREDICTOR",
    version_name="V1",
    comment="Predicts asset maintenance urgency using Random Forest based on usage and condition",
    metrics=asset_metrics
)

print("✅ Asset maintenance model registered to Model Registry as ASSET_MAINTENANCE_PREDICTOR")


## Summary

All 3 models have been trained and registered to Snowflake Model Registry:

| Model | Type | Output Column | Purpose |
|-------|------|---------------|---------|
| `PROGRAM_RISK_PREDICTOR` | RandomForestClassifier | `PREDICTED_RISK` | Predicts program risk level |
| `SUPPLIER_RISK_PREDICTOR` | RandomForestClassifier | `PREDICTED_RISK` | Identifies supplier risk |
| `ASSET_MAINTENANCE_PREDICTOR` | RandomForestClassifier | `PREDICTED_URGENCY` | Predicts maintenance urgency |

**Next Steps:**
1. Run `sql/ml/07_create_model_wrapper_functions.sql` to create stored procedures that wrap these models
2. Run `sql/agent/08_create_intelligence_agent.sql` to create the Intelligence Agent with ML tools
3. Test the agent in Snowsight under AI & ML > Agents


In [None]:
# Verify all models are registered
print("=" * 60)
print("VERIFICATION: Models registered in Model Registry")
print("=" * 60)

# List all models in registry
models = reg.show_models()
print(models)

print("\n✅ All 3 models registered successfully!")
print("   - PROGRAM_RISK_PREDICTOR")
print("   - SUPPLIER_RISK_PREDICTOR")
print("   - ASSET_MAINTENANCE_PREDICTOR")
