# Logging R ARIMAX Model to Snowflake Model Registry

This notebook demonstrates how to wrap an R model in Python and log it to Snowflake's Model Registry.

## Problem
- Snowflake Model Registry only supports Python models natively
- Customer has R ARIMAX model getting `reticulate` package errors

## Solution
- Create Python wrapper class that calls R via subprocess
- Use SNOWPARK_CONTAINER_SERVICES target platform
- Handle data I/O via CSV interchange

## Step 1: Setup and Connect to Snowflake

In [None]:
import os
import pandas as pd
import numpy as np
from snowflake.snowpark import Session
from snowflake.ml.registry import Registry
from r_model_wrapper import ARIMAXModelWrapper

# connection_params = {
#     "connection_name": os.getenv("SNOWFLAKE_CONNECTION_NAME") or "MY_DEMO"
# }

# session = Session.builder.configs(connection_params).create()
# print(f"Connected to Snowflake: {session.get_current_database()}.{session.get_current_schema()}")

from snowflake.snowpark.context import get_active_session
session = get_active_session()

## Step 2: Generate Synthetic Test Data

Create exogenous variables for forecasting:

In [None]:
np.random.seed(42)
n_forecast = 10

test_data = pd.DataFrame({
    'exog_var1': np.random.normal(5, 1, n_forecast),
    'exog_var2': np.random.normal(10, 2, n_forecast)
})

print("Test Data:")
print(test_data)

## Step 3: Fetch Files

First verify the R model works locally before logging to registry:

In [None]:
# Download R model artifacts from Snowflake stage
r_artifact_stage_path = "@E2E_SNOW_MLOPS_DB.MLOPS_SCHEMA.ML_ARTIFACTS_STAGE/r_models/arimax_model_artifact.rds"
r_script_stage_path = "@E2E_SNOW_MLOPS_DB.MLOPS_SCHEMA.ML_ARTIFACTS_STAGE/r_models/predict_arimax.R"

# Download files from stage to /tmp/
session.file.get(r_artifact_stage_path, "/tmp/")
session.file.get(r_script_stage_path, "/tmp/")

print("Downloaded R artifacts from stage to /tmp/")
print("Note: Skipping local testing - R not available in notebook environment")
print("Model will be tested after logging to registry via SPCS")

## Step 4: Create Model Registry

In [None]:
reg = Registry(
    session=session,
    database_name='E2E_SNOW_MLOPS_DB',
    schema_name='MLOPS_SCHEMA'
)

print(f"Registry initialized")

## Step 5: Log Model to Snowflake Registry

### Key Configuration:
- **target_platforms**: `["SNOWPARK_CONTAINER_SERVICES"]` - Required for subprocess/R execution
- **conda_dependencies**: Include R and forecast package
- **sample_input_data**: Define expected input schema

In [None]:
from snowflake.ml.model import custom_model
from snowflake.ml.model.model_signature import (
    ModelSignature,
    FeatureSpec,
    DataType
)
import pandas as pd

# Create ModelContext with downloaded artifacts
model_context = custom_model.ModelContext(
    model_rds='/tmp/arimax_model_artifact.rds',
    predict_script='/tmp/predict_arimax.R'
)

# Instantiate custom model with context
my_model = ARIMAXModelWrapper(model_context)

# Define explicit signature to avoid running model locally during packaging
predict_signature = ModelSignature(
    inputs=[
        FeatureSpec(name="exog_var1", dtype=DataType.DOUBLE),
        FeatureSpec(name="exog_var2", dtype=DataType.DOUBLE)
    ],
    outputs=[
        FeatureSpec(name="forecast", dtype=DataType.DOUBLE),
        FeatureSpec(name="lower_80", dtype=DataType.DOUBLE),
        FeatureSpec(name="upper_80", dtype=DataType.DOUBLE),
        FeatureSpec(name="lower_95", dtype=DataType.DOUBLE),
        FeatureSpec(name="upper_95", dtype=DataType.DOUBLE)
    ]
)


# Log to registry
model_version = reg.log_model(
    my_model,
    model_name="ARIMAX_R_MODEL",
    version_name="V1",
    target_platforms=["SNOWPARK_CONTAINER_SERVICES"],
    conda_dependencies=[
        "r-base>=4.1",
        "r-forecast"
    ],
    signatures={"predict": predict_signature},
    sample_input_data=test_data,
    comment="R ARIMAX model wrapped in Python for Snowflake Model Registry"
)

