# Model Registry: Train, Register, and Serve R Models

This notebook demonstrates the `snowflakeR` package workflow for registering R models
in the Snowflake Model Registry and running inference.

**What you'll do:**
1. Train models in R (as usual)
2. Test predictions locally
3. Register to Snowflake Model Registry with one function call
4. Manage versions and metrics
5. Run remote inference via SPCS

**Under the hood:** `snowflakeR` auto-generates a Python `CustomModel` wrapper
that uses `rpy2` to load and call your R model. You never write Python.

This notebook is for **Snowflake Workspace Notebooks** (Python kernel + `%%R` magic).
For local environments (RStudio, Posit, JupyterLab), use `local_model_registry.ipynb`.

**Before you start:** Copy `notebook_config.yaml.template` to `notebook_config.yaml`
and edit it with your warehouse, database, and schema.

**Sections:**
1. [Setup](#section-1-setup)
2. [Connect](#section-2-connect)
3. [Train a Model](#section-3-train)
4. [Test Locally](#section-4-local-test)
5. [Register to Snowflake](#section-5-register)
6. [Manage Models](#section-6-manage)
7. [Remote Inference](#section-7-inference)
8. [Advanced: Custom Predict Code](#section-8-custom-predict)
9. [Cleanup](#section-9-cleanup)

---

## 1. Setup

### Step 1: Install R environment (~3 minutes, first time only)

In [None]:
# Install R + rpy2 via setup script (included in this directory)
!bash setup_r_environment.sh --basic

### Step 2: Configure rpy2 and register `%%R` magic

In [None]:
from r_helpers import setup_r_environment
result = setup_r_environment()

if result['success']:
    print(f"R {result['r_version']} ready. %%R magic registered.")
else:
    print("Setup failed:", result['errors'])

### Step 3: Install and load snowflakeR

In [None]:
# Resolve the absolute path to the snowflakeR package root.
# This notebook lives at snowflakeR/inst/notebooks/, so the package root
# (the directory containing DESCRIPTION) is two levels up.
import os
snowflaker_path = os.path.normpath(os.path.join(os.getcwd(), "..", ".."))
print(f"snowflakeR path: {snowflaker_path}")
assert os.path.isfile(os.path.join(snowflaker_path, "DESCRIPTION")), \
    f"DESCRIPTION not found in {snowflaker_path} -- check your working directory"

# Export as env var so R can read it via Sys.getenv()
os.environ["SNOWFLAKER_PATH"] = snowflaker_path

In [None]:
%%R
# Suppress interactive prompts (Workspace Notebooks have no stdin)
options(repos = c(CRAN = "https://cloud.r-project.org"))

# Remove stale install (if any) so we always get the latest source
try(remove.packages("snowflakeR"), silent = TRUE)

# Install required dependencies from CRAN first (repos=NULL skips CRAN)
deps <- c("DBI", "methods", "reticulate", "cli", "rlang")
for (pkg in deps) {
  if (!requireNamespace(pkg, quietly = TRUE))
    install.packages(pkg, type = "source", quiet = TRUE)
}

# Option 1: Install from local repo cloned into the Workspace
# (absolute path resolved in the previous Python cell via env var)
install.packages(Sys.getenv("SNOWFLAKER_PATH"), repos = NULL, type = "source")

# Option 2: Install from GitHub via pak (once published to public repo)
# install.packages("pak", type = "source", quiet = TRUE)
# pak::pak("Snowflake-Labs/snowflakeR", ask = FALSE, upgrade = FALSE)

library(snowflakeR)

---
## 2. Connect & Set Execution Context

Workspace Notebooks do **not** auto-set database or schema.
`sfr_load_notebook_config()` reads `notebook_config.yaml` and runs
`USE WAREHOUSE / DATABASE / SCHEMA` to set the execution context.

All table references in this notebook use fully qualified names via `sfr_fqn()`.

In [None]:
%%R
# Connect (auto-detects Workspace session)
conn <- sfr_connect()

# Load config and set execution context
conn <- sfr_load_notebook_config(conn)
conn

# Create a Model Registry context
reg <- sfr_model_registry(conn)
reg

---
## 3. Train a Model

### Example A: Linear regression on mtcars

Train any R model as you normally would -- nothing snowflakeR-specific here.

In [None]:
%%R
# Train a simple linear model
model_lm <- lm(mpg ~ wt + hp + cyl, data = mtcars)
summary(model_lm)

### Example B: Time series with forecast

```r
%%R
library(forecast)
model_arima <- auto.arima(AirPassengers)
summary(model_arima)
```

---

## 4. Test Locally

`sfr_predict_local()` runs the **exact same prediction logic** that will execute
inside Snowflake, but entirely in R. Use this to verify before registering.

In [None]:
%%R
# Create test data
test_data <- data.frame(
  wt  = c(2.5, 3.0, 3.5, 4.0),
  hp  = c(110, 150, 200, 245),
  cyl = c(4L, 6L, 8L, 8L)
)

# Test locally -- same predict path as remote
preds <- sfr_predict_local(model_lm, test_data)
rprint(cbind(test_data, preds))

---
## 5. Register to Snowflake

One call to `sfr_log_model()` handles everything:
- Saves the R model to `.rds`
- Auto-generates a Python `CustomModel` wrapper
- Registers in the Snowflake Model Registry

In [None]:
%%R
mv <- sfr_log_model(
  reg,
  model        = model_lm,
  model_name   = "SFR_DEMO_MPG",
  version_name = "V1",
  input_cols   = list(wt = "double", hp = "double", cyl = "integer"),
  output_cols  = list(prediction = "double"),
  comment      = "Linear regression: MPG from weight, horsepower, cylinders"
)

mv

### Key parameters for `sfr_log_model()`

| Parameter | Description |
|-----------|-------------|
| `model` | Any R object that can be `saveRDS()`'d |
| `model_name` | Registry name (uppercase recommended) |
| `input_cols` | Named list: column name -> type (`double`, `integer`, `string`, `boolean`) |
| `output_cols` | Named list: output column name -> type |
| `predict_fn` | R function name (default: `"predict"`) |
| `predict_pkgs` | R packages needed at inference time |
| `conda_deps` | Extra conda packages (r-base and rpy2 always included) |
| `target_platforms` | `"SNOWPARK_CONTAINER_SERVICES"` (default) or `"WAREHOUSE"` |

---

## 6. Manage Models

### List and inspect models

In [None]:
%%R
# List all models in the registry
models <- sfr_show_models(reg)
rprint(models)

In [None]:
%%R
# Get a specific model
m <- sfr_get_model(reg, "SFR_DEMO_MPG")
m

# Show versions
sfr_show_model_versions(reg, "SFR_DEMO_MPG")

### Metrics

Attach evaluation metrics to model versions:

In [None]:
%%R
# Calculate and set metrics
preds_train <- predict(model_lm, mtcars)
rmse <- sqrt(mean((mtcars$mpg - preds_train)^2))
r_sq <- summary(model_lm)$r.squared

sfr_set_model_metric(reg, "SFR_DEMO_MPG", "V1", "rmse", rmse)
sfr_set_model_metric(reg, "SFR_DEMO_MPG", "V1", "r_squared", r_sq)

rcat(sprintf("RMSE: %.3f, R-squared: %.3f", rmse, r_sq))

In [None]:
%%R
# Retrieve metrics
sfr_show_model_metrics(reg, "SFR_DEMO_MPG", "V1")

### Log a second version

In [None]:
%%R
# Train a better model (added displacement)
model_v2 <- lm(mpg ~ wt + hp + cyl + disp, data = mtcars)

mv2 <- sfr_log_model(
  reg,
  model        = model_v2,
  model_name   = "SFR_DEMO_MPG",
  version_name = "V2",
  input_cols   = list(wt = "double", hp = "double", cyl = "integer", disp = "double"),
  output_cols  = list(prediction = "double"),
  comment      = "V2: added displacement"
)

# Set as default
sfr_set_default_model_version(reg, "SFR_DEMO_MPG", "V2")

---
## 7. Inference

### Local prediction (works everywhere)

`sfr_predict_local()` runs the **exact same prediction logic** that the registered model
uses, but entirely in your local R session. Use this to verify predictions before
deploying to production.

> **Note:** R models require `rpy2` and `r-base` at inference time, which are
> not available in the Snowflake warehouse Anaconda channel. Therefore, warehouse
> inference (`sfr_predict()` without a service) is not supported for R models.
> For production inference in Snowflake, deploy via SPCS (see below).

In [None]:
%%R
# Test data
new_data <- data.frame(
  wt   = c(2.62, 3.44, 3.57),
  hp   = c(110, 175, 245),
  cyl  = c(4L, 6L, 8L),
  disp = c(120.3, 258.0, 360.0)
)

# Predict locally -- same logic as the registered model
preds <- sfr_predict_local(model_lm, new_data)
cbind(new_data, preds)

### Production inference via SPCS

For production inference **inside Snowflake**, deploy the model as a
Snowpark Container Services (SPCS) service. This creates a container with
R, rpy2, and your model, then serves predictions via a REST endpoint.

The cells below create the SPCS infrastructure (compute pool + image repo)
if it doesn't already exist, then deploy, wait, test, benchmark, and undeploy.
Each step is a separate cell so you can re-run individually.

In [None]:
%%R
# ── 0. Create SPCS infrastructure (if not exists) ────────────────────────
sfr_create_compute_pool(conn, "R_FORECAST_POOL", instance_family = "CPU_X64_M")
sfr_create_image_repo(conn, sfr_fqn(conn, "R_FORECAST_IMAGES"))

# ── 1. Deploy as SPCS service (force = TRUE drops existing service first)
sfr_deploy_model(
  reg,
  model_name   = "SFR_DEMO_MPG",
  version_name = "V2",
  service_name = "mpg_service",
  compute_pool = "R_FORECAST_POOL",
  image_repo   = sfr_fqn(conn, "R_FORECAST_IMAGES"),
  force        = TRUE
)

In [None]:
%%R
# ── 2. Check service status (one-off) ────────────────────────────────────
st <- sfr_get_service_status(reg, "mpg_service")
rcat(sprintf("Status: %s | Message: %s | FQN: %s", st$status, st$message, st$fqn))
if (!is.null(st$containers)) rprint(st$containers[, c("containerName", "status", "message")])

In [None]:
%%R
# ── 3. Wait for service to be ready ──────────────────────────────────────
# Polls every 15 seconds, times out after 10 minutes
sfr_wait_for_service(reg, "mpg_service", timeout_min = 10, poll_sec = 15)

In [None]:
%%R
# ── 4. Run inference via the service ─────────────────────────────────────
preds <- sfr_predict(reg, "SFR_DEMO_MPG", new_data, service_name = "mpg_service")
rprint(preds)

# ── 5. Benchmark: run N iterations and report latency stats ──────────────
bench <- sfr_benchmark_inference(
  reg, "SFR_DEMO_MPG", new_data,
  service_name = "mpg_service",
  n = 20                                        # adjust iterations as needed
)

In [None]:
%%R
# ── 5. Undeploy when done ────────────────────────────────────────────────
sfr_undeploy_model(reg, "SFR_DEMO_MPG", "V2", "mpg_service")

## 8. Advanced -- Custom Predict Code

For models that need special prediction logic (e.g., `forecast`, multi-step pipelines),
use the `predict_body` template.

### Template variables

| Variable | Description |
|----------|-------------|
| `{{MODEL}}` | The loaded R model object |
| `{{INPUT}}` | The input data.frame |
| `{{UID}}` | Unique ID for variable naming |
| `{{N}}` | Number of rows in input |

In [None]:
%%R
# Example: forecast model with custom output
# (Uncomment and adapt for your use case)

# library(forecast)
# arima_model <- auto.arima(AirPassengers)
#
# mv_forecast <- sfr_log_model(
#   reg,
#   model        = arima_model,
#   model_name   = "SFR_DEMO_FORECAST",
#   predict_fn   = "forecast",
#   predict_pkgs = c("forecast"),
#   predict_body = '
#     pred_{{UID}} <- forecast({{MODEL}}, h = {{N}})
#     result_{{UID}} <- data.frame(
#       period = seq_len({{N}}),
#       point_forecast = as.numeric(pred_{{UID}}$mean),
#       lower_95 = as.numeric(pred_{{UID}}$lower[,2]),
#       upper_95 = as.numeric(pred_{{UID}}$upper[,2])
#     )
#   ',
#   input_cols  = list(period = "integer"),
#   output_cols = list(
#     period = "integer", point_forecast = "double",
#     lower_95 = "double", upper_95 = "double"
#   )
# )

---

## 9. Cleanup

In [None]:
%%R
# Uncomment to clean up demo objects
# (commented out to avoid accidental deletion on Run All)
#
# sfr_delete_model(reg, "SFR_DEMO_MPG")
# sfr_execute(conn, paste("DROP TABLE IF EXISTS", sfr_fqn(conn, "SFR_DEMO_PREDICT_INPUT")))
# sfr_disconnect(conn)
# rcat("Cleanup complete.")

---

## Supported model types

Any R model that can be serialised with `saveRDS()` works:

- `lm()`, `glm()` (base R)
- `randomForest::randomForest()`
- `xgboost::xgb.train()`
- `ranger::ranger()`
- `forecast::auto.arima()`, `forecast::ets()`
- `tidymodels` workflows
- Custom S3/S4 model objects

## Next steps

- **Feature Store:** See `workspace_feature_store.ipynb`
- **Full documentation:** `vignette("model-registry", package = "snowflakeR")`