# Model Registry Demo

## Setup Notebook and Import Path

In [1]:
# Scale cell width with the browser window to accommodate .show() commands for wider tables.
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import sys
import os

# Simplify reading from the local repository
cwd=os.getcwd()
REPO_PREFIX="snowflake/ml"
LOCAL_REPO_PATH=cwd[:cwd.find(REPO_PREFIX)].rstrip('/')

if LOCAL_REPO_PATH not in sys.path:
    print(f"Adding {LOCAL_REPO_PATH} to system path")
    sys.path.append(LOCAL_REPO_PATH)

Adding /Users/amauser/src/snowml to system path


## Train A Small Model

The cell below trains a small model for demonstration purposes. The nature of the model does not matter, it is purely used to demonstrate the usage of the Model Registry.

In [3]:
from sklearn import svm, linear_model
from sklearn.datasets import load_digits

digits = load_digits()
target_digit = 6
num_training_examples = 10
svc_gamma = 0.001
svc_C = 10.

clf = svm.SVC(gamma=svc_gamma, C=svc_C, probability=True)


def one_vs_all(dataset, digit):
    return [x == digit for x in dataset]

# Train a classifier using num_training_examples and use the last 100 examples for test.
train_features = digits.data[:num_training_examples]
train_labels = one_vs_all(digits.target[:num_training_examples], target_digit)
clf.fit(train_features, train_labels)

test_features = digits.data[-100:]
test_labels = one_vs_all(digits.target[-100:], target_digit)
prediction = clf.predict(test_features)


## Start Snowpark Session

To avoid exposing credentials in Github, we use a small utility `SnowflakeLoginOptions`. It allows you to score your default credentials in `~/.snowsql/config` in the following format:
```
[connections]
accountname = <string>   # Account identifier to connect to Snowflake.
username = <string>      # User name in the account. Optional.
password = <string>      # User password. Optional.
dbname = <string>        # Default database. Optional.
schemaname = <string>    # Default schema. Optional.
warehousename = <string> # Default warehouse. Optional.
#rolename = <string>      # Default role. Optional.
#authenticator = <string> # Authenticator: 'snowflake', 'externalbrowser', etc
```
Please follow [this](https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings) for more details.

In [4]:
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session, Column, functions

session = Session.builder.configs(SnowflakeLoginOptions()).create()

## Open/Create Model Registry

A model registry needs to be created before it can be used. The creation will create a new database in the current account so the active role needs to have permissions to create a database. After the first creation, the model registry can be opened without the need to create it again.

In [5]:
import importlib
from snowflake.ml.registry import model_registry
# Force re-loading model_registry in case we updated the package during the runtime of this notebook.
importlib.reload(model_registry)

<module 'snowflake.ml.registry.model_registry' from '/opt/homebrew/anaconda3/envs/snowpark/lib/python3.8/site-packages/snowflake/ml/registry/model_registry.py'>

In [6]:
# Create a new model registry. This will be a no-op if the registry already exists.
create_result = model_registry.create_model_registry(session)



In [7]:
registry = model_registry.ModelRegistry(session=session)

There are two functionally equivalent APIs to interact with the model registry.

* A _relational API_ where all operations are performed as methods of the `registry` object and 
* a _object API_ where operations on a specific model are performend as methods of a `ModelReference` object.

The usage examples below will add some color to the two APIs and how they behave.

## Register a new Model

Registering a new model is always performed through the relational API. 

The call to `log_model` executes a few steps:
1. The given model object is serialized and uploaded to a stage.
1. An entry in the Model Registry is created for the model, referencing the model stage location.
1. Additional metadata is updated for the model as provided in the call.

For the serialization to work, the model object needs to be serializable in python.

In [8]:
# A name and model tags can be added to the model at registration time.
model_id = registry.log_model(model=clf, name="my_model", tags={
    "stage": "testing", "classifier_type": "svm.SVC", "svc_gamma": svc_gamma, "svc_C": svc_C})

# The object API can be used to reference a model after creation.
model = model_registry.ModelReference(registry=registry, id=model_id)
print("Registered new model:", model_id)

Registered new model: 89f1499eb70011ed9cd3e289b4f89202


## Add Metrics

