In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


# Power Grid Forecasting: Partitioned Models at Scale

## Demo Overview
This notebook demonstrates a complete end-to-end ML pipeline for forecasting power production across 5,000 electrical substations using Snowflake's partitioned models capabilities.

### What We'll Build:
- **Data Simulation**: Generate realistic time-series data for 5,000 substations
- **Feature Store**: Point-in-time correct features with serving parity
- **Partitioned Training**: One XGBoost model per substation (5,000 models)
- **Distributed Inference**: Parallel predictions across all substations
- **Production Orchestration**: Automated retraining with ML Jobs and Tasks
- **Monitoring**: Observability and model performance tracking

### Key Benefits:
- **Scalability**: Train 5,000+ models in parallel
- **Performance**: All computation stays close to data
- **Simplicity**: Single notebook, SQL + Python
- **Production-Ready**: Automated retraining and monitoring

## 1. Environment Setup & DDL

First, let's create our database objects and infrastructure.

In [None]:
-- Create database and schemas
CREATE DATABASE IF NOT EXISTS POWERGRID_DEMO;
CREATE SCHEMA IF NOT EXISTS POWERGRID_DEMO.PUBLIC;  -- Main schema for all tables and views
CREATE SCHEMA IF NOT EXISTS POWERGRID_DEMO.ML;      -- ML-specific objects (stages, procedures)

-- Warehouse for interactive development
CREATE WAREHOUSE IF NOT EXISTS DEMO_WH
  WAREHOUSE_SIZE = 'X-LARGE'
  AUTO_SUSPEND = 300
  AUTO_RESUME = TRUE
  INITIALLY_SUSPENDED = TRUE;

-- Stage for Python files and model artifacts
CREATE OR REPLACE STAGE POWERGRID_DEMO.ML.CODE_STAGE
  DIRECTORY = (ENABLE = TRUE);


-- Compute Pool for ML Jobs
CREATE COMPUTE POOL IF NOT EXISTS SUBSTATION_MMT_EXAMPLE
  MIN_NODES = 2
  MAX_NODES = 20
  INSTANCE_FAMILY = CPU_X64_S;

-- Set context
USE DATABASE POWERGRID_DEMO;
USE SCHEMA POWERGRID_DEMO.PUBLIC;
USE WAREHOUSE DEMO_WH;

### Create Core Data Tables

In [None]:
-- All tables will be in the PUBLIC schema

-- Training data: Historical production with weather features
CREATE OR REPLACE TABLE SUBSTATION_PRODUCTION_TRAIN (
  SUBSTATION_ID       NUMBER,         -- 1..5000
  TIMESTAMP_HOUR      TIMESTAMP_NTZ,  -- hourly timestamps
  HIST_PRODUCTION     FLOAT,          -- observed production (target)
  TEMP_C              FLOAT,          -- temperature in Celsius
  WIND_MPS            FLOAT,          -- wind speed in m/s
  HOUR_OF_DAY         NUMBER,         -- 0-23
  DAY_OF_WEEK         NUMBER,         -- 0-6
  CREATED_AT          TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
);

-- Inference input: Future weather conditions (no production label)
CREATE OR REPLACE TABLE SUBSTATION_PRODUCTION_INFERENCE_INPUT (
  SUBSTATION_ID       NUMBER,
  TIMESTAMP_HOUR      TIMESTAMP_NTZ,
  TEMP_C              FLOAT,
  WIND_MPS            FLOAT,
  HOUR_OF_DAY         NUMBER,
  DAY_OF_WEEK         NUMBER,
  CREATED_AT          TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
);

-- Forecast outputs
CREATE OR REPLACE TABLE SUBSTATION_FORECASTS (
  SUBSTATION_ID       NUMBER,
  TIMESTAMP_HOUR      TIMESTAMP_NTZ,
  PREDICTION          FLOAT,
  MODEL_VERSION       STRING,
  CREATED_AT          TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
);

