In [1]:
import os
from dotenv import load_dotenv

# Snowflake ML / Snowpark imports
import snowflake.snowpark as snowpark
from snowflake.ml.model import custom_model
from snowflake.ml.registry import Registry

# Standard Python / ML imports
import pandas as pd
import joblib


In [2]:
# 1. Load environment variables from .env
load_dotenv()

SNOWFLAKE_USER = os.getenv("SNOWFLAKE_USER")
SNOWFLAKE_PASSWORD = os.getenv("SNOWFLAKE_PASSWORD")
SNOWFLAKE_ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT")
SNOWFLAKE_DATABASE = os.getenv("SNOWFLAKE_DATABASE")
SNOWFLAKE_SCHEMA = os.getenv("SNOWFLAKE_SCHEMA")
SNOWFLAKE_WAREHOUSE = os.getenv("SNOWFLAKE_WAREHOUSE")
SNOWFLAKE_ROLE = os.getenv("SNOWFLAKE_ROLE")

In [3]:
# 2. Build Snowflake connection parameters for Snowpark
connection_parameters = {
    "account": SNOWFLAKE_ACCOUNT,
    "user": SNOWFLAKE_USER,
    "password": SNOWFLAKE_PASSWORD,
    "role": SNOWFLAKE_ROLE,
    "warehouse": SNOWFLAKE_WAREHOUSE,
    "database": SNOWFLAKE_DATABASE,
    "schema": SNOWFLAKE_SCHEMA
}

In [4]:
# 3. Create a Snowpark session
session = snowpark.Session.builder.configs(connection_parameters).create()
print("Snowpark session created.")

Snowpark session created.


In [5]:
# 4. Provide context for your local .pkl model
from snowflake.ml.model.custom_model import ModelContext
model_context = ModelContext(model_file="anomaly_model.pkl")

In [6]:
# 5. Define a custom model class that loads your .pkl model
class AnomalyDetectionModel(custom_model.CustomModel):
    def __init__(self, context):
        super().__init__(context)
        # Load the model from the .pkl file
        model_path = self.context["model_file"]
        print(f"Loading local model from: {model_path}")
        self.model = joblib.load(model_path)

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        """
        Your prediction logic. 
        Typically for an Isolation Forest: -1 = anomaly, 1 = normal.
        We'll return both raw predictions and a 'tag' column.
        """
        predictions = self.model.predict(X)
        tags = ["anomaly" if p == -1 else "normal" for p in predictions]
        # Return them as columns in a DataFrame
        return pd.DataFrame({"prediction": predictions, "anomaly_tag": tags})

In [7]:
# 6. Instantiate your custom model class
anomaly_model = AnomalyDetectionModel(model_context)

Loading local model from: anomaly_model.pkl


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [8]:
# 7. Provide a small sample of data that matches the features your .pkl expects
#    (In your case, 'price' and 'size' might be the main features.)
sample_data = pd.DataFrame({
    "price": [100.0, 105.5],
    "size": [10, 20]
})

In [9]:
# 8. Register the model in Snowflake Model Registry
registry = Registry(session=session)

model_version = registry.log_model(
    model=anomaly_model,
    model_name="anomaly_detection_model",
    version_name="v1",  # or another version label
    conda_dependencies=["scikit-learn", "joblib", "pandas"],  # Required deps
    sample_input_data=sample_data,
    comment="IsolationForest anomaly detection model"
)

print("Model registered successfully. Version info:")
model_version.show_functions()

  self.manifest.save(


Model registered successfully. Version info:


[{'name': 'PREDICT',
  'target_method': 'predict',
  'target_method_function_type': 'FUNCTION',
  'signature': ModelSignature(
                      inputs=[
                          FeatureSpec(dtype=DataType.DOUBLE, name='price', nullable=True),
  		FeatureSpec(dtype=DataType.INT64, name='size', nullable=True)
                      ],
                      outputs=[
                          FeatureSpec(dtype=DataType.INT64, name='prediction', nullable=True),
  		FeatureSpec(dtype=DataType.STRING, name='anomaly_tag', nullable=True)
                      ]
                  ),
  'is_partitioned': False}]

In [10]:
# 9. (Optional) List all models in the registry to confirm
models_df = registry.show_models()
print("List of models in the registry:")
models_df.head(10)

List of models in the registry:


Unnamed: 0,created_on,name,model_type,database_name,schema_name,comment,owner,default_version_name,versions,aliases
0,2025-03-30 21:17:13.708000-07:00,ANOMALY_DETECTION_MODEL,USER_MODEL,TRADES_DB,TRADES_SCHEMA,,ACCOUNTADMIN,V1,"[""V1""]","{""DEFAULT"":""V1"",""FIRST"":""V1"",""LAST"":""V1""}"
