# Kratos Defense ML Models - Model Registry

This notebook trains ML models for the Kratos Defense Intelligence Agent:
- **Program Risk Prediction** - Predict likelihood of program schedule/cost issues
- **Supplier Risk Prediction** - Identify suppliers at risk of performance issues
- **Production Forecaster** - Forecast production throughput and capacity

All models are registered to Snowflake Model Registry and can be added as tools to the Intelligence Agent.

## Prerequisites

**Required Packages** (configured automatically):
- `snowflake-ml-python`
- `scikit-learn`
- `xgboost`
- `matplotlib`

**Database Context:**
- **Database:** KRATOS_INTELLIGENCE  
- **Schema:** ANALYTICS  
- **Warehouse:** KRATOS_WH

**Note:** This notebook uses Snowflake Model Registry. Ensure you have appropriate permissions to create and register models.


In [None]:
# Import Python packages
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Import Snowpark
from snowflake.snowpark.context import get_active_session
import snowflake.snowpark.functions as F
import snowflake.snowpark.types as T
from snowflake.snowpark import Window

# Import Snowpark ML
from snowflake.ml.modeling.preprocessing import StandardScaler, OneHotEncoder
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.linear_model import LinearRegression, LogisticRegression
from snowflake.ml.modeling.ensemble import RandomForestClassifier
from snowflake.ml.modeling.metrics import mean_squared_error, mean_absolute_error, accuracy_score, roc_auc_score
from snowflake.ml.registry import Registry

print("✅ Packages imported successfully")


In [None]:
# Get active Snowflake session
session = get_active_session()

# Set context - MUST MATCH 01_database_and_schema.sql
session.use_database('KRATOS_INTELLIGENCE')
session.use_schema('ANALYTICS')
session.use_warehouse('KRATOS_WH')

print(f"✅ Connected - Role: {session.get_current_role()}")
print(f"   Warehouse: {session.get_current_warehouse()}")
print(f"   Database.Schema: {session.get_fully_qualified_current_schema()}")


## MODEL 1: Program Risk Prediction

Predict whether a program is at risk of schedule delays or cost overruns based on program metrics and history.


In [None]:
# Get program risk data with features
# COLUMN NAMES VERIFIED against 02_create_tables.sql
program_risk_df = session.sql("""
SELECT
    p.program_id,
    DATEDIFF('day', p.start_date, COALESCE(p.planned_end_date, CURRENT_DATE()))::FLOAT AS program_duration_days,
    p.total_contract_value::FLOAT AS contract_value,
    p.funded_value::FLOAT AS funded_value,
    COALESCE(p.cost_variance_pct, 0)::FLOAT AS cost_variance,
    COALESCE(p.schedule_variance_pct, 0)::FLOAT AS schedule_variance,
    COALESCE(p.technology_readiness_level, 5)::FLOAT AS trl,
    p.program_type AS program_type,
    p.classification_level AS classification,
    -- Create risk label based on variance thresholds
    CASE 
        WHEN p.risk_level IN ('HIGH', 'CRITICAL') THEN 1
        WHEN COALESCE(p.cost_variance_pct, 0) < -10 THEN 1
        WHEN COALESCE(p.schedule_variance_pct, 0) < -15 THEN 1
        ELSE 0
    END AS is_at_risk
FROM RAW.PROGRAMS p
WHERE p.program_status IN ('ACTIVE', 'COMPLETED')
  AND p.start_date IS NOT NULL
LIMIT 10000
""")

print(f"Program risk data: {program_risk_df.count()} program records")
program_risk_df.show(5)


In [None]:
# Train/test split (80/20)
train_prog, test_prog = program_risk_df.random_split([0.8, 0.2], seed=42)
train_prog = train_prog.drop("PROGRAM_ID")
test_prog = test_prog.drop("PROGRAM_ID")