-- Model performance tracking
CREATE OR REPLACE TABLE MODEL_PERFORMANCE_METRICS (
  SUBSTATION_ID       NUMBER,
  MODEL_VERSION       STRING,
  METRIC_DATE         DATE,
  MAE                 FLOAT,
  RMSE                FLOAT,
  MAPE                FLOAT,
  PREDICTION_COUNT    NUMBER,
  CREATED_AT          TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
);


## 2. Data Simulation

Let's generate realistic time-series data for 5,000 substations with seasonal patterns, weather dependencies, and realistic noise.

In [None]:
from datetime import datetime, timedelta
# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

def generate_substation_data_sql(session, n_substations=5000, days_history=365):
    """
    Generate realistic power production data using pure SQL in Snowflake.
    MUCH faster than Python loops - executes entirely in Snowflake's engine!
    
    Features:
    - Seasonal patterns (higher production in summer)
    - Daily cycles (peak around 2 PM)
    - Weather dependencies  
    - Individual substation characteristics
    - Realistic noise
    """
    print(f"🚀 Generating {n_substations:,} substations × {days_history} days = {n_substations * days_history * 24:,} rows")
    print(f"   • Using pure SQL - executes in Snowflake engine (100x faster than Python!)")
    
    end_time = datetime.utcnow().replace(minute=0, second=0, microsecond=0)
    start_time = end_time - timedelta(days=days_history)
    
    # Generate data using PURE SQL - runs in Snowflake, not Python!
    print(f"   • Executing SQL generation...")
    
    session.sql(f"""
        CREATE OR REPLACE TABLE SUBSTATION_PRODUCTION_TRAIN AS
        WITH 
        substations AS (
            SELECT SEQ4() + 1 AS SUBSTATION_ID
            FROM TABLE(GENERATOR(ROWCOUNT => {n_substations}))
        ),
        hours AS (
            SELECT DATEADD(hour, SEQ4(), '{start_time}'::TIMESTAMP_NTZ) AS TIMESTAMP_HOUR
            FROM TABLE(GENERATOR(ROWCOUNT => {days_history * 24}))
        ),
        base_data AS (
            SELECT 
                s.SUBSTATION_ID,
                h.TIMESTAMP_HOUR,
                HOUR(h.TIMESTAMP_HOUR) AS HOUR_OF_DAY,
                DAYOFWEEK(h.TIMESTAMP_HOUR) AS DAY_OF_WEEK,
                DAYOFYEAR(h.TIMESTAMP_HOUR) AS DAY_OF_YEAR
            FROM substations s CROSS JOIN hours h
        )
        SELECT
            SUBSTATION_ID,
            TIMESTAMP_HOUR,
            -- Realistic power production with seasonal and daily patterns
            -- Use HASH for deterministic "randomness" per substation
            GREATEST(0, 
                (40 + (HASH(SUBSTATION_ID) % 40 + 20)) *
                (1 + 0.3 * SIN(2 * PI() * (DAY_OF_YEAR - 80) / 365)) *
                (1 + 0.4 * SIN(2 * PI() * (HOUR_OF_DAY - 6) / 24)) +
                (HASH(SUBSTATION_ID, TIMESTAMP_HOUR) % 100 - 50) / 10.0  -- Noise: -5 to +5
            ) AS HIST_PRODUCTION,
            -- Weather features with patterns
            15 + 10 * SIN(2 * PI() * (DAY_OF_YEAR - 80) / 365) + 
                5 * SIN(2 * PI() * HOUR_OF_DAY / 24) + 
                (HASH(TIMESTAMP_HOUR) % 40 - 20) / 10.0 AS TEMP_C,  -- Noise: -2 to +2
            ABS(5 + 3 * SIN(2 * PI() * DAY_OF_YEAR / 365) + 
                (HASH(DAY_OF_YEAR) % 40 - 20) / 10.0) AS WIND_MPS,  -- Noise: -2 to +2
            HOUR_OF_DAY,
            DAY_OF_WEEK
        FROM base_data
    """).collect()
    
    row_count = session.table("SUBSTATION_PRODUCTION_TRAIN").count()
    print(f"✅ Generated {row_count:,} rows in Snowflake!")
    print("   • Data generation complete!")