Metrics are a type of metadata annotation that can be associated with models stored in the Model Registry. Metrics often take the form of scalars but we also support more complex objects such as arrays or dictionaries to represent metrics. In the exmamples below, we add scalars, dictionaries, and a 2-dimensional numpy array as metrics.

In [9]:
from sklearn import metrics

test_accuracy = metrics.accuracy_score(test_labels, prediction)
print("Model test accuracy:", test_accuracy)

# Simple scalar metrics.

# Relational API
registry.set_metric(id=model_id, name="test_accuracy", value=test_accuracy)
# Object API
model.set_metric(name="num_training_examples", value=num_training_examples)

# Hierarchical metric.
registry.set_metric(id=model_id, name="dataset_test", value={"accuracy": test_accuracy})

# Multivalent metric:
test_confusion_matrix = metrics.confusion_matrix(test_labels, prediction)
print("Confusion matrix:", test_confusion_matrix)

registry.set_metric(id=model_id, name="confusion_matrix", value=test_confusion_matrix)

Model test accuracy: 0.97
Confusion matrix: [[90  0]
 [ 3  7]]


## List Model in Registry

Listing models in the registry returns a SnowPark DataFrame. That allows the caller to select and filter the models as needed. In the example below, we list the name, tags, and metrics for the model we just added.

In [10]:
model_list = registry.list_models()

model_list.filter(model_list["ID"] == model_id).select("NAME","TAGS","METRICS").show()

