# Model Registry: Train, Register, and Serve R Models

This notebook demonstrates the `snowflakeR` package workflow for registering R models
in the Snowflake Model Registry and running inference.

**What you'll do:**
1. Train models in R (as usual)
2. Test predictions locally
3. Register to Snowflake Model Registry with one function call
4. Manage versions and metrics
5. Run remote inference via SPCS

**Under the hood:** `snowflakeR` auto-generates a Python `CustomModel` wrapper
that uses `rpy2` to load and call your R model. You never write Python.

**Sections:**
1. [Setup](#section-1-setup)
2. [Train a Model](#section-2-train)
3. [Test Locally](#section-3-local-test)
4. [Register to Snowflake](#section-4-register)
5. [Manage Models](#section-5-manage)
6. [Remote Inference](#section-6-inference)
7. [Advanced: Custom Predict Code](#section-7-custom-predict)
8. [Cleanup](#section-8-cleanup)

---

# Section 1: Setup

## Workspace Notebook

If you haven't already, run the R environment setup from `quickstart.ipynb` first.
Then load the package:

In [None]:
# Workspace Notebook: configure rpy2 (skip if already done in this session)
import sys
sys.path.insert(0, '..')

from r_helpers import setup_r_environment
result = setup_r_environment()
print(f"R ready: {result['success']}")

In [None]:
%%R
library(snowflakeR)

# Connect and create a Model Registry context
conn <- sfr_connect()

# Target a specific db/schema for model storage
reg <- sfr_model_registry(conn)
reg

## Local Environment

If running locally with an R kernel, skip the Python cells above and run:

```r
library(snowflakeR)
conn <- sfr_connect()
reg  <- sfr_model_registry(conn, database = "ML_DB", schema = "MODELS")
```

---

# Section 2: Train a Model

## Example A: Linear regression on mtcars

Train any R model as you normally would -- nothing snowflakeR-specific here.

In [None]:
%%R
# Train a simple linear model
model_lm <- lm(mpg ~ wt + hp + cyl, data = mtcars)
summary(model_lm)

## Example B: Time series with forecast

```r
%%R
library(forecast)
model_arima <- auto.arima(AirPassengers)
summary(model_arima)
```

---

# Section 3: Test Locally

`sfr_predict_local()` runs the **exact same prediction logic** that will execute
inside Snowflake, but entirely in R. Use this to verify before registering.

In [None]:
%%R
# Create test data
test_data <- data.frame(
  wt  = c(2.5, 3.0, 3.5, 4.0),
  hp  = c(110, 150, 200, 245),
  cyl = c(4L, 6L, 8L, 8L)
)

# Test locally -- same predict path as remote
preds <- sfr_predict_local(model_lm, test_data)
cbind(test_data, preds)

---

# Section 4: Register to Snowflake

One call to `sfr_log_model()` handles everything:
- Saves the R model to `.rds`
- Auto-generates a Python `CustomModel` wrapper
- Registers in the Snowflake Model Registry

In [None]:
%%R
mv <- sfr_log_model(
  reg,
  model       = model_lm,
  model_name  = "SFR_DEMO_MPG",
  input_cols  = list(wt = "double", hp = "double", cyl = "integer"),
  output_cols = list(prediction = "double"),
  comment     = "Linear regression: MPG from weight, horsepower, cylinders"
)

mv

### Key parameters for `sfr_log_model()`

| Parameter | Description |
|-----------|-------------|
| `model` | Any R object that can be `saveRDS()`'d |
| `model_name` | Registry name (uppercase recommended) |
| `input_cols` | Named list: column name -> type (`double`, `integer`, `string`, `boolean`) |
| `output_cols` | Named list: output column name -> type |
| `predict_fn` | R function name (default: `"predict"`) |
| `predict_pkgs` | R packages needed at inference time |
| `conda_deps` | Extra conda packages (r-base and rpy2 always included) |
| `target_platforms` | `"SNOWPARK_CONTAINER_SERVICES"` (default) or `"WAREHOUSE"` |

---

# Section 5: Manage Models

## List and inspect models

In [None]:
%%R
# List all models in the registry
models <- sfr_show_models(reg)
models

In [None]:
%%R
# Get a specific model
m <- sfr_get_model(reg, "SFR_DEMO_MPG")
m

# Show versions
sfr_show_model_versions(reg, "SFR_DEMO_MPG")

## Metrics

Attach evaluation metrics to model versions:

In [None]:
%%R
# Calculate and set metrics
preds_train <- predict(model_lm, mtcars)
rmse <- sqrt(mean((mtcars$mpg - preds_train)^2))
r_sq <- summary(model_lm)$r.squared

sfr_set_model_metric(reg, "SFR_DEMO_MPG", "V1", "rmse", rmse)
sfr_set_model_metric(reg, "SFR_DEMO_MPG", "V1", "r_squared", r_sq)

cat(sprintf("RMSE: %.3f, R-squared: %.3f\n", rmse, r_sq))

In [None]:
%%R
# Retrieve metrics
sfr_show_model_metrics(reg, "SFR_DEMO_MPG", "V1")

## Log a second version

In [None]:
%%R
# Train a better model (added displacement)
model_v2 <- lm(mpg ~ wt + hp + cyl + disp, data = mtcars)

mv2 <- sfr_log_model(
  reg,
  model        = model_v2,
  model_name   = "SFR_DEMO_MPG",
  version_name = "V2",
  input_cols   = list(wt = "double", hp = "double", cyl = "integer", disp = "double"),
  output_cols  = list(prediction = "double"),
  comment      = "V2: added displacement"
)

# Set as default
sfr_set_default_model_version(reg, "SFR_DEMO_MPG", "V2")

---

# Section 6: Remote Inference (SPCS)

Once the model is registered, run predictions directly in Snowflake.
The R model executes inside Snowpark Container Services via the auto-generated rpy2 wrapper.

In [None]:
%%R
# Write test data to Snowflake
new_data <- data.frame(
  wt   = c(2.62, 3.44, 3.57),
  hp   = c(110, 175, 245),
  cyl  = c(4L, 6L, 8L),
  disp = c(120.3, 258.0, 360.0)
)
sfr_write_table(conn, "SFR_DEMO_PREDICT_INPUT", new_data, overwrite = TRUE)

# Predict using the registered model
# (Requires SPCS service -- uncomment when deployed)
# preds <- sfr_predict(reg, "SFR_DEMO_MPG", new_data)
# preds

### Deploying as an SPCS service

```r
%%R
# Deploy (requires compute pool and image repo)
sfr_deploy_model(
  reg,
  model_name   = "SFR_DEMO_MPG",
  version_name = "V2",
  service_name = "mpg_service",
  compute_pool = "ML_POOL",
  image_repo   = "my_db.my_schema.my_repo"
)

# Predict via the service
preds <- sfr_predict(reg, "SFR_DEMO_MPG", new_data, service_name = "mpg_service")

# Undeploy when done
sfr_undeploy_model(reg, "SFR_DEMO_MPG", "V2", "mpg_service")
```

---

# Section 7: Advanced -- Custom Predict Code

For models that need special prediction logic (e.g., `forecast`, multi-step pipelines),
use the `predict_body` template.

### Template variables

| Variable | Description |
|----------|-------------|
| `{{MODEL}}` | The loaded R model object |
| `{{INPUT}}` | The input data.frame |
| `{{UID}}` | Unique ID for variable naming |
| `{{N}}` | Number of rows in input |

In [None]:
%%R
# Example: forecast model with custom output
# (Uncomment and adapt for your use case)

# library(forecast)
# arima_model <- auto.arima(AirPassengers)
#
# mv_forecast <- sfr_log_model(
#   reg,
#   model        = arima_model,
#   model_name   = "SFR_DEMO_FORECAST",
#   predict_fn   = "forecast",
#   predict_pkgs = c("forecast"),
#   predict_body = '
#     pred_{{UID}} <- forecast({{MODEL}}, h = {{N}})
#     result_{{UID}} <- data.frame(
#       period = seq_len({{N}}),
#       point_forecast = as.numeric(pred_{{UID}}$mean),
#       lower_95 = as.numeric(pred_{{UID}}$lower[,2]),
#       upper_95 = as.numeric(pred_{{UID}}$upper[,2])
#     )
#   ',
#   input_cols  = list(period = "integer"),
#   output_cols = list(
#     period = "integer", point_forecast = "double",
#     lower_95 = "double", upper_95 = "double"
#   )
# )

---

# Section 8: Cleanup

In [None]:
%%R
# Delete test models and tables
sfr_delete_model(reg, "SFR_DEMO_MPG")
sfr_execute(conn, "DROP TABLE IF EXISTS SFR_DEMO_PREDICT_INPUT")

sfr_disconnect(conn)
cat("Cleanup complete.\n")

---

## Supported model types

Any R model that can be serialised with `saveRDS()` works:

- `lm()`, `glm()` (base R)
- `randomForest::randomForest()`
- `xgboost::xgb.train()`
- `ranger::ranger()`
- `forecast::auto.arima()`, `forecast::ets()`
- `tidymodels` workflows
- Custom S3/S4 model objects

## Next steps

- **Feature Store:** See `feature_store_demo.ipynb`
- **Full documentation:** `vignette("model-registry", package = "snowflakeR")`