# Generate the training data - FULL YEAR using pure SQL
generate_substation_data_sql(session, n_substations=5000, days_history=365)


In [None]:
select * from SUBSTATION_PRODUCTION_TRAIN
LIMIT 10

### Generate Future Inference Data

In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()
import numpy as np
import pandas as pd 

def generate_inference_data(session: session, n_substations: int = 5000, hours_ahead: int = 48):
    """Generate future weather conditions for inference (next 48 hours)."""
    print(f"Generating inference data for next {hours_ahead} hours...")
    
    # Future time range
    start_time = datetime.utcnow().replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
    future_hours = [start_time + timedelta(hours=i) for i in range(hours_ahead)]
    
    inference_data = []
    
    for substation_id in range(1, n_substations + 1):
        # Use substation-specific seed for consistent weather patterns
        np.random.seed(substation_id + 10000)
        
        # Base weather for this substation (slight regional variation)
        base_temp = 18 + np.random.uniform(-5, 5)
        base_wind = 6 + np.random.uniform(-2, 2)
        
        for timestamp in future_hours:
            # Add daily temperature cycle
            temp_cycle = 4 * np.sin(2 * np.pi * (timestamp.hour - 6) / 24)
            temp = base_temp + temp_cycle + np.random.normal(0, 1.5)
            
            # Wind with some persistence
            wind = max(0, base_wind + np.random.normal(0, 1.5))
            
            inference_data.append({
                'SUBSTATION_ID': substation_id,
                'TIMESTAMP_HOUR': timestamp,  # Already a Python datetime
                'TEMP_C': float(temp),
                'WIND_MPS': float(wind),
                'HOUR_OF_DAY': timestamp.hour,
                'DAY_OF_WEEK': timestamp.weekday()
            })
    
    # Write to Snowflake
    df_inference = pd.DataFrame(inference_data)
    
    # Ensure proper data types
    df_inference['SUBSTATION_ID'] = df_inference['SUBSTATION_ID'].astype('int64')
    df_inference['TEMP_C'] = df_inference['TEMP_C'].astype('float64')
    df_inference['WIND_MPS'] = df_inference['WIND_MPS'].astype('float64')
    df_inference['HOUR_OF_DAY'] = df_inference['HOUR_OF_DAY'].astype('int64')
    df_inference['DAY_OF_WEEK'] = df_inference['DAY_OF_WEEK'].astype('int64')
    
    # Use Snowpark DataFrame for better type handling
    snowpark_df = session.create_dataframe(df_inference)
    snowpark_df.write.mode("overwrite").save_as_table("SUBSTATION_PRODUCTION_INFERENCE_INPUT")
    
    print(f"Generated {len(inference_data)} inference records.")
    return df_inference

# Generate inference data
inference_data = generate_inference_data(session)


## 3. Feature Store Implementation

Build a production-ready feature store with point-in-time correctness and serving parity.


In [None]:
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.snowpark import DataFrame
import snowflake.snowpark.functions as F
from snowflake.snowpark.window import Window
from datetime import timedelta
import warnings

# Suppress the online_config private preview warning
# We're using batch/offline features only, not online serving
warnings.filterwarnings('ignore', message='.*online_config.*')

# Initialize Feature Store
# The feature store will be created if it doesn't exist
fs = FeatureStore(
    session=session,
    database="POWERGRID_DEMO",
    name="SUBSTATION_FEATURE_STORE",
        default_warehouse="DEMO_WH",
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST
)

print("✅ Feature Store initialized")
print("   NOTE: Using BATCH/OFFLINE features only (no online serving)")

# Define the Substation Entity
substation_entity = Entity(
    name="SUBSTATION",
    join_keys=["SUBSTATION_ID"],
    desc="Electrical substation entity for power grid forecasting"
)