print(f"\nModel logged successfully!")
print(f"Model: {model_version.model_name}")
print(f"Version: {model_version.version_name}")

## Step 6: Create SPCS Resources

Create compute pool and image repository for model deployment:

In [None]:
USE ROLE ACCOUNTADMIN;

GRANT CREATE COMPUTE POOL ON ACCOUNT TO ROLE SNOWFLAKE_INTELLIGENCE_ADMIN_RL;     
USE ROLE SNOWFLAKE_INTELLIGENCE_ADMIN_RL;
                                                                

In [None]:
USE ROLE SNOWFLAKE_INTELLIGENCE_ADMIN_RL;

In [None]:
# Option 1: Create new resources (requires CREATE COMPUTE POOL privilege)
# Uncomment if you want to create new resources:

#GRANT CREATE COMPUTE POOL ON ACCOUNT TO ROLE <role name>;   
session.sql("""
CREATE COMPUTE POOL IF NOT EXISTS R_MODEL_POOL
    MIN_NODES = 1
    MAX_NODES = 2
    INSTANCE_FAMILY = 'CPU_X64_M'
    AUTO_RESUME = TRUE
    COMMENT = 'Compute pool for R ARIMAX model inference'
""").collect()
print("✓ Compute pool created: R_MODEL_POOL")

session.sql("""
CREATE IMAGE REPOSITORY IF NOT EXISTS E2E_SNOW_MLOPS_DB.MLOPS_SCHEMA.R_MODEL_IMAGE_REPO
    COMMENT = 'Repository for R model container images'
""").collect()
print("✓ Image repository created: R_MODEL_IMAGE_REPO")

# # Option 2: Use existing resources (recommended)
# # Set these to your existing compute pool and image repository names:
# COMPUTE_POOL_NAME = "DEMO_POOL_CPU"  # Replace with your compute pool
# IMAGE_REPO_NAME = "ML_IMAGE_REPO"    # Replace with your image repo (or leave empty to use default)

# print(f"✓ Using compute pool: {COMPUTE_POOL_NAME}")
# print(f"✓ Using image repository: {IMAGE_REPO_NAME or 'default'}")

## Step 6: Deploy Model to SPCS

Since we used `SNOWPARK_CONTAINER_SERVICES` target platform, we must deploy the model before inference:

In [None]:
# Deploy model to SPCS
model_version.create_service(
    service_name="arimax_deployment",
    service_compute_pool="R_MODEL_POOL",
    image_repo="R_MODEL_IMAGE_REPO",
    ingress_enabled=True, 
    max_instances=1
)

print("Model deployed to SPCS: arimax_deployment")
print("Building container image and starting service...")
print("This may take 5-10 minutes for first deployment.")

In [None]:
# Wait for service to be ready (optional check)
import time
service_status = session.sql("SHOW SERVICES LIKE 'arimax_deployment'").collect()
print(f"Service status: {service_status[0]['status'] if service_status else 'Not found'}")

# Create test data for inference
test_snowpark_df = session.create_dataframe(test_data)

# Call the model via SPCS service
print("\n=== Running Inference via SPCS ===")
predictions = model_version.run(
    test_snowpark_df,
    function_name="predict",
    service_name="arimax_deployment"
)

print("\nPredictions:")
predictions.show()

## Step 7: View Model in Registry

In [None]:
models_df = reg.show_models()
print("\nRegistered Models:")
print(models_df[['name', 'versions', 'comment']])

## Cleanup (Optional)

In [None]:
# Uncomment to clean up resources

# Delete the service
# model_version.delete_service("arimax_deployment")

# Delete the model
# reg.delete_model("ARIMAX_R_MODEL")

# Drop compute pool and image repo
# session.sql("DROP COMPUTE POOL IF EXISTS R_MODEL_POOL").collect()
# session.sql("DROP IMAGE REPOSITORY IF EXISTS R_MODEL_IMAGE_REPO").collect()

# print("Resources cleaned up")

# reg.delete_model("ARIMAX_R_MODEL")
# print("Model deleted from registry")