# ML Observability Suite - Snowflake Native
## Financial Services Model Monitoring & Drift Detection

This notebook demonstrates Snowflake's ML Observability capabilities for production model monitoring.

## What We'll Cover
- Model performance monitoring with automatic drift detection
- Data quality monitoring and alerting
- Business impact tracking and ROI measurement
- Production-ready monitoring dashboards

## ML Observability Features
- **Performance Tracking**: Accuracy, precision, recall, F1, AUC
- **Drift Detection**: Feature drift, prediction drift, concept drift
- **Data Quality**: Completeness, validity, consistency monitoring
- **Alerting System**: Automated alerts for performance degradation
- **Business Metrics**: Revenue impact, cost savings, ROI tracking


In [None]:
# Import required libraries
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import *
from snowflake.ml.model_observability import ModelMonitor
from snowflake.ml.registry import Registry
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

# Get active session
session = snowpark.session._get_active_session()

print(f"🏔️ Snowflake ML Observability Suite")
print(f"Database: {session.get_current_database()}")
print(f"Schema: {session.get_current_schema()}")
print(f"Warehouse: {session.get_current_warehouse()}")
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


## Step 1: Setup ML Observability Infrastructure


In [None]:
# Create ML Observability infrastructure
print("🛠️ Setting up ML Observability infrastructure...")

# 1. Create inference logging table (core requirement for ML Observability)
inference_table_sql = """
CREATE OR REPLACE TABLE ml_inference_logs (
    inference_id STRING PRIMARY KEY,
    model_name STRING NOT NULL,
    model_version STRING NOT NULL,
    timestamp TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
    
    -- Client context
    client_id STRING,
    
    -- Input features (for drift detection)
    total_events_30d INTEGER,
    web_visits_30d INTEGER,
    email_opens_30d INTEGER,
    engagement_frequency DECIMAL(10,4),
    annual_income INTEGER,
    current_401k_balance DECIMAL(12,2),
    age INTEGER,
    service_tier STRING,
    
    -- Model predictions
    conversion_probability DECIMAL(5,4),
    churn_probability DECIMAL(5,4),
    recommended_action STRING,
    prediction_confidence DECIMAL(5,4),
    
    -- Ground truth (when available)
    actual_conversion BOOLEAN,
    actual_churn BOOLEAN,
    
    -- Quality metrics
    inference_latency_ms INTEGER,
    data_quality_score DECIMAL(5,4)
)
"""

session.sql(inference_table_sql).collect()
print("✅ Created ml_inference_logs table")

# 2. Create performance tracking table
performance_table_sql = """
CREATE OR REPLACE TABLE ml_model_performance (
    performance_id STRING PRIMARY KEY,
    model_name STRING NOT NULL,
    model_version STRING NOT NULL,
    evaluation_timestamp TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
    accuracy DECIMAL(5,4),
    precision_score DECIMAL(5,4),
    recall_score DECIMAL(5,4),
    f1_score DECIMAL(5,4),
    auc_score DECIMAL(5,4),
    total_inferences INTEGER,
    drift_score DECIMAL(5,4),
    data_quality_score DECIMAL(5,4)
)
"""

session.sql(performance_table_sql).collect()
print("✅ Created ml_model_performance table")

# 3. Create alerts table
alerts_table_sql = """
CREATE OR REPLACE TABLE ml_observability_alerts (
    alert_id STRING PRIMARY KEY,
    alert_timestamp TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
    alert_type STRING,
    severity STRING,
    model_name STRING,
    alert_message STRING,
    current_value DECIMAL(10,4),
    threshold_value DECIMAL(10,4),
    status STRING DEFAULT 'open'
)
"""

session.sql(alerts_table_sql).collect()
print("✅ Created ml_observability_alerts table")

print("🎯 ML Observability infrastructure ready!")


## Step 2: Simulate Inference Data with Ground Truth


In [None]:
# Generate synthetic inference data with ground truth for monitoring
print("📊 Generating inference data with ground truth...")

inference_data_sql = """
WITH inference_simulation AS (
  SELECT 
    'inf_' || seq8() as inference_id,
    'CONVERSION_PREDICTOR_V1' as model_name,
    'V1.0' as model_version,
    DATEADD(minute, -UNIFORM(1, 10080, RANDOM()), CURRENT_TIMESTAMP()) as timestamp,
    
    -- Sample client data
    'client_' || UNIFORM(1, 50000, RANDOM()) as client_id,
    UNIFORM(10, 100, RANDOM()) as total_events_30d,
    UNIFORM(5, 50, RANDOM()) as web_visits_30d,
    UNIFORM(2, 20, RANDOM()) as email_opens_30d,
    ROUND(UNIFORM(0.1, 2.0, RANDOM()), 4) as engagement_frequency,
    UNIFORM(30000, 150000, RANDOM()) as annual_income,
    UNIFORM(10000, 300000, RANDOM()) as current_401k_balance,
    UNIFORM(25, 65, RANDOM()) as age,
    CASE 
      WHEN UNIFORM(0, 1, RANDOM()) < 0.6 THEN 'Basic'
      WHEN UNIFORM(0, 1, RANDOM()) < 0.9 THEN 'Premium'
      ELSE 'Elite'
    END as service_tier,
    
    -- Simulated model predictions
    ROUND(UNIFORM(0.1, 0.9, RANDOM()), 4) as conversion_probability,
    ROUND(UNIFORM(0.1, 0.7, RANDOM()), 4) as churn_probability,
    CASE 
      WHEN UNIFORM(0, 1, RANDOM()) < 0.3 THEN 'Wealth_Advisory_Consultation'
      WHEN UNIFORM(0, 1, RANDOM()) < 0.6 THEN 'Schedule_Planning_Session'
      ELSE 'Educational_Content'
    END as recommended_action,
    ROUND(UNIFORM(0.5, 1.0, RANDOM()), 4) as prediction_confidence,
    
    -- Simulated ground truth (with realistic correlation to predictions)
    CASE 
      WHEN conversion_probability > 0.7 AND UNIFORM(0, 1, RANDOM()) < 0.8 THEN TRUE
      WHEN conversion_probability > 0.5 AND UNIFORM(0, 1, RANDOM()) < 0.6 THEN TRUE
      WHEN conversion_probability > 0.3 AND UNIFORM(0, 1, RANDOM()) < 0.3 THEN TRUE
      ELSE FALSE
    END as actual_conversion,
    
    CASE 
      WHEN churn_probability > 0.6 AND UNIFORM(0, 1, RANDOM()) < 0.7 THEN TRUE
      WHEN churn_probability > 0.4 AND UNIFORM(0, 1, RANDOM()) < 0.4 THEN TRUE
      ELSE FALSE
    END as actual_churn,
    
    UNIFORM(10, 100, RANDOM()) as inference_latency_ms,
    ROUND(UNIFORM(0.85, 1.0, RANDOM()), 4) as data_quality_score
    
  FROM TABLE(GENERATOR(ROWCOUNT => 5000))
)

SELECT * FROM inference_simulation
"""

# Execute and load inference data
inference_df = session.sql(inference_data_sql)
inference_df.write.mode("overwrite").save_as_table("ml_inference_logs")

# Verify data
count_result = session.sql("SELECT COUNT(*) as count FROM ml_inference_logs").collect()
inference_count = count_result[0]['COUNT']

print(f"✅ Generated {inference_count:,} inference records with ground truth")

# Show sample data
print("\n📋 Sample inference data:")
session.sql("SELECT * FROM ml_inference_logs LIMIT 5").show()