# Register the entity
fs.register_entity(substation_entity)
print("✅ Entity 'SUBSTATION' registered")

# Create source data view for feature engineering
# Note: Fully qualify the table name to reference the PUBLIC schema
session.sql("""
CREATE OR REPLACE VIEW POWERGRID_DEMO.PUBLIC.SUBSTATION_SOURCE_DATA AS
SELECT
    SUBSTATION_ID,
    TIMESTAMP_HOUR,
    HIST_PRODUCTION,
    TEMP_C,
    WIND_MPS,
    HOUR_OF_DAY,
    DAY_OF_WEEK,
    -- Derived time features
    CASE WHEN HOUR_OF_DAY BETWEEN 6 AND 18 THEN 1 ELSE 0 END as IS_DAYTIME,
    CASE WHEN DAY_OF_WEEK IN (5, 6) THEN 1 ELSE 0 END as IS_WEEKEND,
    EXTRACT(MONTH FROM TIMESTAMP_HOUR) as MONTH,
    -- Weather interaction features
    TEMP_C * WIND_MPS as TEMP_WIND_INTERACTION,
    CASE WHEN TEMP_C > 25 THEN 1 ELSE 0 END as IS_HOT,
    CASE WHEN WIND_MPS > 8 THEN 1 ELSE 0 END as IS_WINDY
FROM POWERGRID_DEMO.PUBLIC.SUBSTATION_PRODUCTION_TRAIN
""").collect()

print("✅ Source data view created")

# Define Feature View with time-series features
# Note: Feature Views in Snowflake work differently - we'll create a managed approach
print("📊 Creating Feature View with time-series features...")

# Get source data as Snowpark DataFrame (from PUBLIC schema)
source_df = session.table("POWERGRID_DEMO.PUBLIC.SUBSTATION_SOURCE_DATA")

# Define window specs for partitioned calculations
window_by_substation = Window.partition_by("SUBSTATION_ID").order_by("TIMESTAMP_HOUR")

# Define rolling windows (using rows between to ensure PIT correctness)
window_24h_past = Window.partition_by("SUBSTATION_ID").order_by("TIMESTAMP_HOUR").rows_between(-24, -1)
window_168h_past = Window.partition_by("SUBSTATION_ID").order_by("TIMESTAMP_HOUR").rows_between(-168, -1)

# Create the feature DataFrame with all transformations
features_df = source_df.select(
    F.col("SUBSTATION_ID"),
    F.col("TIMESTAMP_HOUR"),
    F.col("HIST_PRODUCTION"),
    F.col("TEMP_C"),
    F.col("WIND_MPS"),
    F.col("HOUR_OF_DAY"),
    F.col("DAY_OF_WEEK"),
    F.col("IS_DAYTIME"),
    F.col("IS_WEEKEND"),
    F.col("MONTH"),
    F.col("TEMP_WIND_INTERACTION"),
    F.col("IS_HOT"),
    F.col("IS_WINDY"),
    # Lag features
    F.lag("HIST_PRODUCTION", 1).over(window_by_substation).alias("LAG_1H"),
    F.lag("HIST_PRODUCTION", 24).over(window_by_substation).alias("LAG_24H"),
    F.lag("HIST_PRODUCTION", 168).over(window_by_substation).alias("LAG_1W"),
    # Rolling statistics (past 24 hours)
    F.avg("HIST_PRODUCTION").over(window_24h_past).alias("ROLL_MEAN_24H"),
    F.stddev("HIST_PRODUCTION").over(window_24h_past).alias("ROLL_STD_24H"),
    F.max("HIST_PRODUCTION").over(window_24h_past).alias("ROLL_MAX_24H"),
    F.min("HIST_PRODUCTION").over(window_24h_past).alias("ROLL_MIN_24H"),
    # Weekly rolling mean
    F.avg("HIST_PRODUCTION").over(window_168h_past).alias("ROLL_MEAN_1W")
)

