# Partitioned Custom ML Model with Model Registry

This notebook includes two different models and datasets. They are both capable of being tested locally as well as run entirely in Snowflake. I have also made it so you can push the datasets into a Snowflake table for running the inference from the Snowflake model registry.

### Partitioned restaurant traffic forecasting model

The dataset is loaded locally from the `Partitioned_Custom_Model_Restaurant_Traffic_Data.csv` file.

Change `"MY_DB"` and `"MY_SCHEMA"` to your desired existing database and schema.

In [10]:
import json
from snowflake.snowpark import Session
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

with open('../../creds.json') as f:
    data = json.load(f)
    USERNAME = data['user']
    SF_ACCOUNT = data['account']
    SF_WH = data['warehouse']
    passphrase = data['passphrase']

# Read the private key from the .p8 file
with open('../../rsa_key.p8', 'rb') as key_file:
    private_key = key_file.read()

# If the private key is encrypted, load it with a passphrase
# Replace 'your_key_passphrase' with your actual passphrase if needed
private_key_obj = serialization.load_pem_private_key(
    private_key,
    password=passphrase.encode() if passphrase else None,
    backend=default_backend()
)

# Define connection parameters including the private key
CONNECTION_PARAMETERS = {
    'user': USERNAME,
    'account': SF_ACCOUNT,
    'private_key': private_key_obj,
    'warehouse': SF_WH,
}

# Create a session with the specified connection parameters
session = Session.builder.configs(CONNECTION_PARAMETERS).create()

from snowflake.core.warehouse import Warehouse
from snowflake.core import Root
root = Root(session)
from snowflake.snowpark.functions import col 
from time import time


In [12]:
from datetime import timedelta

import pandas as pd

from snowflake.ml.model import custom_model
from snowflake.ml.model import model_signature
from snowflake.ml.registry import registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import functions as F

In [13]:
from snowflake.snowpark.version import VERSION
snowflake_environment = session.sql('select current_user(), current_version()').collect()
snowpark_version = VERSION

from snowflake.ml import version
mlversion = version.VERSION

# Current Environment Details
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(session.get_current_role()))
print('Database                    : {}'.format(session.get_current_database()))
print('Schema                      : {}'.format(session.get_current_schema()))
print('Warehouse                   : {}'.format(session.get_current_warehouse()))
print('Snowflake version           : {}'.format(snowflake_environment[0][1]))
print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0],snowpark_version[1],snowpark_version[2]))
print('Snowflake ML version        : {}.{}.{}'.format(mlversion[0],mlversion[2],mlversion[4]))

User                        : RSHAH
Role                        : "RAJIV"
Database                    : "RAJIV"
Schema                      : "PUBLIC"
Warehouse                   : "RAJIV"
Snowflake version           : 8.30.2
Snowpark for Python version : 1.20.0
Snowflake ML version        : 1.6.1


In [5]:
REGISTRY_DATABASE_NAME = "TPCDS_XGBOOST"
REGISTRY_SCHEMA_NAME = "DEMO"
reg = registry.Registry(session=session, database_name=REGISTRY_DATABASE_NAME, schema_name=REGISTRY_SCHEMA_NAME)

#### The dataset contains an epoch timestamp in milliseconds, a store ID which will later be used as a partition column, a feature column `COLLEGE_TOWN`, and a target to be forecasted, `HOURLY_TRAFFIC`.

In [63]:
# Load data from csv file into pandas dataframe.
test_df_pandas = pd.read_csv("Partitioned_Custom_Model_Restaurant_Traffic_Data.csv")
test_df = session.create_dataframe(test_df_pandas)
test_df.show()

--------------------------------------------------------------------
|"EPOCH"          |"STORE_ID"  |"COLLEGE_TOWN"  |"HOURLY_TRAFFIC"  |
--------------------------------------------------------------------
|1529154000000.0  |1.0         |1.0             |82                |
|1529182800000.0  |1.0         |1.0             |2                 |
|1529247600000.0  |1.0         |1.0             |35                |
|1529269200000.0  |1.0         |1.0             |9                 |
|1529326800000.0  |1.0         |1.0             |114               |
|1529514000000.0  |1.0         |1.0             |24                |
|1529697600000.0  |1.0         |1.0             |31                |
|1529424000000.0  |1.0         |1.0             |28                |
|1529575200000.0  |1.0         |1.0             |13                |
|1529931600000.0  |1.0         |1.0             |110               |
--------------------------------------------------------------------



In [9]:
test_df.write.mode('overwrite').save_as_table('TPCDS_XGBOOST.DEMO.Restaurant_Traffic_Data')