---------------------------------------------------------------------------------
|"NAME"      |"TAGS"                           |"METRICS"                       |
---------------------------------------------------------------------------------
|"my_model"  |{                                |{                               |
|            |  "classifier_type": "svm.SVC",  |  "confusion_matrix": [         |
|            |  "stage": "testing",            |    [                           |
|            |  "svc_C": 10,                   |      90,                       |
|            |  "svc_gamma": 0.001             |      0                         |
|            |}                                |    ],                          |
|            |                                 |    [                           |
|            |                                 |      3,                        |
|            |                                 |      7                         |
|            |  

## Metadata: Tags and Name

Similar to how we changed metrics in the example above, we can also edit tags and names of models both with the relational API and with the object API.

### Relational API

In [11]:
print("Old tags:", registry.get_tags(model_id))

registry.set_tag(model_id, "minor_version", "23")
print("Added tag:", registry.get_tags(model_id))

registry.remove_tag(model_id, "minor_version")
print("Removed tag", registry.get_tags(model_id))
registry.set_tag(model_id, "stage", "production")
print("Updated tag:", registry.get_tags(model_id))

# Rename Model
print("Old name:", registry.get_model_name(model_id))

new_model_name = f"target_digit_{target_digit}"
registry.set_model_name(id=model_id, name=new_model_name)

print("New name:", registry.get_model_name(model_id))

Old tags: {'classifier_type': 'svm.SVC', 'stage': 'testing', 'svc_C': 10, 'svc_gamma': 0.001}
Added tag: {'classifier_type': 'svm.SVC', 'minor_version': '23', 'stage': 'testing', 'svc_C': 10, 'svc_gamma': 0.001}
Removed tag {'classifier_type': 'svm.SVC', 'stage': 'testing', 'svc_C': 10, 'svc_gamma': 0.001}
Updated tag: {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}
Old name: "my_model"
New name: "target_digit_6"


### Object API

In [12]:
print("Old tags:", model.get_tags())

model.set_tag("minor_version", "23")
print("Added tag:", model.get_tags())

model.remove_tag("minor_version")
print("Removed tag", model.get_tags())
model.set_tag("stage", "production")
print("Updated tag:", model.get_tags())

# Rename Model
print("Old name:", model.get_model_name())

new_model_name = f"target_digit_{target_digit}"
model.set_model_name(name=new_model_name)

print("New name:", model.get_model_name())

Old tags: {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}
Added tag: {'classifier_type': 'svm.SVC', 'minor_version': '23', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}
Removed tag {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}
Updated tag: {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}
Old name: "target_digit_6"
New name: "target_digit_6"


## List recent Models in Registry

Listing the models in the Model Registry returns a dataframe that allows us to conveniently manipulate the model list. In the example below, we show all models in the Model Registry sorted by recency.

In [13]:
model_list.select("ID","NAME","CREATION_TIME","TAGS").order_by("CREATION_TIME", ascending=False).show()

---------------------------------------------------------------------------------------------------------------------------------------------
|"ID"                              |"NAME"                |"CREATION_TIME"                   |"TAGS"                                        |
---------------------------------------------------------------------------------------------------------------------------------------------
|89f1499eb70011ed9cd3e289b4f89202  |"target_digit_6"      |2023-02-27 16:40:47.108000-08:00  |{                                             |
|                                  |                      |                                  |  "classifier_type": "svm.SVC",               |
|                                  |                      |                                  |  "stage": "production",                      |
|                                  |                      |                                  |  "svc_C": 10,                                |
|     

## List all versions of a Model ordered by test set accuracy

With a similar logic, we can also list all versions of a model with a given name sorted by a metric, in this case model accuracy.

In [14]:
model_list.select("ID","NAME","TAGS","METRICS").filter(
    Column("NAME") == new_model_name).order_by(Column("METRICS")["test_accuracy"], ascending=False 
).show()                                                                                                             

--------------------------------------------------------------------------------------------------------------------------
|"ID"                              |"NAME"            |"TAGS"                           |"METRICS"                       |
--------------------------------------------------------------------------------------------------------------------------
|3df8e89099ef11edbd00e289b4f89203  |"target_digit_6"  |{                                |{                               |
|                                  |                  |  "classifier_type": "svm.SVC",  |  "confusion_matrix": [         |
|                                  |                  |  "stage": "production",         |    [                           |
|                                  |                  |  "svc_C": 100,                  |      88,                       |
|                                  |                  |  "svc_gamma": 0.0001            |      2                         |
|               

## Examine Model History

In additon to the current state of the model metadata, we also give access to the history of all changes to the model metadata. This includes the registration event itself but also changes to any metadata of the model, when they happend and who initiated them.

### Relational API

In [15]:
registry.get_model_history(id=model_id).select("EVENT_TIMESTAMP", "ROLE", "ATTRIBUTE_NAME","VALUE[ATTRIBUTE_NAME]").show()

-----------------------------------------------------------------------------------------------------------------------------------
|"EVENT_TIMESTAMP"                 |"ROLE"                |"ATTRIBUTE_NAME"  |"VALUE[ATTRIBUTE_NAME]"                             |
-----------------------------------------------------------------------------------------------------------------------------------
|2023-02-27 16:40:50.433000-08:00  |"ENG_ML_MODELING_RL"  |REGISTRATION      |{                                                   |
|                                  |                      |                  |  "CREATION_ENVIRONMENT_SPEC": {                    |
|                                  |                      |                  |    "python": "3.8.16"                              |
|                                  |                      |                  |  },                                                |
|                                  |                      |                 

### Object API

In [16]:
model.get_model_history().select("EVENT_TIMESTAMP", "ROLE", "ATTRIBUTE_NAME","VALUE[ATTRIBUTE_NAME]").show()

-----------------------------------------------------------------------------------------------------------------------------------
|"EVENT_TIMESTAMP"                 |"ROLE"                |"ATTRIBUTE_NAME"  |"VALUE[ATTRIBUTE_NAME]"                             |
-----------------------------------------------------------------------------------------------------------------------------------
|2023-02-27 16:40:50.433000-08:00  |"ENG_ML_MODELING_RL"  |REGISTRATION      |{                                                   |
|                                  |                      |                  |  "CREATION_ENVIRONMENT_SPEC": {                    |
|                                  |                      |                  |    "python": "3.8.16"                              |
|                                  |                      |                  |  },                                                |
|                                  |                      |                 

## Load Model

We can also restore the model we saved to the registry and load it back into the local context to make predictions.

### Relational API

In [17]:
registry = model_registry.ModelRegistry(session=session)

restored_clf = registry.load_model(id=model_id)

restored_prediction = restored_clf.predict(test_features)

print("Original prediction:", prediction[:10])
print("Restored prediction:", restored_prediction[:10])

Original prediction: [False False False False False False False False False False]
Restored prediction: [False False False False False False False False False False]


### Object API

In [18]:
registry = model_registry.ModelRegistry(session=session)
model = model_registry.ModelReference(registry=registry, id=model_id)
restored_clf = model.load_model()

restored_prediction = restored_clf.predict(test_features)

print("Original prediction:", prediction[:10])
print("Restored prediction:", restored_prediction[:10])

Original prediction: [False False False False False False False False False False]
Restored prediction: [False False False False False False False False False False]