# Calculate trend (difference between lag 1 and lag 25)
features_df = features_df.with_column(
    "TREND_24H",
    F.col("LAG_1H") - F.lag("HIST_PRODUCTION", 25).over(window_by_substation)
)

# Register as a Feature View in the Feature Store
feature_view = fs.register_feature_view(
    feature_view=FeatureView(
        name="SUBSTATION_TIME_SERIES_FEATURES",
        entities=[substation_entity],
        feature_df=features_df,
        timestamp_col="TIMESTAMP_HOUR",
        desc="Time-series features for substation power production forecasting"
    ),
    version="V1"
)

print("✅ Feature View 'SUBSTATION_TIME_SERIES_FEATURES' registered")

print("✅ Feature View registered")

# Generate training dataset from Feature Store
print("\n📊 Generating training dataset from Feature Store...")
from datetime import datetime, timedelta

end_date = datetime.utcnow()
start_date = end_date - timedelta(days=365)

# Create spine (all substation-timestamp combinations for training)
spine_df = session.table("POWERGRID_DEMO.PUBLIC.SUBSTATION_PRODUCTION_TRAIN") \
    .filter((F.col("TIMESTAMP_HOUR") >= start_date) & (F.col("TIMESTAMP_HOUR") <= end_date)) \
    .select(F.col("SUBSTATION_ID"), F.col("TIMESTAMP_HOUR"))

# Generate dataset from Feature Store
training_dataset = fs.generate_dataset(
    spine_df=spine_df,
    features=[feature_view],
    name="SUBSTATION_TRAINING_DATASET",
    spine_timestamp_col="TIMESTAMP_HOUR",
    spine_label_cols=["HIST_PRODUCTION"]
)

# Materialize for ManyModelTraining (needs simple table)
print("   • Materializing dataset...")
training_dataset.read.to_snowpark_dataframe().write.mode("overwrite").save_as_table("SUBSTATION_TRAINING_DATA")

# Read with SQL (not session.table) - forces fresh query plan
training_df = session.sql("SELECT * FROM SUBSTATION_TRAINING_DATA")

print(f"✅ Feature Store ready!")
print(f"   • Entity: SUBSTATION")
print(f"   • Feature View: SUBSTATION_TIME_SERIES_FEATURES")
print(f"   • Training Dataset: SUBSTATION_TRAINING_DATASET")
print(f"   • Period: {start_date.date()} to {end_date.date()}")


In [None]:
training_df.show()

## 4. Partitioned Model Training

Train one XGBoost model per substation using Snowflake's partitioned model capabilities.

In [None]:
import pandas as pd
import numpy as np
from xgboost import XGBRegressor
from datetime import datetime
import time
from sklearn.metrics import mean_squared_error, r2_score


# Training function (following Snowflake docs pattern)
def train_substation_model(data_connector, context):
    """Train XGBoost model for one substation partition."""

    partition_id = context.partition_id
    assert partition_id is not None
    
    # Load partitioned data.
    pandas_df: pd.DataFrame = data_connector.to_pandas()
    print(f"Training model for partition: {partition_id}")
    
    # Define Feature and Target Columns 
    
    
    FEATURE_COLS = [
    'TEMP_C', 'WIND_MPS', 'HOUR_OF_DAY', 'DAY_OF_WEEK',
    'IS_DAYTIME', 'IS_WEEKEND', 'MONTH',
    'TEMP_WIND_INTERACTION', 'IS_HOT', 'IS_WINDY',
    'LAG_1H', 'LAG_24H', 'LAG_1W',
    'ROLL_MEAN_24H', 'ROLL_STD_24H', 'ROLL_MAX_24H', 'ROLL_MIN_24H',
    'ROLL_MEAN_1W', 'TREND_24H'
]

    TARGET_COL = 'HIST_PRODUCTION'
    
    
    # Prepare features and target
    X = pandas_df[FEATURE_COLS]
    y = pandas_df[TARGET_COL]
    
    # Train XGBoost
    model = XGBRegressor(
        n_estimators=50,
        max_depth=6,
        learning_rate=0.05,
        random_state=42
    )
    model.fit(X, y)

    # Evaluate on training data
    preds = model.predict(pandas_df[FEATURE_COLS])
    mse = mean_squared_error(pandas_df[TARGET_COL], preds)
    r2 = r2_score(pandas_df[TARGET_COL], preds)
    
    print(f"Model trained for {partition_id}")
    return model