# Create pipeline with preprocessing and classification
program_pipeline = Pipeline([
    ("Encoder", OneHotEncoder(
        input_cols=["PROGRAM_TYPE", "CLASSIFICATION"],
        output_cols=["PROGRAM_TYPE_ENC", "CLASSIFICATION_ENC"],
        drop_input_cols=True,
        handle_unknown="ignore"
    )),
    ("Scaler", StandardScaler(
        input_cols=["PROGRAM_DURATION_DAYS", "CONTRACT_VALUE", "FUNDED_VALUE", "COST_VARIANCE", "SCHEDULE_VARIANCE", "TRL"],
        output_cols=["DURATION_SCALED", "CONTRACT_SCALED", "FUNDED_SCALED", "CV_SCALED", "SV_SCALED", "TRL_SCALED"]
    )),
    ("Classifier", RandomForestClassifier(
        label_cols=["IS_AT_RISK"],
        output_cols=["PREDICTED_RISK"],
        n_estimators=100,
        max_depth=10
    ))
])

# Train model
program_pipeline.fit(train_prog)
print("✅ Program risk prediction model trained")


In [None]:
# Evaluate and register program risk model
prog_predictions = program_pipeline.predict(test_prog)

accuracy = accuracy_score(
    df=prog_predictions, 
    y_true_col_names="IS_AT_RISK", 
    y_pred_col_names="PREDICTED_RISK"
)

prog_metrics = {"accuracy": round(accuracy, 4)}
print(f"Model metrics: {prog_metrics}")

# Register model
reg = Registry(session)
reg.log_model(
    model=program_pipeline,
    model_name="PROGRAM_RISK_PREDICTOR",
    version_name="V1",
    comment="Predicts likelihood of program schedule/cost risk using Random Forest based on program metrics",
    metrics=prog_metrics
)

print("✅ Program risk model registered to Model Registry as PROGRAM_RISK_PREDICTOR")


## MODEL 2: Supplier Risk Prediction

Identify suppliers at risk of quality or delivery performance issues.

**CRITICAL:** Column names and data types MUST match `07_create_model_wrapper_functions.sql` EXACTLY.


In [None]:
# Get supplier risk data with features
supplier_risk_df = session.sql("""
WITH supplier_metrics AS (
    SELECT
        s.supplier_id,
        s.supplier_tier::FLOAT AS supplier_tier,
        COALESCE(s.quality_rating, 75)::FLOAT AS quality_rating,
        COALESCE(s.delivery_rating, 75)::FLOAT AS delivery_rating,
        COALESCE(s.overall_rating, 75)::FLOAT AS overall_rating,
        s.supplier_type,
        CASE WHEN s.is_small_business = TRUE THEN 1 ELSE 0 END AS is_small_business,
        COALESCE(AVG(sp.quality_score), 75)::FLOAT AS avg_quality_score,
        COALESCE(AVG(sp.on_time_delivery_pct), 85)::FLOAT AS avg_otd_pct,
        COALESCE(AVG(sp.defect_rate_pct), 2)::FLOAT AS avg_defect_rate,
        COALESCE(SUM(sp.total_orders), 0)::FLOAT AS total_orders,
        COALESCE(SUM(sp.late_orders), 0)::FLOAT AS late_orders,
        COALESCE(SUM(sp.corrective_actions_issued), 0)::FLOAT AS cars_issued
    FROM RAW.SUPPLIERS s
    LEFT JOIN RAW.SUPPLIER_PERFORMANCE sp ON s.supplier_id = sp.supplier_id
    WHERE s.is_active = TRUE
    GROUP BY s.supplier_id, s.supplier_tier, s.quality_rating, s.delivery_rating, 
             s.overall_rating, s.supplier_type, s.is_small_business
)
SELECT
    supplier_id,
    supplier_tier,
    quality_rating,
    delivery_rating,
    overall_rating,
    supplier_type,
    is_small_business,
    avg_quality_score,
    avg_otd_pct,
    avg_defect_rate,
    total_orders,
    late_orders,
    cars_issued,
    CASE 
        WHEN avg_quality_score < 70 THEN 1
        WHEN avg_otd_pct < 80 THEN 1
        WHEN avg_defect_rate > 5 THEN 1
        WHEN cars_issued > 3 THEN 1
        ELSE 0
    END AS is_at_risk
FROM supplier_metrics
LIMIT 5000
""")