In [53]:
test_df = session.table('TPCDS_XGBOOST.DEMO.Restaurant_Traffic_Data')

Data set is 
5209585 rows with 200 unique store IDs

In [61]:
unique_store_count = test_df.select(test_df['STORE_ID']).distinct().count()
print(unique_store_count)

201


In [60]:
test_df.count()

5209585

In [64]:
class ForecastingModel(custom_model.CustomModel):

    # Use the same decorator as for methods with FUNCTION inference.
    @custom_model.partitioned_inference_api
    def predict(self, df: pd.DataFrame) -> pd.DataFrame:        
        import xgboost

        # Set the time column as our index.
        input_df = df.set_index('EPOCH')
        input_df.index = pd.to_datetime(df['EPOCH'], unit='ms')

        # Generate categorical features using the datetime index.
        input_df['HOUR'] = input_df.index.hour.astype("category")
        input_df['DAY_OF_WEEK'] = input_df.index.dayofweek.astype("category")
        input_df['MONTH'] = input_df.index.month.astype("category")
        input_df['YEAR'] = input_df.index.year.astype("category")
        
        input_df['COLLEGE_TOWN'] = input_df['COLLEGE_TOWN'].astype("category")
        
        # Use get_dummies (one-hot encoding) for categorical features.
        final = pd.get_dummies(data=input_df, columns=['COLLEGE_TOWN', 'HOUR', 'MONTH', 'YEAR', 'DAY_OF_WEEK'])

        # Define the train & forecast split thresholds.
        today = pd.to_datetime('2022-10-01')
        yesterday = today - timedelta(days=1)
        four_weeks = today + timedelta(days=28)
        tomorrow = today + timedelta(days=1)

        # Train data starts on June 16th 2018 and ends on September 30th.
        train = final[(final.index >= pd.to_datetime('16-Jun-2018')) & (final.index <= pd.to_datetime(yesterday))]
        
        # The forecast starts from October 1st 2022 and goes 4 weeks into the future.
        forecast = final[(final.index >= pd.to_datetime(tomorrow)) & (final.index <= pd.to_datetime(four_weeks))]

        # Remove the target from the input dataset, and construct target dataset.
        X_train = train.drop('HOURLY_TRAFFIC', axis=1)
        y_train = train['HOURLY_TRAFFIC']

        X_forecast = forecast.drop('HOURLY_TRAFFIC', axis=1)
        
        # Train an XGBoost regression model.
        model = xgboost.XGBRegressor(n_estimators=200, n_jobs=1)
        model.fit(X_train, y_train, verbose=False)

        # Predict the hourly forecast for the future dates and make sure no predictions are less than zero.
        forecast['PREDICTION'] = model.predict(X_forecast)
        forecast['EPOCH_OUT'] = [t.value // 10**9 for t in forecast.index]
        forecast = forecast[['EPOCH_OUT', 'PREDICTION']]
        forecast = forecast.sort_index()
        forecast.loc[forecast['PREDICTION'] < 0, 'PREDICTION'] = 0

        return forecast

In [66]:
my_forecasting_model = ForecastingModel()

#### The predict method can be tested locally by using a pandas dataframe directly. Here we can run `predict` for a single partition.

In [67]:
my_forecasting_model.predict(test_df_pandas.loc[test_df_pandas['STORE_ID'] == 1])
#my_forecasting_model.predict(test_df_pandas)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  forecast['PREDICTION'] = model.predict(X_forecast)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  forecast['EPOCH_OUT'] = [t.value // 10**9 for t in forecast.index]


Unnamed: 0_level_0,EPOCH_OUT,PREDICTION
EPOCH,Unnamed: 1_level_1,Unnamed: 2_level_1
2022-10-02 07:00:00,1664694000,77.787636
2022-10-02 08:00:00,1664697600,76.995056
2022-10-02 09:00:00,1664701200,76.751877
2022-10-02 10:00:00,1664704800,76.600456
2022-10-02 11:00:00,1664708400,96.709358
...,...,...
2022-10-28 18:00:00,1666980000,30.727842
2022-10-28 19:00:00,1666983600,30.608843
2022-10-28 20:00:00,1666987200,31.198336
2022-10-28 21:00:00,1666990800,4.420763


#### Log the model, specifying the `function_type: "TABLE_FUNCTION"` option.

In [8]:
options = {
    "function_type": "TABLE_FUNCTION",
}

mv = reg.log_model(
    my_forecasting_model,
    model_name="forecast",
    version_name="v13",
    conda_dependencies=["pandas", "scikit-learn", "xgboost"],
    options=options,
    signatures={
        "predict": model_signature.ModelSignature(
            inputs=[
                model_signature.FeatureSpec(name="EPOCH", dtype=model_signature.DataType.DOUBLE),
                model_signature.FeatureSpec(name="STORE_ID", dtype=model_signature.DataType.DOUBLE),
                model_signature.FeatureSpec(name="COLLEGE_TOWN", dtype=model_signature.DataType.DOUBLE),
                model_signature.FeatureSpec(name="HOURLY_TRAFFIC", dtype=model_signature.DataType.INT64),
            ],
            outputs=[
                model_signature.FeatureSpec(name="EPOCH_OUT", dtype=model_signature.DataType.FLOAT),
                model_signature.FeatureSpec(name="PREDICTION", dtype=model_signature.DataType.FLOAT),
            ],
        )
    },
)

  return next(self.gen)


#### Use the `run` method for inference, specifying the partition column.

In [56]:
mv = reg.get_model("forecast").version("v13")

In [57]:
snowpark_opt_wh = Warehouse(
  name="snowpark_opt_wh",
  warehouse_size="LARGE",
  warehouse_type = "SNOWPARK-OPTIMIZED",
  auto_suspend=600,
)
warehouses = root.warehouses["snowpark_opt_wh"]
warehouses.create_or_alter(snowpark_opt_wh)

session.sql('USE WAREHOUSE SNOWPARK_OPT_WH').collect()
session.sql('alter session set USE_CACHED_RESULT = FALSE').collect()
session.sql('alter session set query_tag = "TS_XG_LARGE" ').collect()
print(session.get_current_warehouse())

"SNOWPARK_OPT_WH"


In [58]:
result = mv.run(test_df, partition_column="STORE_ID")
result.select("EPOCH_OUT", "PREDICTION", "STORE_ID").to_pandas()

Unnamed: 0,EPOCH_OUT,PREDICTION,STORE_ID
0,1.664694e+09,74.812241,7.0
1,1.664698e+09,77.070419,7.0
2,1.664701e+09,75.961517,7.0
3,1.664705e+09,75.930870,7.0
4,1.664708e+09,96.267212,7.0
...,...,...,...
86395,1.666980e+09,31.088211,78.0
86396,1.666984e+09,30.569330,78.0
86397,1.666987e+09,32.221611,78.0
86398,1.666991e+09,6.004010,78.0


In [None]:
## Optimized S it ran in 24 seconds 
## Optimized M is ran in 21 seconds 


Raj test:
## Optimized L ran in 15 seconds

## Local test - one thread it ran in 1 minute 40 seconds


### StatsForecast Arima Model on Generated Data

In [23]:
#Generate Series - Takes 2 minutes to run
#Only need to run this the first time
from statsforecast.utils import generate_series
for length in [10_000, 100_000, 500_000, 1_000_000, 2_000_000]:
		print(f'length: {length}')
		series = generate_series(n_series=length, seed=1)

series

length: 10000
length: 100000
length: 500000
length: 1000000
length: 2000000


In [7]:
## Save series to Snowflake table
#Only need to run this the first time
df = pd.DataFrame(series)
df_reset = df.reset_index()
df_reset.columns = ['ID', 'DS', 'Y']

test_df = session.create_dataframe(df_reset)
test_df.write.mode('overwrite').save_as_table('TPCDS_XGBOOST.DEMO.Series2M')
train_df = session.table('TPCDS_XGBOOST.DEMO.SERIES2M')

In [6]:
#Retrieve from Snowflake -- 
train_df = session.table('TPCDS_XGBOOST.DEMO.SERIES2M')
train_df.show()

--------------------------------------------------------
|"ID"    |"DS"                 |"Y"                    |
--------------------------------------------------------
|181695  |2000-02-22 00:00:00  |6.219272538160337      |
|181695  |2000-02-23 00:00:00  |0.3076429294607981     |
|181695  |2000-02-24 00:00:00  |1.197810254827208      |
|181695  |2000-02-25 00:00:00  |2.173458515198763      |
|181695  |2000-02-26 00:00:00  |3.102199405394565      |
|181695  |2000-02-27 00:00:00  |4.376139372280642      |
|181695  |2000-02-28 00:00:00  |5.375742028359614      |
|181695  |2000-02-29 00:00:00  |6.147630148293396      |
|181695  |2000-03-01 00:00:00  |0.0025383417716690615  |
|181695  |2000-03-02 00:00:00  |1.0790424184236609     |
--------------------------------------------------------



In [31]:
### Local Test for Model
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA, Naive

model = StatsForecast(models=[AutoARIMA(), Naive()],
                      freq='D',
                      n_jobs=-1)
df_reset.columns = ['unique_id', 'ds', 'y']
forecasts_df = model.forecast(df=df_reset, h=7)
forecasts_df.head()

  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')


Unnamed: 0_level_0,ds,AutoARIMA,Naive
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,2000-03-28,1.626143,2.053747
0,2000-03-29,1.287569,2.053747
0,2000-03-30,1.019489,2.053747
0,2000-03-31,0.807224,2.053747
0,2000-04-01,0.639155,2.053747


In [11]:
class ForecastingModel(custom_model.CustomModel):

    # Use the same decorator as for methods with FUNCTION inference.
    @custom_model.partitioned_inference_api
    def predict(self, df: pd.DataFrame) -> pd.DataFrame:        
        from statsforecast import StatsForecast
        from statsforecast.models import AutoARIMA, Naive
        df.columns = ['unique_id', 'ds', 'y']
        model = StatsForecast(models=[AutoARIMA()],
                      freq='D',
                      n_jobs=-1)

        forecasts_df = model.forecast(df=df, h=7)
        forecasts_df.columns = ['DSOUT', 'AUTOARIMA']

        return forecasts_df

In [12]:
my_forecasting_model = ForecastingModel()

In [62]:
my_forecasting_model.predict(df_reset)

  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')
  multiarray.copyto(a, fill_value, casting='unsafe')


Unnamed: 0_level_0,DSOUT,AUTOARIMA
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,2000-03-28,1.626143
0,2000-03-29,1.287569
0,2000-03-30,1.019489
0,2000-03-31,0.807224
0,2000-04-01,0.639155
...,...,...
99,2000-08-30,3.246663
99,2000-08-31,2.810550
99,2000-09-01,2.433019
99,2000-09-02,2.106200


In [58]:
options = {
    "function_type": "TABLE_FUNCTION",
}

mv = reg.log_model(
    my_forecasting_model,
    model_name="statsforecast",
    version_name="v8",
    conda_dependencies=["pandas", "statsforecast"],
    options=options,
    signatures={
        "predict": model_signature.ModelSignature(
            inputs=[
                model_signature.FeatureSpec(name="ID", dtype=model_signature.DataType.INT64),
                model_signature.FeatureSpec(name="DS", dtype=model_signature.DataType.TIMESTAMP_NTZ),
                model_signature.FeatureSpec(name="Y", dtype=model_signature.DataType.DOUBLE),
            ],
            outputs=[
               # model_signature.FeatureSpec(name="ID", dtype=model_signature.DataType.INT64),
                model_signature.FeatureSpec(name="DSOUT", dtype=model_signature.DataType.TIMESTAMP_NTZ),
                model_signature.FeatureSpec(name="AUTOARIMA", dtype=model_signature.DataType.FLOAT),
            ],
        )
    },
)

  return next(self.gen)


In [7]:
reg_model = reg.get_model("statsforecast").version("v8")

In [14]:
snowpark_opt_wh = Warehouse(
  name="snowpark_opt_wh",
  warehouse_size="LARGE",
  #warehouse_type = "SNOWPARK-OPTIMIZED",
  auto_suspend=600,
)
warehouses = root.warehouses["snowpark_opt_wh"]
warehouses.create_or_alter(snowpark_opt_wh)
session.use_warehouse("snowpark_opt_wh")

session.sql('alter session set USE_CACHED_RESULT = FALSE').collect()
session.sql('alter session set query_tag = "TS-LARGE" ').collect()

print(session.get_current_warehouse())

"SNOWPARK_OPT_WH"


In [9]:
lengths = [10_000, 50_000, 100_000, 500_000, 1_000_000,2_000_000]
#lengths = [1_000_000]

for length in lengths:
  unique_ids_df = train_df.select("ID").distinct().limit(length)
  filtered_df = train_df.join(unique_ids_df, on="ID", how="inner")
  print(unique_ids_df.count())
  init = time()
  # Run the regression model
  result = reg_model.run(filtered_df, partition_column="ID").collect()
  total_time = (time() - init) / 60
  print(f'n_series: {length} total time: {total_time} total rows: {filtered_df.count()}')

10000
n_series: 10000 total time: 72.435857351621 total rows: 2740194
50000
n_series: 50000 total time: 86.72789065043132 total rows: 13750257
100000
n_series: 100000 total time: 44.73047570387522 total rows: 27524168
500000
n_series: 500000 total time: 144.72759178082148 total rows: 137357685
1000000
n_series: 1000000 total time: 267.6801359653473 total rows: 274943496
2000000
n_series: 2000000 total time: 358.4139559308688 total rows: 549884998