In [None]:
# Optional step to scale to multiple nodes for speed up overall many model trainings.
from snowflake.ml.runtime_cluster import cluster_manager
TOTAL_NODES=20
cluster_manager.scale_cluster(expected_cluster_size=TOTAL_NODES, notebook_name="SUBSTATION_MMT", options={"block_until_min_cluster_size": 2})

In [None]:
from snowflake.ml.modeling.distributors.many_model import ManyModelTraining
from snowflake.ml.modeling.distributors.distributed_partition_function.entities import (
    ExecutionOptions,
    RunStatus,
)

trainer = ManyModelTraining(
     train_substation_model,    
    stage_name="@POWERGRID_DEMO.ML.ML_STAGE",    
)

run_id="my_mmt_model_v1"
training_run = trainer.run(
    snowpark_dataframe=training_df,
    partition_by="SUBSTATION_ID",
    run_id=run_id,
    on_existing_artifacts="overwrite", # or "error"
    # execution_options is optional. When running in a multi-node setting, it's recommended setting use_head_node=False to exclude head node from doing actual training, this improves overall MMT training reliability.
    # execution_options=ExecutionOptions(use_head_node=False)
)

In [None]:
## Value filters for test 

#from snowflake.snowpark.functions import col
#values_to_filter = [1,2,3,4,5]

#training_df = training_df.filter(col("SUBSTATION_ID").in_(values_to_filter))

In [None]:
training_run.get_progress()

In [None]:
assert training_run.wait() == RunStatus.SUCCESS

In [None]:
# Run if you want to cancel the training job 
#training_run.cancel()

In [None]:
training_run.get_progress()

In [None]:
training_run.get_progress()["DONE"][0].logs # inspect result

# To inspect failures
# training_run.get_progress()["FAILED"][0].logs 

In [None]:
# To inspect failures
training_run.get_progress()["FAILED"][0].logs 

In [None]:
# To obtain models trained with each partition
import xgboost as xgb
for partition_id in training_run.partition_details.keys():
    model = training_run.get_model(partition_id)
    assert isinstance(model, xgb.XGBRegressor)

## Step 3: Running Inference on Trained Models

### Step 3.1: Register Models in the Snowflake Model Registry and Run Inference (Warehouse Execution) — GA Feature

In [None]:
#models = {
#    partition_id: training_run.get_model(partition_id)
#    for partition_id in training_run.partition_details
#}

In [None]:
# from typing import Optional
# from snowflake.ml.model import custom_model
# from snowflake.ml.registry import registry
# import pandas as pd


# # Log model to model registry
# class PartitionedModel(custom_model.CustomModel):
#     def __init__(self, context: Optional[custom_model.ModelContext] = None) -> None:
#         super().__init__(context)
#         self.partition_id = None
#         self.model = None

#     @custom_model.partitioned_api
#     def predict(self, input: pd.DataFrame) -> pd.DataFrame:
#         FEATURE_COLS = [
#     'TEMP_C', 'WIND_MPS', 'HOUR_OF_DAY', 'DAY_OF_WEEK',
#     'IS_DAYTIME', 'IS_WEEKEND', 'MONTH',
#     'TEMP_WIND_INTERACTION', 'IS_HOT', 'IS_WINDY',
#     'LAG_1H', 'LAG_24H', 'LAG_1W',
#     'ROLL_MEAN_24H', 'ROLL_STD_24H', 'ROLL_MAX_24H', 'ROLL_MIN_24H',
#     'ROLL_MEAN_1W', 'TREND_24H'
# ]