print(f"Supplier risk data: {supplier_risk_df.count()} supplier records")
supplier_risk_df.show(5)


In [None]:
# Train supplier risk model
train_sup, test_sup = supplier_risk_df.random_split([0.8, 0.2], seed=42)
train_sup = train_sup.drop("SUPPLIER_ID")
test_sup = test_sup.drop("SUPPLIER_ID")

supplier_pipeline = Pipeline([
    ("Encoder", OneHotEncoder(
        input_cols=["SUPPLIER_TYPE"],
        output_cols=["SUPPLIER_TYPE_ENC"],
        drop_input_cols=True,
        handle_unknown="ignore"
    )),
    ("Scaler", StandardScaler(
        input_cols=["SUPPLIER_TIER", "QUALITY_RATING", "DELIVERY_RATING", "OVERALL_RATING", 
                    "AVG_QUALITY_SCORE", "AVG_OTD_PCT", "AVG_DEFECT_RATE", "TOTAL_ORDERS", "LATE_ORDERS", "CARS_ISSUED"],
        output_cols=["TIER_SCALED", "QUAL_SCALED", "DEL_SCALED", "OVERALL_SCALED",
                     "AVG_QUAL_SCALED", "OTD_SCALED", "DEFECT_SCALED", "ORDERS_SCALED", "LATE_SCALED", "CARS_SCALED"]
    )),
    ("Classifier", RandomForestClassifier(
        label_cols=["IS_AT_RISK"],
        output_cols=["PREDICTED_RISK"],
        n_estimators=100,
        max_depth=10
    ))
])

supplier_pipeline.fit(train_sup)
print("✅ Supplier risk prediction model trained")

# Evaluate and register
sup_predictions = supplier_pipeline.predict(test_sup)
accuracy = accuracy_score(df=sup_predictions, y_true_col_names="IS_AT_RISK", y_pred_col_names="PREDICTED_RISK")
sup_metrics = {"accuracy": round(accuracy, 4)}
print(f"Model metrics: {sup_metrics}")

reg.log_model(
    model=supplier_pipeline,
    model_name="SUPPLIER_RISK_PREDICTOR",
    version_name="V1",
    comment="Predicts supplier performance risk using Random Forest based on quality and delivery metrics",
    metrics=sup_metrics
)
print("✅ Supplier risk model registered to Model Registry as SUPPLIER_RISK_PREDICTOR")


## MODEL 3: Production Forecaster

Forecast production throughput based on historical production metrics.


In [None]:
# Get monthly production data for forecasting
production_forecast_df = session.sql("""
SELECT
    MONTH(wo.planned_start_date) AS month_num,
    YEAR(wo.planned_start_date) AS year_num,
    COUNT(DISTINCT wo.work_order_id)::FLOAT AS work_order_count,
    SUM(wo.quantity_ordered)::FLOAT AS total_quantity_ordered,
    SUM(wo.quantity_completed)::FLOAT AS total_quantity_completed,
    COALESCE(AVG(wo.estimated_hours), 100)::FLOAT AS avg_estimated_hours,
    COALESCE(AVG(wo.actual_hours), 100)::FLOAT AS avg_actual_hours,
    COALESCE(SUM(wo.estimated_cost), 1000000)::FLOAT AS total_estimated_cost,
    COUNT(DISTINCT wo.work_order_id)::FLOAT AS production_count
FROM RAW.PRODUCTION_ORDERS wo
WHERE wo.planned_start_date >= DATEADD('year', -3, CURRENT_DATE())
  AND wo.work_order_status IN ('COMPLETED', 'IN_PROGRESS')
GROUP BY MONTH(wo.planned_start_date), YEAR(wo.planned_start_date)
ORDER BY year_num, month_num
""")

