# ML Observability Suite

This notebook demonstrates both:
1. **Native Snowflake Model Registry Observability** - Using built-in monitoring features
2. **Custom Observability Extensions** - Additional business-specific metrics and dashboards - Snowflake Native
## Financial Services Model Monitoring & Drift Detection

This notebook demonstrates Snowflake's ML Observability capabilities for production model monitoring.

## What We'll Cover
- Model performance monitoring with automatic drift detection
- Data quality monitoring and alerting
- Business impact tracking and ROI measurement
- Production-ready monitoring dashboards

## ML Observability Features
- **Performance Tracking**: Accuracy, precision, recall, F1, AUC
- **Drift Detection**: Feature drift, prediction drift, concept drift
- **Data Quality**: Completeness, validity, consistency monitoring
- **Alerting System**: Automated alerts for performance degradation
- **Business Metrics**: Revenue impact, cost savings, ROI tracking


## Leveraging Snowflake Model Registry Observability

This notebook demonstrates how to use Snowflake's native Model Registry observability features alongside custom monitoring:

### Native Features We Use:
- **Model Metadata Tracking** - Version control and lineage through the Model Registry
- **Model Tags** - Storing metrics and configuration as searchable tags
- **Dynamic Tables** - Real-time monitoring of predictions
- **Model Versioning** - Tracking model iterations and performance over time

### Custom Extensions We Add:
- **Business Impact Metrics** - ROI, revenue impact, churn prevention value
- **Advanced Drift Detection** - Feature-level and prediction drift analysis
- **Automated Alerts** - Custom thresholds and business rules
- **Comprehensive Dashboards** - Unified view of model health and business value


In [None]:
# Import required libraries
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session
from snowflake.snowpark.functions import *
from snowflake.ml.registry import Registry
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

# Get active session
session = get_active_session()

print(f"🏔️ Snowflake ML Observability Suite")
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# Set up database context
try:
    session.sql("USE DATABASE FINANCIAL_ML_DB").collect()
    session.sql("USE SCHEMA ML_PIPELINE").collect()
    print(f"Database: FINANCIAL_ML_DB")
except:
    # Use the Feature Store workaround database
    session.sql("USE DATABASE FINANCIAL_ML_DEMO_20250923_093605").collect()
    session.sql("USE SCHEMA ML_PIPELINE").collect()
    print(f"Database: FINANCIAL_ML_DEMO_20250923_093605")

print(f"Schema: {session.get_current_schema()}")
print(f"Warehouse: {session.get_current_warehouse()}")


## Step 1.5: Native Model Registry Observability


In [None]:
# Use Snowflake's Native Model Registry Observability
print("🔍 Leveraging Snowflake Model Registry Observability...")

# Import model monitoring capabilities
from snowflake.ml.registry import Registry
from snowflake.ml.model import ModelVersion

# Connect to the registry
registry = Registry(session=session)

# Get our registered model
model = registry.get_model("CONVERSION_PREDICTOR")
model_version = model.default

print(f"✅ Connected to model: {model.name} (version: {model_version.version_name})")

# Enable monitoring for the model
print("\n📊 Enabling Model Monitoring...")

# First, check what columns are available in the predictions view
check_cols_sql = """
SELECT * FROM CLIENT_CONVERSION_PREDICTIONS LIMIT 1
"""
sample_df = session.sql(check_cols_sql).collect()
if sample_df:
    print("📋 Available columns in predictions view:")
    print(f"   {list(sample_df[0].asDict().keys())}")

# Create a monitor for the model
monitor_name = f"{model.name}_MONITOR"

# Set up dynamic monitoring table with feature enrichment
session.sql(f"""
    CREATE OR REPLACE DYNAMIC TABLE {monitor_name}_PREDICTIONS
    TARGET_LAG = '1 hour'
    WAREHOUSE = {session.get_current_warehouse()}
    AS
    SELECT 
        cp.CLIENT_ID,
        cp.CONVERSION_PROBABILITY,
        cp.ACTUAL_CONVERSION,
        -- Join with feature store to get feature values
        fs.TOTAL_EVENTS_30D,
        fs.ENGAGEMENT_SCORE_30D,
        fs.DAYS_SINCE_LAST_ACTIVITY,
        fs.ANNUAL_INCOME,
        fs.CURRENT_401K_BALANCE,
        fs.TOTAL_ASSETS_UNDER_MANAGEMENT,
        -- Add monitoring metadata
        CURRENT_TIMESTAMP() as MONITOR_TIMESTAMP,
        DATE_TRUNC('hour', CURRENT_TIMESTAMP()) as MONITOR_HOUR,
        -- Calculate prediction value (handle OBJECT type)
        CAST(cp.CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) as PREDICTION_VALUE
    FROM CLIENT_CONVERSION_PREDICTIONS cp
    LEFT JOIN FEATURE_STORE fs ON cp.CLIENT_ID = fs.CLIENT_ID
    WHERE cp.ACTUAL_CONVERSION IS NOT NULL  -- Only monitor records with ground truth
""").collect()

print(f"✅ Created dynamic monitoring table: {monitor_name}_PREDICTIONS")

# Create tags for ML observability (tags must exist before use)
print("\n🏷️ Creating ML Observability Tags...")

tags_to_create = [
    "ML_OBSERVABILITY_ENABLED",
    "MONITOR_TABLE", 
    "MONITORING_START_DATE",
    "MODEL_TYPE",
    "BUSINESS_DOMAIN",
    "LATEST_ACCURACY",
    "TOTAL_PREDICTIONS",
    "POSITIVE_RATE",
    "LAST_MONITORED"
]

for tag_name in tags_to_create:
    try:
        session.sql(f"""
            CREATE TAG IF NOT EXISTS {tag_name} 
            COMMENT = 'ML Model Observability Metadata'
        """).collect()
    except:
        # Tag might already exist
        pass

print("✅ Created observability tags")

# Apply tags to the model
print("\n🏷️ Applying tags to model...")