#         TARGET_COL = 'HIST_PRODUCTION'

#         model_id = str(input["SUBSTATION_ID"][0])
#         model = self.context.model_ref(model_id)

#         model_output = model.predict(input[FEATURE_COLS])
#         res = pd.DataFrame(model_output)
#         return res

In [None]:
# from snowflake.ml.model import custom_model

# # Models have been fit, and they can now be retrieved and registered to the model registry.
# model_context = custom_model.ModelContext(
#     models=models
# )

# my_stateful_model = PartitionedModel(context=model_context)
# reg = registry.Registry(session=session)
# options = {
#     "function_type": "TABLE_FUNCTION",
#     "relax_version": False
# }
# FEATURE_COLS = [
#     'TEMP_C', 'WIND_MPS', 'HOUR_OF_DAY', 'DAY_OF_WEEK',
#     'IS_DAYTIME', 'IS_WEEKEND', 'MONTH',
#     'TEMP_WIND_INTERACTION', 'IS_HOT', 'IS_WINDY',
#     'LAG_1H', 'LAG_24H', 'LAG_1W',
#     'ROLL_MEAN_24H', 'ROLL_STD_24H', 'ROLL_MAX_24H', 'ROLL_MIN_24H',
#     'ROLL_MEAN_1W', 'TREND_24H'
# ]

# TARGET_COL = 'HIST_PRODUCTION'

# import snowflake.snowpark.functions as F
# from functools import reduce

# cols = FEATURE_COLS + ["SUBSTATION_ID"]
# cond = reduce(lambda a, b: a & F.col(b).is_not_null(), cols[1:], F.col(cols[0]).is_not_null())

# ## want sample data with no nulls
# sample_input_data = (
#     training_df
#     .select(cols)
#     .filter(cond)      # keep only rows with no nulls in all required columns
#     .limit(1)
#     .to_pandas()
# )
# mv = reg.log_model(
#     my_stateful_model,
#     model_name="partitioned_model",
#     options=options,
#     conda_dependencies=["pandas", "xgboost"],
#     sample_input_data=sample_input_data,    
# )

In [None]:
# service_prediction = mv.run(
#     training_df,
#     partition_column="SUBSTATION_ID",
# )

### Step 3.2: Alternative ManyModelInference Method (Container Execution) — Preview Feature

Best for scaling training or inference across thousands of partitions, or when incorporating custom Python dependencies. Container execution gives us the flexibility to bring and scale out any enviornment to Snowflake. 

In [None]:
from snowflake.ml.modeling.distributors.many_model import ManyModelInference
from snowflake.ml.data import DataConnector
from snowflake.ml.modeling.distributors.distributed_partition_function.partition_context import (
    PartitionContext,
)

def xgb_inference_func(data_connector: DataConnector, model, context: PartitionContext):
    """Simple inference function."""
    df = data_connector.to_pandas()
    FEATURE_COLS = [
    'TEMP_C', 'WIND_MPS', 'HOUR_OF_DAY', 'DAY_OF_WEEK',
    'IS_DAYTIME', 'IS_WEEKEND', 'MONTH',
    'TEMP_WIND_INTERACTION', 'IS_HOT', 'IS_WINDY',
    'LAG_1H', 'LAG_24H', 'LAG_1W',
    'ROLL_MEAN_24H', 'ROLL_STD_24H', 'ROLL_MAX_24H', 'ROLL_MIN_24H',
    'ROLL_MEAN_1W', 'TREND_24H'
]
    X = df[FEATURE_COLS].values
    predictions = model.predict(X)

    # Write prediction results to persistent storage
    results = df.copy()
    results['predictions'] = predictions
    
    # Two persistence strategies (choose one or both based on your needs):

    # Strategy 1: Stage artifacts - for framework management and debugging
    # context.upload_to_stage(results, "predictions.csv",
    #     write_function=lambda df, path: df.to_csv(path, index=False))

    # Strategy 2: Snowflake table - for immediate downstream consumption
    predictions_df = context.session.create_dataframe(results)
    predictions_df.write.mode("append").save_as_table("sales_predictions")
    
    return predictions

