# R-Native Model Registry: Train, Register, and Inference

This notebook demonstrates **R-native** workflow for registering and serving R models in Snowflake.

**Key difference from r_forecasting_demo.ipynb:** Instead of writing Python wrapper classes manually,
we use R wrapper functions (`snowflake_registry.R`) that handle all the Python plumbing automatically.

## What the R user does:

1. **Train** a model in R (as usual)
2. **Register** with one R function call: `sf_registry_log_model()`
3. **Test locally** with: `sf_registry_predict_local()`
4. **Run remote inference** with: `sf_registry_predict()`

## What happens under the hood:

- R model is saved to `.rds` file
- A Python `CustomModel` wrapper is auto-generated (uses rpy2 to call R)
- The wrapper + model are logged to Snowflake Model Registry
- Inference runs via SPCS (or warehouse), calling R through rpy2

---

## Table of Contents

1. [Configuration](#section-1-configuration)
2. [Environment Setup](#section-2-environment-setup)
3. [Data Exploration & Model Training (R)](#section-3-data-and-training)
4. [Local Testing (R)](#section-4-local-testing)
5. [Model Registration (R)](#section-5-registration)
6. [Remote Inference (R)](#section-6-inference)
7. [Model Management (R)](#section-7-management)
8. [Example: Linear Regression](#section-8-linear-model)
9. [Example: ARIMAX with Exogenous Vars](#section-9-arimax)
10. [Cleanup](#section-10-cleanup)

---

# Section 1: Configuration

In [None]:
# =============================================================================
# USER CONFIGURATION - Modify these values for your environment
# =============================================================================

MODEL_DATABASE = "SIMON"              # Your database
MODEL_SCHEMA   = "R_REGISTRY_DEMO"    # Schema for models
WAREHOUSE      = "SIMON_XS"           # Your warehouse

# SPCS resources (for remote inference)
COMPUTE_POOL   = "R_FORECAST_POOL"    # Compute pool
IMAGE_REPO     = "R_FORECAST_IMAGES"  # Image repository

# Data source (TPC-H sample data - available in all accounts)
SOURCE_DATABASE = "SNOWFLAKE_SAMPLE_DATA"
SOURCE_SCHEMA   = "TPCH_SF1"

print(f"Config: {MODEL_DATABASE}.{MODEL_SCHEMA}")
print(f"Warehouse: {WAREHOUSE}")

---

# Section 2: Environment Setup

Install R, configure the Python-R bridge, and load our R wrapper library.

In [None]:
# Install R environment (only needed once per session)
!bash setup_r_environment.sh --adbc 2>&1 | tail -20

In [None]:
# Configure Python-R bridge
from r_helpers import setup_r_environment

result = setup_r_environment()
if result['success']:
    print(f"R environment ready: {result['r_version']}")
else:
    print(f"Setup failed: {result['errors']}")

In [None]:
# Connect to Snowflake
from snowflake.snowpark import Session
from snowflake.snowpark.context import get_active_session
import pandas as pd

session = get_active_session()
session.sql(f"USE WAREHOUSE {WAREHOUSE}").collect()

print(f"Connected as: {session.get_current_user()}")
print(f"Warehouse: {session.get_current_warehouse()}")

In [None]:
%%R
# Load the R wrapper library
# This sources snowflake_registry.R which provides all sf_registry_* functions
source("snowflake_registry.R")

# Initialize the wrapper (auto-detects the Snowpark session)
sf_registry_init()

# Show available functions
sf_registry_help()

In [None]:
%%R
# Set up schema and artifacts stage
sf_registry_setup(
  database  = "SIMON",
  schema    = "R_REGISTRY_DEMO",
  warehouse = "SIMON_XS"
)

---

# Section 3: Data Exploration & Model Training (R)

This is pure R - the user trains a model exactly as they normally would.

In [None]:
# Query data from Snowflake into Python, then pass to R
orders_query = f"""
SELECT 
    DATE_TRUNC('MONTH', O_ORDERDATE) as ORDER_MONTH,
    COUNT(*) as ORDER_COUNT,
    SUM(O_TOTALPRICE) as TOTAL_REVENUE,
    AVG(O_TOTALPRICE) as AVG_ORDER_VALUE
FROM {SOURCE_DATABASE}.{SOURCE_SCHEMA}.ORDERS
GROUP BY DATE_TRUNC('MONTH', O_ORDERDATE)
ORDER BY ORDER_MONTH
"""

orders_df = session.sql(orders_query).to_pandas()
print(f"Loaded {len(orders_df)} months of data")
orders_df.head()

In [None]:
%%R -i orders_df
library(forecast)
library(ggplot2)

# Standard R time series workflow - nothing special here!
orders_df <- orders_df[order(orders_df$ORDER_MONTH), ]
order_counts <- orders_df$ORDER_COUNT

start_date  <- as.Date(min(orders_df$ORDER_MONTH))
start_year  <- as.numeric(format(start_date, "%Y"))
start_month <- as.numeric(format(start_date, "%m"))

orders_ts <- ts(order_counts, start = c(start_year, start_month), frequency = 12)

cat("Time Series:", length(orders_ts), "observations\n")
cat("Start:", paste(start(orders_ts), collapse = "/"), "\n")
cat("End:",   paste(end(orders_ts), collapse = "/"),   "\n")

In [None]:
%%R
# Train the model - standard R, nothing Snowflake-specific
arima_model <- auto.arima(orders_ts, seasonal = TRUE, stepwise = FALSE)

cat("\nModel:", arima_model$method, "\n")
writeLines(capture.output(summary(arima_model)))

---

# Section 4: Local Testing (R)

Before registering, test that the model works through the wrapper.
This uses the **exact same** rpy2 CustomModel pipeline that will run in SPCS.

In [None]:
%%R
# Test locally - this runs through the same CustomModel wrapper
# that will be deployed to Snowflake, ensuring no surprises.

test_input <- data.frame(period = 1:6)

local_preds <- sf_registry_predict_local(
  model        = arima_model,
  input_data   = test_input,
  predict_fn   = "forecast",
  predict_pkgs = c("forecast")
)

cat("Local prediction results (6-month forecast):\n")
rprint(local_preds)

---

# Section 5: Model Registration (R)

Register the model to Snowflake Model Registry with **one R function call**.

Compare this to the manual approach in `r_forecasting_demo.ipynb` which required:
- Writing a Python wrapper class (~80 lines)
- Defining ModelSignature objects
- Creating ModelContext
- Calling reg.log_model() with many Python arguments

Now it's just one R function:

In [None]:
%%R
# Register model to Snowflake Model Registry - ONE function call!
mv <- sf_registry_log_model(
  model        = arima_model,
  model_name   = "TPCH_ORDERS_FORECAST_R",

  # How to run inference in R
  predict_fn   = "forecast",
  predict_pkgs = c("forecast"),

  # Schema: what goes in, what comes out
  input_cols  = list(period = "integer"),
  output_cols = list(
    period         = "integer",
    point_forecast = "double",
    lower_80       = "double",
    upper_80       = "double",
    lower_95       = "double",
    upper_95       = "double"
  ),

  # Dependencies for the SPCS container
  conda_dependencies = c(
    "r-base>=4.1",
    "r-forecast>=8.0",
    "rpy2>=3.5"
  ),

  target_platforms = "SNOWPARK_CONTAINER_SERVICES",
  comment = "R ARIMA forecast model (registered via R wrapper)",
  metrics = list(aic = arima_model$aic, method = arima_model$method)
)

In [None]:
%%R
# View registered models
models_df <- sf_registry_show_models()
rprint(models_df)

In [None]:
%%R
# View versions of our model
versions_df <- sf_registry_show_versions("TPCH_ORDERS_FORECAST_R")
rprint(versions_df)

---

# Section 6: Remote Inference (R)

Deploy the model to SPCS and run predictions - all from R.

In [None]:
%%R
# Deploy as SPCS service
sf_registry_create_service(
  model_name   = "TPCH_ORDERS_FORECAST_R",
  version_name = mv$version_name,
  service_name = "orders_forecast_r_svc",
  compute_pool = "R_FORECAST_POOL",
  image_repo   = "R_FORECAST_IMAGES"
)

In [None]:
# Check service status (Python cell for SQL access)
import time

for i in range(20):
    status = session.sql("SHOW SERVICES LIKE 'orders_forecast_r_svc'").collect()
    if status:
        current = status[0]['status']
        print(f"Service status: {current}")
        if current == 'RUNNING':
            print("Service is running!")
            break
    time.sleep(30)
else:
    print("Service not ready yet")

In [None]:
%%R
# Run remote inference - 12 month forecast
remote_preds <- sf_registry_predict(
  model_name   = "TPCH_ORDERS_FORECAST_R",
  input_data   = data.frame(period = 1:12),
  service_name = "orders_forecast_r_svc"
)

cat("Remote predictions (12-month forecast):\n")
rprint(remote_preds)

In [None]:
%%R -i orders_df -w 900 -h 500
library(ggplot2)
library(scales)

# Visualize forecast
orders_df$ORDER_MONTH <- as.Date(orders_df$ORDER_MONTH)
last_date <- max(orders_df$ORDER_MONTH)

remote_preds$forecast_date <- seq.Date(
  from = last_date + 30, by = "month", length.out = nrow(remote_preds)
)

p <- ggplot() +
  geom_ribbon(data = remote_preds,
              aes(x = forecast_date, ymin = lower_95, ymax = upper_95),
              fill = "steelblue", alpha = 0.2) +
  geom_ribbon(data = remote_preds,
              aes(x = forecast_date, ymin = lower_80, ymax = upper_80),
              fill = "steelblue", alpha = 0.3) +
  geom_line(data = orders_df,
            aes(x = ORDER_MONTH, y = ORDER_COUNT), linewidth = 1) +
  geom_line(data = remote_preds,
            aes(x = forecast_date, y = point_forecast),
            color = "steelblue", linewidth = 1, linetype = "dashed") +
  scale_y_continuous(labels = comma) +
  labs(title = "TPC-H Orders Forecast (R Model via Registry)",
       subtitle = "Registered and served entirely from R",
       x = "Date", y = "Order Count") +
  theme_minimal(base_size = 12)

print(p)

---

# Section 7: Model Management (R)

All model management operations are available from R.

In [None]:
%%R
# Set metrics on the model version
sf_registry_set_metric(
  model_name   = "TPCH_ORDERS_FORECAST_R",
  version_name = mv$version_name,
  metric_name  = "training_observations",
  metric_value = length(orders_ts)
)

# Retrieve metrics
metrics <- sf_registry_show_metrics(
  model_name   = "TPCH_ORDERS_FORECAST_R",
  version_name = mv$version_name
)
cat("Model metrics:\n")
str(metrics)

In [None]:
%%R
# Get detailed model info
model_info <- sf_registry_get_model("TPCH_ORDERS_FORECAST_R")
cat("Model:",   model_info$name, "\n")
cat("Comment:", model_info$comment, "\n")
cat("Versions:", paste(model_info$versions, collapse = ", "), "\n")
cat("Default:",  model_info$default_version, "\n")

---

# Section 8: Example - Linear Regression

The wrapper works with **any** R model that supports `predict()` or a custom function.
Here's an example with a simple linear model.

In [None]:
%%R
# Train a linear model (standard R)
lm_model <- lm(mpg ~ wt + hp + cyl, data = mtcars)
writeLines(capture.output(summary(lm_model)))

In [None]:
%%R
# Test locally first
test_data <- data.frame(
  wt  = c(2.5, 3.0, 3.5, 4.0),
  hp  = c(100, 150, 200, 250),
  cyl = c(4, 6, 6, 8)
)

local_lm_preds <- sf_registry_predict_local(
  model        = lm_model,
  input_data   = test_data,
  predict_fn   = "predict"
  # No extra packages needed - predict.lm is in base R
)

cat("Local linear model predictions:\n")
rprint(local_lm_preds)

In [None]:
%%R
# Register the linear model
lm_mv <- sf_registry_log_model(
  model       = lm_model,
  model_name  = "MTCARS_MPG_MODEL",
  predict_fn  = "predict",
  input_cols  = list(wt = "double", hp = "double", cyl = "integer"),
  output_cols = list(prediction = "double"),
  comment     = "Linear regression: MPG ~ wt + hp + cyl",
  metrics     = list(
    r_squared     = summary(lm_model)$r.squared,
    adj_r_squared = summary(lm_model)$adj.r.squared
  )
)

---

# Section 9: Example - ARIMAX with Exogenous Variables

For models that need exogenous regressors (xreg), you can use the
`predict_body` parameter to provide custom R code.

In [None]:
%%R
library(forecast)

# Generate synthetic data with exogenous variables
set.seed(42)
n <- 100
exog1 <- rnorm(n, mean = 5, sd = 1)
exog2 <- rnorm(n, mean = 10, sd = 2)
y <- 50 + 0.5 * (1:n) + 10 * sin(2 * pi * (1:n) / 12) +
     2 * exog1 + 1.5 * exog2 + rnorm(n, sd = 2)

xreg <- cbind(exog1 = exog1, exog2 = exog2)
arimax_model <- auto.arima(y, xreg = xreg, seasonal = TRUE)

cat("ARIMAX Model:", arimax_model$method, "\n")

In [None]:
%%R
# Custom R prediction code for ARIMAX (uses xreg)
# Template variables: {{MODEL}}, {{INPUT}}, {{UID}}, {{N}}
arimax_predict_body <- '
  xreg_{{UID}} <- as.matrix({{INPUT}}[, c("exog1", "exog2")])
  pred_{{UID}} <- forecast({{MODEL}}, xreg = xreg_{{UID}}, h = {{N}})
  result_{{UID}} <- data.frame(
    point_forecast = as.numeric(pred_{{UID}}$mean),
    lower_80 = as.matrix(pred_{{UID}}$lower)[, 1],
    upper_80 = as.matrix(pred_{{UID}}$upper)[, 1],
    lower_95 = as.matrix(pred_{{UID}}$lower)[, 2],
    upper_95 = as.matrix(pred_{{UID}}$upper)[, 2]
  )
'

# Test locally with custom predict body
test_xreg <- data.frame(
  exog1 = rnorm(6, mean = 5, sd = 1),
  exog2 = rnorm(6, mean = 10, sd = 2)
)

arimax_preds <- sf_registry_predict_local(
  model        = arimax_model,
  input_data   = test_xreg,
  predict_fn   = "forecast",
  predict_pkgs = c("forecast"),
  predict_body = arimax_predict_body
)

cat("ARIMAX local predictions:\n")
rprint(arimax_preds)

In [None]:
%%R
# Register the ARIMAX model with custom predict body
arimax_mv <- sf_registry_log_model(
  model        = arimax_model,
  model_name   = "SYNTHETIC_ARIMAX_MODEL",
  predict_fn   = "forecast",
  predict_pkgs = c("forecast"),
  predict_body = arimax_predict_body,
  input_cols   = list(exog1 = "double", exog2 = "double"),
  output_cols  = list(
    point_forecast = "double",
    lower_80       = "double",
    upper_80       = "double",
    lower_95       = "double",
    upper_95       = "double"
  ),
  conda_dependencies = c("r-base>=4.1", "r-forecast>=8.0", "rpy2>=3.5"),
  comment = "ARIMAX model with exogenous regressors"
)

---

# Section 10: Cleanup

In [None]:
%%R
# Uncomment to clean up resources

# sf_registry_delete_service("TPCH_ORDERS_FORECAST_R", mv$version_name,
#                            "orders_forecast_r_svc")
# sf_registry_delete_model("TPCH_ORDERS_FORECAST_R")
# sf_registry_delete_model("MTCARS_MPG_MODEL")
# sf_registry_delete_model("SYNTHETIC_ARIMAX_MODEL")

cat("Cleanup section - uncomment to delete resources\n")

---

## Summary: Before vs After

### Before (r_forecasting_demo.ipynb)

R users had to:
1. Write ~80 lines of Python `CustomModel` wrapper class
2. Understand `snowflake.ml.model.custom_model` internals
3. Manually construct `ModelSignature` objects in Python
4. Handle rpy2 conversion details
5. Call `reg.log_model()` with many Python-specific arguments

### After (this notebook)

R users just call:
```r
sf_registry_log_model(model, model_name, predict_fn, predict_pkgs,
                      input_cols, output_cols)
```

### Function Reference

| Function | Purpose |
|----------|--------|
| `sf_registry_init()` | Initialize the wrapper |
| `sf_registry_log_model()` | Register R model to registry |
| `sf_registry_predict_local()` | Test model locally |
| `sf_registry_predict()` | Run remote inference |
| `sf_registry_show_models()` | List registered models |
| `sf_registry_show_versions()` | List model versions |
| `sf_registry_get_model()` | Get model details |
| `sf_registry_set_metric()` | Add metrics |
| `sf_registry_show_metrics()` | View metrics |
| `sf_registry_create_service()` | Deploy to SPCS |
| `sf_registry_delete_model()` | Delete model |
| `sf_registry_help()` | Show help |

### Supported Model Types

Any R model that can be:
1. Saved with `saveRDS()`
2. Loaded with `readRDS()`
3. Used for prediction with `predict()`, `forecast()`, or a custom function

Examples: `lm`, `glm`, `randomForest`, `xgboost`, `auto.arima`, `ets`, `nnetar`, `ranger`, `rpart`, etc.