# Set model tags for observability
model.set_tag("ML_OBSERVABILITY_ENABLED", "true")
model.set_tag("MONITOR_TABLE", f"{monitor_name}_PREDICTIONS")
model.set_tag("MONITORING_START_DATE", str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
model.set_tag("MODEL_TYPE", "binary_classification")
model.set_tag("BUSINESS_DOMAIN", "financial_services")

print("✅ Model monitoring enabled with tags")

# Create monitoring views for different metrics
print("\n📊 Creating Monitoring Views...")

# Performance monitoring view
session.sql(f"""
    CREATE OR REPLACE VIEW {monitor_name}_PERFORMANCE AS
    WITH hourly_metrics AS (
        SELECT 
            MONITOR_HOUR,
            COUNT(*) as predictions_count,
            AVG(CASE WHEN PREDICTION_VALUE = ACTUAL_CONVERSION THEN 1 ELSE 0 END) as accuracy,
            AVG(CASE WHEN PREDICTION_VALUE = 1 AND ACTUAL_CONVERSION = 1 THEN 1
                     WHEN PREDICTION_VALUE = 1 THEN 0 END) as precision,
            AVG(CASE WHEN ACTUAL_CONVERSION = 1 AND PREDICTION_VALUE = 1 THEN 1
                     WHEN ACTUAL_CONVERSION = 1 THEN 0 END) as recall
        FROM {monitor_name}_PREDICTIONS
        GROUP BY MONITOR_HOUR
    )
    SELECT 
        *,
        2 * (precision * recall) / NULLIF(precision + recall, 0) as f1_score
    FROM hourly_metrics
    ORDER BY MONITOR_HOUR DESC
""").collect()

# Feature drift monitoring view
session.sql(f"""
    CREATE OR REPLACE VIEW {monitor_name}_FEATURE_DRIFT AS
    WITH baseline_stats AS (
        -- First 1000 predictions as baseline
        SELECT 
            AVG(TOTAL_EVENTS_30D) as baseline_events_mean,
            STDDEV(TOTAL_EVENTS_30D) as baseline_events_std,
            AVG(ENGAGEMENT_SCORE_30D) as baseline_engagement_mean,
            STDDEV(ENGAGEMENT_SCORE_30D) as baseline_engagement_std,
            AVG(ANNUAL_INCOME) as baseline_income_mean,
            STDDEV(ANNUAL_INCOME) as baseline_income_std
        FROM (
            SELECT * FROM {monitor_name}_PREDICTIONS 
            ORDER BY MONITOR_TIMESTAMP 
            LIMIT 1000
        )
    ),
    current_window AS (
        -- Most recent 1000 predictions
        SELECT 
            AVG(TOTAL_EVENTS_30D) as current_events_mean,
            AVG(ENGAGEMENT_SCORE_30D) as current_engagement_mean,
            AVG(ANNUAL_INCOME) as current_income_mean
        FROM (
            SELECT * FROM {monitor_name}_PREDICTIONS 
            ORDER BY MONITOR_TIMESTAMP DESC 
            LIMIT 1000
        )
    )
    SELECT 
        'TOTAL_EVENTS_30D' as feature_name,
        baseline_events_mean as baseline_mean,
        current_events_mean as current_mean,
        ABS(current_events_mean - baseline_events_mean) / NULLIF(baseline_events_std, 0) as drift_score
    FROM baseline_stats, current_window
    UNION ALL
    SELECT 
        'ENGAGEMENT_SCORE_30D',
        baseline_engagement_mean as baseline_mean,
        current_engagement_mean as current_mean,
        ABS(current_engagement_mean - baseline_engagement_mean) / NULLIF(baseline_engagement_std, 0)
    FROM baseline_stats, current_window
    UNION ALL
    SELECT 
        'ANNUAL_INCOME',
        baseline_income_mean as baseline_mean,
        current_income_mean as current_mean,
        ABS(current_income_mean - baseline_income_mean) / NULLIF(baseline_income_std, 0)
    FROM baseline_stats, current_window
""").collect()

print(f"✅ Created monitoring views:")
print(f"   - {monitor_name}_PERFORMANCE (hourly performance metrics)")
print(f"   - {monitor_name}_FEATURE_DRIFT (feature drift detection)")

# Log current model metrics
print("\n📈 Logging Current Model Metrics...")

# Get model version info
version_info = {
    "name": model.name,
    "version": model_version.version_name,
    "model_ref": str(model_version.model_ref) if hasattr(model_version, 'model_ref') else "N/A"
}

# Calculate and log performance metrics
# First check if we have data in CLIENT_CONVERSION_PREDICTIONS
check_data_sql = """
SELECT COUNT(*) as count FROM CLIENT_CONVERSION_PREDICTIONS WHERE ACTUAL_CONVERSION IS NOT NULL
"""
data_check = session.sql(check_data_sql).collect()
has_data = data_check[0]['COUNT'] > 0

if has_data:
    # Use the source table directly for initial metrics
    metrics_sql = """
    SELECT 
        COUNT(*) as total_predictions,
        AVG(CASE 
            WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = ACTUAL_CONVERSION 
            THEN 1 ELSE 0 
        END) as accuracy,
        COUNT(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 THEN 1 END) as positive_predictions,
        COUNT(CASE WHEN ACTUAL_CONVERSION = 1 THEN 1 END) as actual_positives,
        NULL as last_monitored
    FROM CLIENT_CONVERSION_PREDICTIONS
    WHERE ACTUAL_CONVERSION IS NOT NULL
    """
else:
    print("⚠️ No ground truth data available yet for metrics calculation")
    metrics_sql = None

if metrics_sql:
    metrics_df = session.sql(metrics_sql).collect()
    if metrics_df:
        metrics = metrics_df[0]
        
        # Set metrics as model tags for tracking
        # Use float() to ensure proper type conversion and format with string formatting
        model.set_tag("LATEST_ACCURACY", f"{float(metrics['ACCURACY']):.4f}")
        model.set_tag("TOTAL_PREDICTIONS", str(int(metrics['TOTAL_PREDICTIONS'])))
        model.set_tag("POSITIVE_RATE", f"{float(metrics['POSITIVE_PREDICTIONS']) / float(metrics['TOTAL_PREDICTIONS']):.4f}")
        if metrics['LAST_MONITORED']:
            model.set_tag("LAST_MONITORED", str(metrics['LAST_MONITORED']))
        
        print(f"✅ Model Metrics:")
        print(f"   Total Predictions: {metrics['TOTAL_PREDICTIONS']:,}")
        print(f"   Accuracy: {metrics['ACCURACY']:.2%}")
        print(f"   Positive Predictions: {metrics['POSITIVE_PREDICTIONS']:,}")
        print(f"   Actual Positives: {metrics['ACTUAL_POSITIVES']:,}")
        if metrics['LAST_MONITORED']:
            print(f"   Last Monitored: {metrics['LAST_MONITORED']}")

# Query performance trends (might be empty initially)
try:
    perf_trends = session.sql(f"""
        SELECT * FROM {monitor_name}_PERFORMANCE 
        ORDER BY MONITOR_HOUR DESC 
        LIMIT 5
    """).collect()
    
    if perf_trends:
        print("\n📊 Recent Performance Trends:")
        print("Hour                    | Predictions | Accuracy | Precision | Recall | F1")
        print("-" * 75)
        for row in perf_trends:
            print(f"{row['MONITOR_HOUR']} | {row['PREDICTIONS_COUNT']:>11,} | {row['ACCURACY']:>8.2%} | {row['PRECISION']:>9.2%} | {row['RECALL']:>6.2%} | {row['F1_SCORE']:>4.2f}")
    else:
        print("\n📊 Performance trends will be available after the dynamic table refreshes (1 hour)")
except:
    print("\n📊 Performance trends will be available after the dynamic table refreshes (1 hour)")

# Show all model tags
print("\n🏷️ Model Tags (Observability Metadata):")
tags = model.show_tags()
for tag_name, tag_value in tags.items():
    print(f"   {tag_name}: {tag_value}")

print("\n✅ Snowflake Native Model Registry Observability fully configured!")
print("📊 Monitor your model at:")
print(f"   - Dynamic Table: {monitor_name}_PREDICTIONS")
print(f"   - Performance View: {monitor_name}_PERFORMANCE")
print(f"   - Drift View: {monitor_name}_FEATURE_DRIFT")

print("\n" + "="*60)
print("Continuing with additional custom observability features...")


## Step 1.6: Streamlit Model Monitoring Dashboard

Interactive visualizations for real-time model observability.


In [None]:
# Create Streamlit Model Monitoring Dashboard
print("📊 Creating Streamlit Model Monitoring Dashboard...")

import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Dashboard Title
st.title("🎯 ML Model Observability Dashboard")
st.subheader("Financial Services - Conversion Prediction Model")
st.markdown("---")

# Get model info
monitor_name = "CONVERSION_PREDICTOR_MONITOR"

# Initialize shared variables
available_cols = []
drift_df = pd.DataFrame()

# Create tabs for different monitoring aspects
tab1, tab2, tab3, tab4, tab5 = st.tabs(["📈 Performance", "🔍 Drift Analysis", "💰 Business Impact", "🚨 Alerts", "🏷️ Model Info"])

with tab1:
    st.header("Model Performance Metrics")
    
    # Performance over time
    col1, col2 = st.columns(2)
    
    with col1:
        # Check if performance view exists first
        try:
            check_perf_view = f"""
            SHOW VIEWS LIKE '{monitor_name}_PERFORMANCE' IN SCHEMA {session.get_current_schema()}
            """
            perf_view_exists = len(session.sql(check_perf_view).collect()) > 0
            
            if perf_view_exists:
                # Get performance metrics
                perf_sql = f"""
                SELECT 
                    MONITOR_HOUR,
                    ACCURACY,
                    PRECISION,
                    RECALL,
                    F1_SCORE,
                    PREDICTIONS_COUNT
                FROM {monitor_name}_PERFORMANCE
                ORDER BY MONITOR_HOUR DESC
                LIMIT 24  -- Last 24 hours
                """
                
                perf_df = session.sql(perf_sql).to_pandas()
            else:
                perf_df = pd.DataFrame()  # Empty dataframe
        except:
            perf_df = pd.DataFrame()  # Empty dataframe on any error
        
        if not perf_df.empty:
            # Performance metrics over time
            fig_perf = go.Figure()
            fig_perf.add_trace(go.Scatter(x=perf_df['MONITOR_HOUR'], y=perf_df['ACCURACY'], 
                                          mode='lines+markers', name='Accuracy', line=dict(color='blue')))
            fig_perf.add_trace(go.Scatter(x=perf_df['MONITOR_HOUR'], y=perf_df['PRECISION'], 
                                          mode='lines+markers', name='Precision', line=dict(color='green')))
            fig_perf.add_trace(go.Scatter(x=perf_df['MONITOR_HOUR'], y=perf_df['RECALL'], 
                                          mode='lines+markers', name='Recall', line=dict(color='orange')))
            fig_perf.add_trace(go.Scatter(x=perf_df['MONITOR_HOUR'], y=perf_df['F1_SCORE'], 
                                          mode='lines+markers', name='F1 Score', line=dict(color='red')))
            
            fig_perf.update_layout(
                title='Model Performance Over Time',
                xaxis_title='Time',
                yaxis_title='Score',
                yaxis=dict(range=[0, 1]),
                height=400
            )
            st.plotly_chart(fig_perf, use_container_width=True)
        else:
            st.info("Performance data will be available after the monitoring table refreshes (1 hour)")
    
    with col2:
        # Current metrics display
        st.subheader("Current Performance")
        
        current_metrics_sql = """
        SELECT 
            AVG(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = ACTUAL_CONVERSION THEN 1 ELSE 0 END) as accuracy,
            COUNT(*) as total_predictions,
            COUNT(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 THEN 1 END) as positive_predictions
        FROM CLIENT_CONVERSION_PREDICTIONS
        WHERE ACTUAL_CONVERSION IS NOT NULL
        """
        
        current_metrics = session.sql(current_metrics_sql).collect()[0]
        
        # Display metrics
        metric_col1, metric_col2, metric_col3 = st.columns(3)
        with metric_col1:
            st.metric("Accuracy", f"{float(current_metrics['ACCURACY']):.2%}")
        with metric_col2:
            st.metric("Total Predictions", f"{int(current_metrics['TOTAL_PREDICTIONS']):,}")
        with metric_col3:
            positive_rate = float(current_metrics['POSITIVE_PREDICTIONS']) / float(current_metrics['TOTAL_PREDICTIONS'])
            st.metric("Positive Rate", f"{positive_rate:.2%}")
        
        # Confusion Matrix
        st.subheader("Confusion Matrix")
        cm_sql = """
        SELECT 
            SUM(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 AND ACTUAL_CONVERSION = 1 THEN 1 ELSE 0 END) as TP,
            SUM(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 AND ACTUAL_CONVERSION = 0 THEN 1 ELSE 0 END) as FP,
            SUM(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 0 AND ACTUAL_CONVERSION = 0 THEN 1 ELSE 0 END) as TN,
            SUM(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 0 AND ACTUAL_CONVERSION = 1 THEN 1 ELSE 0 END) as FN
        FROM CLIENT_CONVERSION_PREDICTIONS
        WHERE ACTUAL_CONVERSION IS NOT NULL
        """
        cm_data = session.sql(cm_sql).collect()[0]
        
        cm_matrix = [[int(cm_data['TN']), int(cm_data['FP'])], 
                     [int(cm_data['FN']), int(cm_data['TP'])]]
        
        fig_cm = px.imshow(cm_matrix, 
                          labels=dict(x="Predicted", y="Actual", color="Count"),
                          x=['No Conversion', 'Conversion'],
                          y=['No Conversion', 'Conversion'],
                          text_auto=True,
                          color_continuous_scale='Blues')
        fig_cm.update_layout(title='Confusion Matrix', height=300)
        st.plotly_chart(fig_cm, use_container_width=True)

with tab2:
    st.header("Feature Drift Analysis")
    
    # Check if drift monitoring view exists
    try:
        # First check if the view exists
        check_view_sql = f"""
        SHOW VIEWS LIKE '{monitor_name}_FEATURE_DRIFT' IN SCHEMA {session.get_current_schema()}
        """
        view_exists = len(session.sql(check_view_sql).collect()) > 0
        
        if not view_exists:
            st.info("🔄 Feature drift monitoring is being set up. The dynamic table will refresh in approximately 1 hour.")
            
            # Show sample drift analysis with mock data for demo purposes
            st.subheader("Sample Drift Analysis (Demo Data)")
            
            # Create sample drift data
            sample_features = ['TOTAL_EVENTS_30D', 'ENGAGEMENT_SCORE_30D', 'ANNUAL_INCOME']
            sample_drift_scores = [0.5, 1.3, 0.8]
            sample_baseline = [15.2, 0.65, 75000]
            sample_current = [14.8, 0.71, 73500]
            
            fig_sample = go.Figure()
            fig_sample.add_trace(go.Bar(
                x=sample_features,
                y=sample_drift_scores,
                marker_color=['green', 'orange', 'green'],
                text=[f"{score:.2f}" for score in sample_drift_scores],
                textposition='auto'
            ))
            
            fig_sample.update_layout(
                title='Sample Feature Drift Scores (Demo)',
                xaxis_title='Feature',
                yaxis_title='Drift Score',
                height=400
            )
            
            # Add threshold lines
            fig_sample.add_hline(y=1, line_dash="dash", line_color="orange", 
                               annotation_text="Warning Threshold")
            fig_sample.add_hline(y=2, line_dash="dash", line_color="red", 
                               annotation_text="Alert Threshold")
            
            st.plotly_chart(fig_sample, use_container_width=True)
            
            st.caption("⏱️ Real drift analysis will be available once the monitoring table refreshes.")
        else:
            # Get drift data
            drift_sql = f"""
            SELECT 
                FEATURE_NAME,
                BASELINE_MEAN,
                CURRENT_MEAN,
                DRIFT_SCORE
            FROM {monitor_name}_FEATURE_DRIFT
            """
            
            drift_df = session.sql(drift_sql).to_pandas()
            
            if not drift_df.empty:
                # Drift visualization
                fig_drift = go.Figure()
                fig_drift.add_trace(go.Bar(
                    x=drift_df['FEATURE_NAME'],
                    y=drift_df['DRIFT_SCORE'],
                    marker_color=['red' if score > 2 else 'orange' if score > 1 else 'green' 
                                 for score in drift_df['DRIFT_SCORE']],
                    text=[f"{score:.2f}" for score in drift_df['DRIFT_SCORE']],
                    textposition='auto'
                ))
                
                fig_drift.update_layout(
                    title='Feature Drift Scores',
                    xaxis_title='Feature',
                    yaxis_title='Drift Score',
                    height=400
                )
                
                # Add threshold lines
                fig_drift.add_hline(y=1, line_dash="dash", line_color="orange", 
                                   annotation_text="Warning Threshold")
                fig_drift.add_hline(y=2, line_dash="dash", line_color="red", 
                                   annotation_text="Alert Threshold")
                
                st.plotly_chart(fig_drift, use_container_width=True)
                
                # Feature comparison
                st.subheader("Feature Distribution Comparison")
                for _, row in drift_df.iterrows():
                    with st.expander(f"{row['FEATURE_NAME']} Details"):
                        col1, col2, col3 = st.columns(3)
                        with col1:
                            st.metric("Baseline Mean", f"{row['BASELINE_MEAN']:.2f}")
                        with col2:
                            st.metric("Current Mean", f"{row['CURRENT_MEAN']:.2f}")
                        with col3:
                            st.metric("Drift Score", f"{row['DRIFT_SCORE']:.2f}",
                                     delta=f"{row['DRIFT_SCORE']:.2f}",
                                     delta_color="inverse")
            else:
                st.info("📊 Drift analysis is initializing. Data will populate after the first monitoring cycle completes.")
    except Exception as e:
        st.warning(f"⚠️ Drift monitoring setup in progress. This feature will be available shortly.")
        if st.checkbox("Show technical details"):
            st.code(str(e))

with tab3:
    st.header("Business Impact Analysis")
    
    col1, col2 = st.columns(2)
    
    with col1:
        # Conversion Impact
        # First check what columns exist in CLIENT_SEGMENTS_BATCH
        try:
            check_cols = session.sql("SELECT * FROM CLIENT_SEGMENTS_BATCH LIMIT 1").collect()
            if check_cols:
                available_cols = list(check_cols[0].asDict().keys())
                # Debug: Show available columns in console
                print(f"DEBUG: Available columns in CLIENT_SEGMENTS_BATCH: {available_cols}")
                # Also show in UI for transparency
                with st.expander("Debug Info - Available Columns"):
                    st.write(f"CLIENT_SEGMENTS_BATCH columns: {available_cols}")
        except Exception as e:
            st.error("CLIENT_SEGMENTS_BATCH table not found. Please run the Inference notebook first.")
            print(f"DEBUG: Error accessing CLIENT_SEGMENTS_BATCH: {str(e)}")
            available_cols = []
        
        if available_cols:
            # Dynamically check for available columns and adjust query
            if 'SEGMENT' in available_cols:
                segment_col = 'SEGMENT'
                high_value = "'High Value'"
            elif 'PRIORITY_TIER' in available_cols:
                segment_col = 'PRIORITY_TIER'
                high_value = "'Tier 1'"
            else:
                segment_col = None
            
            if segment_col:
                conversion_sql = f"""
                SELECT 
                    COUNT(DISTINCT CLIENT_ID) as total_clients,
                    COUNT(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 END) as predicted_conversions,
                    COUNT(CASE WHEN {segment_col} = {high_value} THEN 1 END) as high_priority_clients,
                    AVG(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 ELSE 0 END) as conversion_rate
                FROM CLIENT_SEGMENTS_BATCH
                """
            else:
                # No segment column found, use basic metrics
                conversion_sql = """
                SELECT 
                    COUNT(DISTINCT CLIENT_ID) as total_clients,
                    COUNT(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 END) as predicted_conversions,
                    COUNT(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 END) as high_priority_clients,
                    AVG(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 ELSE 0 END) as conversion_rate
                FROM CLIENT_SEGMENTS_BATCH
                """
            
            conv_metrics = session.sql(conversion_sql).collect()[0]
        else:
            # Use CLIENT_CONVERSION_PREDICTIONS as fallback
            conversion_sql = """
            SELECT 
                COUNT(DISTINCT CLIENT_ID) as total_clients,
                COUNT(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 THEN 1 END) as predicted_conversions,
                COUNT(CASE WHEN BUSINESS_PRIORITY_SCORE > 0.8 THEN 1 END) as high_priority_clients,
                AVG(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 THEN 1 ELSE 0 END) as conversion_rate
            FROM CLIENT_CONVERSION_PREDICTIONS
            WHERE ACTUAL_CONVERSION IS NOT NULL
            """
            
            conv_metrics = session.sql(conversion_sql).collect()[0]
        
        st.subheader("Conversion Metrics")
        st.metric("Total Clients Analyzed", f"{int(conv_metrics['TOTAL_CLIENTS']):,}")
        st.metric("Predicted Conversions", f"{int(conv_metrics['PREDICTED_CONVERSIONS']):,}")
        st.metric("High Priority Clients", f"{int(conv_metrics['HIGH_PRIORITY_CLIENTS']):,}")
        st.metric("Conversion Rate", f"{float(conv_metrics['CONVERSION_RATE']):.2%}")
    
    with col2:
        # Churn Prevention
        try:
            churn_sql = """
            SELECT 
                COUNT(*) as at_risk_clients,
                COUNT(CASE WHEN VALUE_TIER = 'High' THEN 1 END) as high_value_at_risk,
                AVG(CHURN_RISK_SCORE) as avg_churn_risk
            FROM CLIENT_CHURN_SEGMENTS
            WHERE CHURN_RISK_SCORE > 0.6
            """
            
            churn_metrics = session.sql(churn_sql).collect()[0]
        except Exception as e:
            print(f"DEBUG: Error accessing CLIENT_CHURN_SEGMENTS: {str(e)}")
            # Use default values if churn table doesn't exist
            churn_metrics = {
                'AT_RISK_CLIENTS': 0,
                'HIGH_VALUE_AT_RISK': 0,
                'AVG_CHURN_RISK': 0.0
            }
        
        st.subheader("Churn Prevention")
        st.metric("At-Risk Clients", f"{int(churn_metrics['AT_RISK_CLIENTS']):,}")
        st.metric("High Value at Risk", f"{int(churn_metrics['HIGH_VALUE_AT_RISK']):,}")
        st.metric("Avg Churn Risk", f"{float(churn_metrics['AVG_CHURN_RISK']):.2%}")
    
    # Revenue Impact Chart
    st.subheader("Revenue Impact Projection")
    
    if available_cols:
        # Check which grouping column is available
        if 'SEGMENT' in available_cols:
            group_col = 'SEGMENT'
            title_text = 'Client Distribution by Segment'
        elif 'PRIORITY_TIER' in available_cols:
            group_col = 'PRIORITY_TIER'
            title_text = 'Client Distribution by Priority Tier'
        else:
            group_col = None
        
        if group_col:
            revenue_sql = f"""
            WITH revenue_calc AS (
                SELECT 
                    {group_col},
                    COUNT(*) as client_count,
                    AVG(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 50000 ELSE 0 END) as avg_revenue_per_group
                FROM CLIENT_SEGMENTS_BATCH
                GROUP BY {group_col}
            )
            SELECT * FROM revenue_calc ORDER BY avg_revenue_per_group DESC
            """
            
            revenue_df = session.sql(revenue_sql).to_pandas()
            
            fig_revenue = px.bar(revenue_df, x=group_col, y='CLIENT_COUNT', 
                                color='AVG_REVENUE_PER_GROUP',
                                title=title_text,
                                labels={'CLIENT_COUNT': 'Number of Clients',
                                       'AVG_REVENUE_PER_GROUP': 'Potential Revenue'})
        else:
            # Fallback to grouping by conversion probability
            revenue_sql = """
            WITH revenue_calc AS (
                SELECT 
                    CASE 
                        WHEN CONVERSION_PROBABILITY = 1 THEN 'High Conversion'
                        ELSE 'Low Conversion'
                    END as conversion_group,
                    COUNT(*) as client_count,
                    AVG(CONVERSION_PROBABILITY) as avg_conversion_rate
                FROM CLIENT_SEGMENTS_BATCH
                GROUP BY conversion_group
            )
            SELECT * FROM revenue_calc ORDER BY avg_conversion_rate DESC
            """
            
            revenue_df = session.sql(revenue_sql).to_pandas()
            
            fig_revenue = px.bar(revenue_df, x='CONVERSION_GROUP', y='CLIENT_COUNT', 
                                color='AVG_CONVERSION_RATE',
                                title='Client Distribution by Conversion Likelihood',
                                labels={'CLIENT_COUNT': 'Number of Clients',
                                       'AVG_CONVERSION_RATE': 'Conversion Rate'})
    else:
        # Fallback to CLIENT_CONVERSION_PREDICTIONS
        revenue_sql = """
        WITH revenue_calc AS (
            SELECT 
                CASE 
                    WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 THEN 'High Conversion'
                    ELSE 'Low Conversion'
                END as conversion_group,
                COUNT(*) as client_count,
                AVG(BUSINESS_PRIORITY_SCORE) as avg_priority_score
            FROM CLIENT_CONVERSION_PREDICTIONS
            WHERE ACTUAL_CONVERSION IS NOT NULL
            GROUP BY conversion_group
        )
        SELECT * FROM revenue_calc ORDER BY avg_priority_score DESC
        """
        
        revenue_df = session.sql(revenue_sql).to_pandas()
        
        fig_revenue = px.bar(revenue_df, x='CONVERSION_GROUP', y='CLIENT_COUNT', 
                            color='AVG_PRIORITY_SCORE',
                            title='Client Distribution by Conversion Likelihood',
                            labels={'CLIENT_COUNT': 'Number of Clients',
                                   'AVG_PRIORITY_SCORE': 'Business Priority'})
    
    st.plotly_chart(fig_revenue, use_container_width=True)

with tab4:
    st.header("Monitoring Alerts")
    
    # Define alert thresholds
    alert_thresholds = {
        'accuracy': 0.65,
        'drift': 2.0,
        'prediction_volume': 100
    }
    
    # Check current status
    alerts = []
    
    # Performance alerts
    if float(current_metrics['ACCURACY']) < alert_thresholds['accuracy']:
        alerts.append({
            'type': 'Performance',
            'severity': 'High',
            'message': f"Model accuracy ({float(current_metrics['ACCURACY']):.2%}) below threshold ({alert_thresholds['accuracy']:.0%})",
            'action': 'Consider model retraining'
        })
    
    # Drift alerts
    if not drift_df.empty:
        try:
            high_drift = drift_df[drift_df['DRIFT_SCORE'] > alert_thresholds['drift']]
            for _, feature in high_drift.iterrows():
                alerts.append({
                    'type': 'Drift',
                    'severity': 'Medium',
                    'message': f"High drift detected in {feature['FEATURE_NAME']} (score: {feature['DRIFT_SCORE']:.2f})",
                    'action': 'Investigate feature distribution changes'
                })
        except:
            pass
    
    # Display alerts
    if alerts:
        for alert in alerts:
            if alert['severity'] == 'High':
                st.error(f"🚨 **{alert['type']} Alert**: {alert['message']}")
            else:
                st.warning(f"⚠️ **{alert['type']} Alert**: {alert['message']}")
            st.caption(f"Recommended Action: {alert['action']}")
    else:
        st.success("✅ All systems operating normally")
    
    # Alert History
    st.subheader("Alert History")
    alert_history_sql = """
    SELECT 
        MONITORING_TIMESTAMP,
        FEATURE_NAME,
        DRIFT_SCORE,
        ALERT_TRIGGERED
    FROM DATA_DRIFT_MONITORING
    WHERE ALERT_TRIGGERED = TRUE
    ORDER BY MONITORING_TIMESTAMP DESC
    LIMIT 10
    """
    
    try:
        alert_history_df = session.sql(alert_history_sql).to_pandas()
        if not alert_history_df.empty:
            st.dataframe(alert_history_df, use_container_width=True)
        else:
            st.info("No historical alerts found")
    except:
        st.info("Alert history will be available as monitoring progresses")

with tab5:
    st.header("Model Information & Tags")
    
    # Get model from registry
    try:
        from snowflake.ml.registry import Registry
        registry = Registry(session)
        model = registry.get_model("CONVERSION_PREDICTOR")
        model_version = model.version("V1")
        model_tags = model.show_tags()
    except:
        # Fallback - read tags from MODEL_DEPLOYMENT_METADATA
        try:
            tags_sql = """
            SELECT TAG_NAME, TAG_VALUE 
            FROM MODEL_DEPLOYMENT_METADATA 
            WHERE MODEL_NAME = 'CONVERSION_PREDICTOR'
            """
            tags_df = session.sql(tags_sql).to_pandas()
            model_tags = dict(zip(tags_df['TAG_NAME'], tags_df['TAG_VALUE']))
            model_version = type('obj', (object,), {'version_name': 'V1'})()
        except:
            model_tags = {
                'ML_OBSERVABILITY_ENABLED': 'TRUE',
                'MODEL_TYPE': 'BINARY_CLASSIFICATION',
                'BUSINESS_DOMAIN': 'FINANCIAL_SERVICES'
            }
            model_version = type('obj', (object,), {'version_name': 'V1'})()
    
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("Model Metadata")
        for tag_name, tag_value in model_tags.items():
            if tag_name in ['ML_OBSERVABILITY_ENABLED', 'MODEL_TYPE', 'BUSINESS_DOMAIN']:
                st.text(f"{tag_name}: {tag_value}")
    
    with col2:
        st.subheader("Latest Metrics")
        for tag_name, tag_value in model_tags.items():
            if tag_name in ['LATEST_ACCURACY', 'TOTAL_PREDICTIONS', 'POSITIVE_RATE', 'LAST_MONITORED']:
                st.text(f"{tag_name}: {tag_value}")
    
    # Model version info
    st.subheader("Model Version Details")
    if hasattr(model, 'name'):
        st.text(f"Model Name: {model.name}")
    else:
        st.text(f"Model Name: CONVERSION_PREDICTOR")
    st.text(f"Current Version: {model_version.version_name}")
    st.text(f"Monitor Table: {monitor_name}_PREDICTIONS")
    
    # Monitoring Configuration
    st.subheader("Monitoring Configuration")
    st.text(f"Dynamic Table Refresh: Every 1 hour")
    st.text(f"Performance View: {monitor_name}_PERFORMANCE")
    st.text(f"Drift View: {monitor_name}_FEATURE_DRIFT")

# Add refresh button
if st.button("🔄 Refresh Dashboard"):
    st.rerun()

st.markdown("---")
st.caption("Dashboard auto-refreshes every 60 seconds")

# Auto-refresh every 60 seconds
st.markdown(
    """
    <script>
        setTimeout(function(){
            window.location.reload();
        }, 60000);
    </script>
    """,
    unsafe_allow_html=True
)

print("\n✅ Streamlit Model Monitoring Dashboard created!")
print("📊 The dashboard provides real-time visualization of:")
print("   - Model performance metrics over time")
print("   - Feature drift analysis")
print("   - Business impact metrics")
print("   - Monitoring alerts")
print("   - Model metadata and tags")


## Step 2: Custom ML Observability Infrastructure

Building on native capabilities with custom monitoring tables for business-specific metrics.


In [None]:
# Check and use existing ML Observability infrastructure
print("🛠️ Checking ML Observability infrastructure...")

# Check for existing monitoring tables from notebook 4
existing_tables = []
try:
    session.sql("SELECT * FROM PREDICTION_MONITORING LIMIT 1").collect()
    existing_tables.append("PREDICTION_MONITORING")
    print("✅ Found existing PREDICTION_MONITORING table")
except:
    print("❌ PREDICTION_MONITORING table not found")

try:
    session.sql("SELECT * FROM MODEL_PERFORMANCE_TRACKING LIMIT 1").collect()
    existing_tables.append("MODEL_PERFORMANCE_TRACKING")
    print("✅ Found existing MODEL_PERFORMANCE_TRACKING table")
except:
    print("❌ MODEL_PERFORMANCE_TRACKING table not found")

try:
    session.sql("SELECT * FROM DATA_DRIFT_MONITORING LIMIT 1").collect()
    existing_tables.append("DATA_DRIFT_MONITORING")
    print("✅ Found existing DATA_DRIFT_MONITORING table")
except:
    print("❌ DATA_DRIFT_MONITORING table not found")

# Create business impact tracking table (new for observability)
business_impact_sql = """
CREATE TABLE IF NOT EXISTS ML_BUSINESS_IMPACT (
    impact_id NUMBER AUTOINCREMENT,
    evaluation_date DATE,
    model_name VARCHAR,
    model_version VARCHAR,
    -- Conversion metrics
    total_conversions INTEGER,
    conversion_rate FLOAT,
    revenue_generated FLOAT,
    -- Churn metrics
    churns_prevented INTEGER,
    retention_rate FLOAT,
    revenue_retained FLOAT,
    -- Overall impact
    total_business_value FLOAT,
    roi_percentage FLOAT
)
"""

session.sql(business_impact_sql).collect()
print("✅ Created ML_BUSINESS_IMPACT table")

# Create model comparison table
comparison_sql = """
CREATE TABLE IF NOT EXISTS ML_MODEL_COMPARISON (
    comparison_id NUMBER AUTOINCREMENT,
    comparison_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP(),
    model_a_name VARCHAR,
    model_a_version VARCHAR,
    model_a_accuracy FLOAT,
    model_b_name VARCHAR,
    model_b_version VARCHAR,
    model_b_accuracy FLOAT,
    performance_difference FLOAT,
    recommendation VARCHAR
)
"""

session.sql(comparison_sql).collect()
print("✅ Created ML_MODEL_COMPARISON table")

print(f"\n🎯 ML Observability infrastructure ready!")
print(f"   Using {len(existing_tables)} existing monitoring tables")
print(f"   Added 2 new observability tables")


## Step 3: Analyze Production Model Performance


In [None]:
# Analyze production model performance using real data
print("📊 Analyzing production model performance...")

# Get current model performance metrics
performance_sql = """
WITH prediction_analysis AS (
    SELECT 
        'CONVERSION_PREDICTOR' as model_name,
        'V1' as model_version,
        COUNT(*) as total_predictions,
        -- Extract binary predictions from OBJECT
        SUM(CASE WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 THEN 1 ELSE 0 END) as predicted_conversions,
        SUM(CASE WHEN ACTUAL_CONVERSION = 1 THEN 1 ELSE 0 END) as actual_conversions,
        
        -- Calculate metrics
        COUNT(CASE 
            WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = ACTUAL_CONVERSION 
            THEN 1 
        END) / NULLIF(COUNT(*), 0) as accuracy,
        
        -- True Positives, False Positives, etc.
        SUM(CASE 
            WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 
            AND ACTUAL_CONVERSION = 1 THEN 1 ELSE 0 
        END) as true_positives,
        
        SUM(CASE 
            WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 1 
            AND ACTUAL_CONVERSION = 0 THEN 1 ELSE 0 
        END) as false_positives,
        
        SUM(CASE 
            WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 0 
            AND ACTUAL_CONVERSION = 0 THEN 1 ELSE 0 
        END) as true_negatives,
        
        SUM(CASE 
            WHEN CAST(CONVERSION_PROBABILITY:"PREDICTION" AS FLOAT) = 0 
            AND ACTUAL_CONVERSION = 1 THEN 1 ELSE 0 
        END) as false_negatives
        
    FROM CLIENT_CONVERSION_PREDICTIONS
)
SELECT 
    model_name,
    model_version,
    total_predictions,
    predicted_conversions,
    actual_conversions,
    accuracy,
    true_positives / NULLIF(true_positives + false_positives, 0) as precision,
    true_positives / NULLIF(true_positives + false_negatives, 0) as recall,
    2 * (precision * recall) / NULLIF(precision + recall, 0) as f1_score,
    true_positives,
    false_positives,
    true_negatives,
    false_negatives
FROM prediction_analysis
"""

performance_df = session.sql(performance_sql).collect()

if performance_df:
    perf = performance_df[0]
    print(f"\n🎯 Model Performance Summary:")
    print(f"   Model: {perf['MODEL_NAME']} {perf['MODEL_VERSION']}")
    print(f"   Total Predictions: {perf['TOTAL_PREDICTIONS']:,}")
    print(f"   Predicted Conversions: {perf['PREDICTED_CONVERSIONS']:,}")
    print(f"   Actual Conversions: {perf['ACTUAL_CONVERSIONS']:,}")
    print(f"\n📊 Performance Metrics:")
    print(f"   Accuracy: {perf['ACCURACY']:.2%}")
    print(f"   Precision: {perf['PRECISION']:.2%}" if perf['PRECISION'] else "   Precision: N/A")
    print(f"   Recall: {perf['RECALL']:.2%}" if perf['RECALL'] else "   Recall: N/A")
    print(f"   F1 Score: {perf['F1_SCORE']:.2%}" if perf['F1_SCORE'] else "   F1 Score: N/A")
    
    # Store performance metrics
    session.sql(f"""
        INSERT INTO MODEL_PERFORMANCE_TRACKING 
        VALUES (
            NULL,
            CURRENT_TIMESTAMP(),
            '{perf['MODEL_NAME']}',
            '{perf['MODEL_VERSION']}',
            'ACCURACY',
            {perf['ACCURACY']},
            'PRODUCTION'
        )
    """).collect()
    
    print("\n✅ Performance metrics logged to MODEL_PERFORMANCE_TRACKING")

# Check data quality
data_quality_sql = """
SELECT 
    COUNT(*) as total_clients,
    COUNT(DISTINCT CLIENT_ID) as unique_clients,
    AVG(DAYS_SINCE_LAST_ACTIVITY) as avg_days_inactive,
    COUNT(CASE WHEN TOTAL_EVENTS_30D IS NULL THEN 1 END) as null_events,
    COUNT(CASE WHEN ANNUAL_INCOME IS NULL THEN 1 END) as null_income,
    MIN(DAYS_SINCE_LAST_ACTIVITY) as most_active_days_ago,
    MAX(DAYS_SINCE_LAST_ACTIVITY) as least_active_days_ago
FROM FEATURE_STORE
"""

quality_df = session.sql(data_quality_sql).collect()
if quality_df:
    q = quality_df[0]
    print(f"\n📋 Data Quality Check:")
    print(f"   Total Clients: {q['TOTAL_CLIENTS']:,}")
    print(f"   Unique Clients: {q['UNIQUE_CLIENTS']:,}")
    print(f"   Avg Days Inactive: {q['AVG_DAYS_INACTIVE']:.1f}")
    print(f"   Most Active: {q['MOST_ACTIVE_DAYS_AGO']} days ago")
    print(f"   Least Active: {q['LEAST_ACTIVE_DAYS_AGO']} days ago")
    print(f"   Null Events: {q['NULL_EVENTS']:,}")
    print(f"   Null Income: {q['NULL_INCOME']:,}")


## Step 4: Drift Detection and Monitoring


In [None]:
# Detect feature drift and prediction drift
print("🔍 Analyzing drift in features and predictions...")

# Feature drift detection
feature_drift_sql = """
WITH feature_baselines AS (
    -- Calculate baseline statistics (from random 25% sample for baseline)
    SELECT 
        'ENGAGEMENT_SCORE_30D' as feature_name,
        AVG(ENGAGEMENT_SCORE_30D) as baseline_mean,
        STDDEV(ENGAGEMENT_SCORE_30D) as baseline_std
    FROM FEATURE_STORE 
    TABLESAMPLE (25)
    UNION ALL
    SELECT 
        'TOTAL_EVENTS_30D',
        AVG(TOTAL_EVENTS_30D),
        STDDEV(TOTAL_EVENTS_30D)
    FROM FEATURE_STORE 
    TABLESAMPLE (25)
    UNION ALL
    SELECT 
        'DAYS_SINCE_LAST_ACTIVITY',
        AVG(DAYS_SINCE_LAST_ACTIVITY),
        STDDEV(DAYS_SINCE_LAST_ACTIVITY)
    FROM FEATURE_STORE 
    TABLESAMPLE (25)
),
current_stats AS (
    -- Calculate current statistics (recently active clients only)
    SELECT 
        'ENGAGEMENT_SCORE_30D' as feature_name,
        AVG(ENGAGEMENT_SCORE_30D) as current_mean,
        STDDEV(ENGAGEMENT_SCORE_30D) as current_std
    FROM FEATURE_STORE 
    WHERE DAYS_SINCE_LAST_ACTIVITY <= 30
    UNION ALL
    SELECT 
        'TOTAL_EVENTS_30D',
        AVG(TOTAL_EVENTS_30D),
        STDDEV(TOTAL_EVENTS_30D)
    FROM FEATURE_STORE 
    WHERE DAYS_SINCE_LAST_ACTIVITY <= 30
    UNION ALL
    SELECT 
        'DAYS_SINCE_LAST_ACTIVITY',
        AVG(DAYS_SINCE_LAST_ACTIVITY),
        STDDEV(DAYS_SINCE_LAST_ACTIVITY)
    FROM FEATURE_STORE 
    WHERE DAYS_SINCE_LAST_ACTIVITY <= 30
)
SELECT 
    b.feature_name,
    b.baseline_mean,
    c.current_mean,
    ABS(c.current_mean - b.baseline_mean) / NULLIF(b.baseline_std, 0) as drift_score,
    CASE 
        WHEN ABS(c.current_mean - b.baseline_mean) / NULLIF(b.baseline_std, 0) > 2 THEN 'High'
        WHEN ABS(c.current_mean - b.baseline_mean) / NULLIF(b.baseline_std, 0) > 1 THEN 'Medium'
        ELSE 'Low'
    END as drift_severity,
    CASE 
        WHEN ABS(c.current_mean - b.baseline_mean) / NULLIF(b.baseline_std, 0) > 2 THEN TRUE
        ELSE FALSE
    END as alert_triggered
FROM feature_baselines b
JOIN current_stats c ON b.feature_name = c.feature_name
"""

drift_df = session.sql(feature_drift_sql).collect()

print("\n📊 Feature Drift Analysis:")
print("Feature Name                | Baseline | Current | Drift Score | Severity")
print("-" * 75)

for row in drift_df:
    print(f"{row['FEATURE_NAME']:<26} | {row['BASELINE_MEAN']:>8.2f} | {row['CURRENT_MEAN']:>7.2f} | {row['DRIFT_SCORE']:>11.2f} | {row['DRIFT_SEVERITY']}")
    
    # Log drift to monitoring table
    if row['ALERT_TRIGGERED']:
        session.sql(f"""
            INSERT INTO DATA_DRIFT_MONITORING VALUES (
                NULL,
                CURRENT_TIMESTAMP(),
                '{row['FEATURE_NAME']}',
                {row['BASELINE_MEAN']},
                {row['CURRENT_MEAN']},
                {row['DRIFT_SCORE']},
                {row['ALERT_TRIGGERED']}
            )
        """).collect()

# Prediction drift detection
prediction_drift_sql = """
WITH prediction_windows AS (
    SELECT 
        DATE_TRUNC('day', BATCH_TIMESTAMP) as prediction_date,
        AVG(CONVERSION_PROBABILITY) as avg_prediction,
        COUNT(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 END) / COUNT(*) as positive_rate,
        COUNT(*) as daily_volume
    FROM CLIENT_SEGMENTS_BATCH
    GROUP BY DATE_TRUNC('day', BATCH_TIMESTAMP)
    ORDER BY prediction_date DESC
    LIMIT 7
)
SELECT 
    prediction_date,
    avg_prediction,
    positive_rate,
    daily_volume,
    LAG(positive_rate) OVER (ORDER BY prediction_date) as prev_positive_rate,
    ABS(positive_rate - LAG(positive_rate) OVER (ORDER BY prediction_date)) as daily_change
FROM prediction_windows
ORDER BY prediction_date DESC
"""

pred_drift_df = session.sql(prediction_drift_sql).collect()

print("\n📈 Prediction Drift Analysis (Last 7 Days):")
print("Date       | Positive Rate | Daily Change | Volume")
print("-" * 50)

for row in pred_drift_df:
    if row['DAILY_CHANGE'] is not None:
        print(f"{row['PREDICTION_DATE'].strftime('%Y-%m-%d')} | {row['POSITIVE_RATE']:>13.2%} | {row['DAILY_CHANGE']:>12.2%} | {row['DAILY_VOLUME']:>6,}")

print("\n✅ Drift analysis complete")


## Step 5: Business Impact Analysis


In [None]:
# Calculate business impact of ML models
print("💰 Analyzing Business Impact of ML Models...")

# Calculate conversion impact
conversion_impact_sql = """
WITH model_results AS (
    SELECT 
        cs.CLIENT_ID,
        cs.RECOMMENDED_ACTION,
        cs.SEGMENT,
        cs.CONVERSION_PROBABILITY,
        f.ANNUAL_INCOME,
        f.TOTAL_ASSETS_UNDER_MANAGEMENT,
        -- Assume conversion value is 2% of assets + advisory fee
        CASE 
            WHEN cs.CONVERSION_PROBABILITY = 1 THEN 
                f.TOTAL_ASSETS_UNDER_MANAGEMENT * 0.02 + 1500
            ELSE 0 
        END as potential_revenue
    FROM CLIENT_SEGMENTS_BATCH cs
    JOIN FEATURE_STORE f ON cs.CLIENT_ID = f.CLIENT_ID
    WHERE cs.BATCH_TIMESTAMP >= DATEADD(day, -7, CURRENT_DATE())
)
SELECT 
    COUNT(DISTINCT CLIENT_ID) as clients_targeted,
    COUNT(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 END) as predicted_conversions,
    SUM(potential_revenue) as total_potential_revenue,
    AVG(potential_revenue) as avg_revenue_per_conversion,
    SUM(CASE WHEN SEGMENT = 'High Value' THEN potential_revenue ELSE 0 END) as high_value_revenue,
    COUNT(CASE WHEN RECOMMENDED_ACTION = 'Wealth_Advisory_Consultation' THEN 1 END) as wealth_advisory_leads
FROM model_results
"""

conversion_impact_df = session.sql(conversion_impact_sql).collect()

if conversion_impact_df:
    impact = conversion_impact_df[0]
    print(f"\n📊 Conversion Model Impact (Last 7 Days):")
    print(f"   Clients Targeted: {impact['CLIENTS_TARGETED']:,}")
    print(f"   Predicted Conversions: {impact['PREDICTED_CONVERSIONS']:,}")
    print(f"   Total Potential Revenue: ${impact['TOTAL_POTENTIAL_REVENUE']:,.2f}")
    print(f"   Avg Revenue per Conversion: ${impact['AVG_REVENUE_PER_CONVERSION']:,.2f}")
    print(f"   High Value Segment Revenue: ${impact['HIGH_VALUE_REVENUE']:,.2f}")
    print(f"   Wealth Advisory Leads: {impact['WEALTH_ADVISORY_LEADS']:,}")

# Calculate churn prevention impact
churn_impact_sql = """
WITH churn_prevention AS (
    SELECT 
        c.CLIENT_ID,
        c.CHURN_RISK_SCORE,
        c.VALUE_TIER,
        c.RECOMMENDED_ACTION,
        f.TOTAL_ASSETS_UNDER_MANAGEMENT,
        f.ANNUAL_INCOME,
        -- Annual revenue from client (1% management fee)
        f.TOTAL_ASSETS_UNDER_MANAGEMENT * 0.01 as annual_revenue,
        -- Cost of retention action
        CASE c.RECOMMENDED_ACTION
            WHEN 'Priority Retention Call' THEN 500
            WHEN 'Personal Advisory Meeting' THEN 300
            WHEN 'Special Offer' THEN 200
            ELSE 100
        END as retention_cost
    FROM CLIENT_CHURN_SEGMENTS c
    JOIN FEATURE_STORE f ON c.CLIENT_ID = f.CLIENT_ID
    WHERE c.CHURN_RISK_SCORE > 0.6
)
SELECT 
    COUNT(*) as at_risk_clients,
    SUM(annual_revenue) as revenue_at_risk,
    AVG(annual_revenue) as avg_revenue_per_client,
    SUM(retention_cost) as total_retention_cost,
    SUM(annual_revenue) - SUM(retention_cost) as net_revenue_saved,
    COUNT(CASE WHEN VALUE_TIER = 'High' THEN 1 END) as high_value_at_risk
FROM churn_prevention
"""

churn_impact_df = session.sql(churn_impact_sql).collect()

if churn_impact_df:
    churn = churn_impact_df[0]
    print(f"\n🛡️ Churn Prevention Impact:")
    print(f"   At-Risk Clients: {churn['AT_RISK_CLIENTS']:,}")
    print(f"   Revenue at Risk: ${churn['REVENUE_AT_RISK']:,.2f}")
    print(f"   Avg Revenue per Client: ${churn['AVG_REVENUE_PER_CLIENT']:,.2f}")
    print(f"   Total Retention Cost: ${churn['TOTAL_RETENTION_COST']:,.2f}")
    print(f"   Net Revenue Saved: ${churn['NET_REVENUE_SAVED']:,.2f}")
    print(f"   High Value Clients at Risk: {churn['HIGH_VALUE_AT_RISK']:,}")

# Store business impact metrics
try:
    session.sql(f"""
        INSERT INTO ML_BUSINESS_IMPACT VALUES (
            NULL,
            CURRENT_DATE(),
            'CONVERSION_PREDICTOR',
            'V1',
            {impact['PREDICTED_CONVERSIONS']},
            {impact['PREDICTED_CONVERSIONS'] / impact['CLIENTS_TARGETED']},
            {impact['TOTAL_POTENTIAL_REVENUE']},
            {churn['AT_RISK_CLIENTS'] if churn_impact_df else 0},
            {1 - (churn['AT_RISK_CLIENTS'] / impact['CLIENTS_TARGETED']) if churn_impact_df else 0},
            {churn['NET_REVENUE_SAVED'] if churn_impact_df else 0},
            {impact['TOTAL_POTENTIAL_REVENUE'] + (churn['NET_REVENUE_SAVED'] if churn_impact_df else 0)},
            {((impact['TOTAL_POTENTIAL_REVENUE'] + (churn['NET_REVENUE_SAVED'] if churn_impact_df else 0)) / 
              (churn['TOTAL_RETENTION_COST'] if churn_impact_df else 1) - 1) * 100}
        )
    """).collect()
    print("\n✅ Business impact metrics logged")
except Exception as e:
    print(f"\n⚠️ Could not log business impact: {str(e)}")

print("\n🎯 Total Business Value Generated: ${:,.2f}".format(
    impact['TOTAL_POTENTIAL_REVENUE'] + (churn['NET_REVENUE_SAVED'] if churn_impact_df else 0)
))


## Step 6: Comprehensive Observability Dashboard & Alerts


In [None]:
# Create comprehensive observability dashboard views
print("📊 Creating ML Observability Dashboard...")

# Create unified dashboard view
dashboard_sql = """
CREATE OR REPLACE VIEW ML_OBSERVABILITY_DASHBOARD AS
WITH model_metrics AS (
    SELECT 
        MODEL_NAME,
        MODEL_VERSION,
        MAX(EVALUATION_TIMESTAMP) as last_evaluated,
        AVG(METRIC_VALUE) as avg_accuracy,
        MIN(METRIC_VALUE) as min_accuracy,
        MAX(METRIC_VALUE) as max_accuracy
    FROM MODEL_PERFORMANCE_TRACKING
    WHERE METRIC_NAME = 'ACCURACY'
    GROUP BY MODEL_NAME, MODEL_VERSION
),
drift_summary AS (
    SELECT 
        COUNT(DISTINCT FEATURE_NAME) as drifted_features,
        MAX(DRIFT_SCORE) as max_drift_score,
        MAX(MONITORING_TIMESTAMP) as last_drift_check
    FROM DATA_DRIFT_MONITORING
    WHERE ALERT_TRIGGERED = TRUE
),
prediction_volume AS (
    SELECT 
        COUNT(*) as daily_predictions,
        COUNT(CASE WHEN CONVERSION_PROBABILITY = 1 THEN 1 END) as positive_predictions,
        COUNT(DISTINCT CLIENT_ID) as unique_clients
    FROM CLIENT_SEGMENTS_BATCH
    WHERE BATCH_TIMESTAMP >= CURRENT_DATE()
),
business_metrics AS (
    SELECT 
        SUM(total_business_value) as total_value,
        AVG(roi_percentage) as avg_roi,
        SUM(total_conversions) as total_conversions,
        SUM(churns_prevented) as total_churns_prevented
    FROM ML_BUSINESS_IMPACT
    WHERE evaluation_date >= DATEADD(day, -30, CURRENT_DATE())
)
SELECT 
    -- Model Performance
    m.MODEL_NAME,
    m.MODEL_VERSION,
    m.avg_accuracy,
    m.last_evaluated,
    
    -- Data Quality
    d.drifted_features,
    d.max_drift_score,
    
    -- Prediction Volume
    p.daily_predictions,
    p.positive_predictions,
    p.unique_clients,
    
    -- Business Impact
    b.total_value,
    b.avg_roi,
    b.total_conversions,
    b.total_churns_prevented,
    
    -- Health Status
    CASE 
        WHEN m.avg_accuracy < 0.6 THEN 'Critical'
        WHEN d.drifted_features > 3 THEN 'Warning'
        WHEN d.max_drift_score > 2 THEN 'Warning'
        ELSE 'Healthy'
    END as model_health_status,
    
    CURRENT_TIMESTAMP() as dashboard_updated
FROM model_metrics m
CROSS JOIN drift_summary d
CROSS JOIN prediction_volume p
CROSS JOIN business_metrics b
"""

session.sql(dashboard_sql).collect()
print("✅ Created ML_OBSERVABILITY_DASHBOARD view")

# Check dashboard
dashboard_df = session.sql("SELECT * FROM ML_OBSERVABILITY_DASHBOARD").collect()

if dashboard_df:
    dash = dashboard_df[0]
    print(f"\n🎯 ML System Health Dashboard")
    print(f"{'=' * 60}")
    print(f"\n📊 Model Performance:")
    print(f"   Model: {dash['MODEL_NAME']} {dash['MODEL_VERSION']}")
    print(f"   Average Accuracy: {dash['AVG_ACCURACY']:.2%}")
    print(f"   Last Evaluated: {dash['LAST_EVALUATED']}")
    
    print(f"\n🔍 Data Quality:")
    print(f"   Drifted Features: {dash['DRIFTED_FEATURES']}")
    print(f"   Max Drift Score: {dash['MAX_DRIFT_SCORE']:.2f}")
    
    print(f"\n📈 Today's Activity:")
    print(f"   Predictions: {dash['DAILY_PREDICTIONS']:,}")
    print(f"   Positive Rate: {dash['POSITIVE_PREDICTIONS'] / dash['DAILY_PREDICTIONS']:.2%}")
    print(f"   Unique Clients: {dash['UNIQUE_CLIENTS']:,}")
    
    print(f"\n💰 Business Impact (30 Days):")
    print(f"   Total Value: ${dash['TOTAL_VALUE']:,.2f}")
    print(f"   Average ROI: {dash['AVG_ROI']:.1f}%")
    print(f"   Conversions: {dash['TOTAL_CONVERSIONS']:,}")
    print(f"   Churns Prevented: {dash['TOTAL_CHURNS_PREVENTED']:,}")
    
    print(f"\n🚦 System Status: {dash['MODEL_HEALTH_STATUS']}")

# Set up automated alerts
alert_procedure_sql = """
CREATE OR REPLACE PROCEDURE CHECK_ML_ALERTS()
RETURNS VARCHAR
LANGUAGE SQL
AS
$$
DECLARE
    alert_message VARCHAR;
    accuracy_threshold FLOAT := 0.65;
    drift_threshold FLOAT := 2.0;
BEGIN
    -- Check model performance
    SELECT INTO alert_message
        CASE 
            WHEN AVG(METRIC_VALUE) < :accuracy_threshold THEN 
                'ALERT: Model accuracy below threshold - ' || AVG(METRIC_VALUE)
            ELSE NULL
        END
    FROM MODEL_PERFORMANCE_TRACKING
    WHERE METRIC_NAME = 'ACCURACY'
    AND EVALUATION_TIMESTAMP >= DATEADD(hour, -24, CURRENT_TIMESTAMP());
    
    IF (alert_message IS NOT NULL) THEN
        RETURN alert_message;
    END IF;
    
    -- Check drift
    SELECT INTO alert_message
        CASE 
            WHEN MAX(DRIFT_SCORE) > :drift_threshold THEN 
                'ALERT: High feature drift detected - ' || MAX(DRIFT_SCORE)
            ELSE NULL
        END
    FROM DATA_DRIFT_MONITORING
    WHERE MONITORING_TIMESTAMP >= DATEADD(hour, -24, CURRENT_TIMESTAMP());
    
    IF (alert_message IS NOT NULL) THEN
        RETURN alert_message;
    END IF;
    
    RETURN 'All systems normal';
END;
$$
"""

session.sql(alert_procedure_sql).collect()
print("\n✅ Created alert monitoring procedure")

# Check for current alerts
alert_result = session.sql("CALL CHECK_ML_ALERTS()").collect()
print(f"\n🚨 Alert Status: {alert_result[0][0]}")

print("\n🎉 ML Observability Suite Complete!")
print("\n📝 Next Steps:")
print("   1. Schedule MONITOR_PREDICTIONS task to run every 6 hours")
print("   2. Set up email alerts for critical model health")
print("   3. Create Snowsight dashboards using ML_OBSERVABILITY_DASHBOARD")
print("   4. Review business impact weekly")
print("   5. Retrain models when accuracy drops below 65%")