mmi = ManyModelInference(
    inference_func=xgb_inference_func,
    stage_name="@POWERGRID_DEMO.ML.ML_STAGE",
    training_run_id=run_id, # run_id from previous many model training run at step 2
)

from snowflake.snowpark.functions import col
values_to_filter = [1,2,3,4,5]

training_df = training_df.filter(col("SUBSTATION_ID").in_(values_to_filter))


inference_run = mmi.run(
    partition_by="SUBSTATION_ID",
    snowpark_dataframe=training_df, # running inference on the same training data mainly for illustration purposes.
    run_id="basic_inference_run",
    on_existing_artifacts="overwrite",
)

In [None]:
#inference_run.cancel()

In [None]:
assert inference_run.wait() == RunStatus.SUCCESS

In [None]:
inference_run.get_progress()
# inference_run.get_progress()["FAILED"][0].logs

## Step 4: Troubleshooting Failed Runs

Training functions can fail for various reasons. Below are some common causes:

- **User Code Errors**  
  Bugs or issues in the user-defined training function can cause failures.

- **Infrastructure Issues**  
  An *Out-of-Memory (OOM)* error occurs when the training function consumes more memory than the node can provide.

- **Unexpected Node Failures**  
  In some cases, a node might crash unexpectedly.

---

### Handling OOM and Node Failures

When an OOM error or fatal node failure occurs, the **MMT API will not automatically retry** the training function. Instead, it will mark the corresponding partition ID run as **`INTERNAL_ERROR`**. If a worker node crashes, logs might not be captured, making debugging more difficult.

For all other failure scenarios (including OOM errors), MMT provides:
- A **detailed error message**  
- A **stack trace** to help diagnose and fix the issue

---

### Retry Logic for Non-Fatal Errors

If the failure is not considered fatal (e.g., transient issues), MMT will automatically retry the training function with **exponential backoff**. This mechanism allows transient issues to resolve before the function ultimately fails.

**Retry Attempts:**
1. **First retry**: Wait for 2 seconds (`initial_delay`)
2. **Second retry**: Wait for 4 seconds (2 * `initial_delay`)
3. **Third retry**: Wait for 8 seconds (2^2 * `initial_delay`)
4. **Fourth retry**: Wait for 16 seconds (2^3 * `initial_delay`)
5. **Final retry**: No delay — if it fails again, an exception is raised


In [None]:
def user_func_error(data_connector: DataConnector, context: PartitionContext):
    pandas_df = data_connector.to_pandas()

    FEATURE_COLS = [
    'TEMP_C', 'WIND_MPS', 'HOUR_OF_DAY', 'DAY_OF_WEEK',
    'IS_DAYTIME', 'IS_WEEKEND', 'MONTH',
    'TEMP_WIND_INTERACTION', 'IS_HOT', 'IS_WINDY',
    'LAG_1H', 'LAG_24H', 'LAG_1W',
    'ROLL_MEAN_24H', 'ROLL_STD_24H', 'ROLL_MAX_24H', 'ROLL_MIN_24H',
    'ROLL_MEAN_1W', 'TREND_24H'
]

    TARGET_COL = 'HIST_PRODUCTION'
    model = xgb.XGBRegressor()

    # INTENTIONAL USER-CODE FAILURE: fitss function does not exist
    model.fitss(pandas_df[FEATURE_COLS], pandas_df[TARGET_COL])    
    
    return model


model_name="my_mmt_model"
model_version = "v2"
run_id=f"{model_name}_{model_version}"

trainer = ManyModelTraining(
    user_func_error,    
    stage_name="@POWERGRID_DEMO.ML.ML_STAGE",
)

failed_trainer_run = trainer.run(
    snowpark_dataframe=training_df,
    partition_by="SUBSTATION_ID",    
    run_id=run_id,
    on_existing_artifacts="overwrite", # or "error"
)