print(f"Production forecast data: {production_forecast_df.count()} monthly records")
production_forecast_df.show(10)


In [None]:
# Train production forecaster model
train_prod, test_prod = production_forecast_df.random_split([0.8, 0.2], seed=42)

production_pipeline = Pipeline([
    ("Scaler", StandardScaler(
        input_cols=["WORK_ORDER_COUNT", "TOTAL_QUANTITY_ORDERED", "TOTAL_QUANTITY_COMPLETED", 
                    "AVG_ESTIMATED_HOURS", "AVG_ACTUAL_HOURS", "TOTAL_ESTIMATED_COST"],
        output_cols=["WO_SCALED", "QTY_ORD_SCALED", "QTY_COMP_SCALED", 
                     "EST_HRS_SCALED", "ACT_HRS_SCALED", "EST_COST_SCALED"]
    )),
    ("Regressor", LinearRegression(
        input_cols=["MONTH_NUM", "YEAR_NUM", "WO_SCALED", "QTY_ORD_SCALED", "QTY_COMP_SCALED", 
                    "EST_HRS_SCALED", "ACT_HRS_SCALED", "EST_COST_SCALED"],
        label_cols=["PRODUCTION_COUNT"],
        output_cols=["PREDICTED_PRODUCTION"]
    ))
])

production_pipeline.fit(train_prod)
print("✅ Production forecaster model trained")

# Evaluate and register
prod_predictions = production_pipeline.predict(test_prod)
mse = mean_squared_error(df=prod_predictions, y_true_col_names="PRODUCTION_COUNT", y_pred_col_names="PREDICTED_PRODUCTION")
mae = mean_absolute_error(df=prod_predictions, y_true_col_names="PRODUCTION_COUNT", y_pred_col_names="PREDICTED_PRODUCTION")
prod_metrics = {"mse": round(mse, 2), "mae": round(mae, 2)}
print(f"Model metrics: {prod_metrics}")

reg.log_model(
    model=production_pipeline,
    model_name="PRODUCTION_FORECASTER",
    version_name="V1",
    comment="Forecasts monthly production throughput using Linear Regression based on historical patterns",
    metrics=prod_metrics
)
print("✅ Production forecaster model registered to Model Registry as PRODUCTION_FORECASTER")


## Summary

All 3 models have been trained and registered to Snowflake Model Registry:

| Model | Type | Output Column | Purpose |
|-------|------|---------------|--------|
| `PROGRAM_RISK_PREDICTOR` | RandomForestClassifier | `PREDICTED_RISK` | Predicts program schedule/cost risk |
| `SUPPLIER_RISK_PREDICTOR` | RandomForestClassifier | `PREDICTED_RISK` | Identifies supplier performance risk |
| `PRODUCTION_FORECASTER` | LinearRegression | `PREDICTED_PRODUCTION` | Forecasts monthly production |

**Next Steps:**
1. Run `sql/ml/07_create_model_wrapper_functions.sql` to create stored procedures that wrap these models
2. Run `sql/agent/08_create_intelligence_agent.sql` to create the Intelligence Agent with ML tools
3. Test the agent in Snowsight under AI & ML > Agents


In [None]:
# Verify all models are registered - reg.show_models() returns a pandas DataFrame, use print() not .show()
print("=" * 60)
print("VERIFICATION: Models registered in Model Registry")
print("=" * 60)
models = reg.show_models()
print(models)

print("\n✅ All 3 models registered successfully!")
print("   - PROGRAM_RISK_PREDICTOR")
print("   - SUPPLIER_RISK_PREDICTOR")
print("   - PRODUCTION_FORECASTER")
