In [None]:
DATABASE = 'EM_TESTS'
SCHEMA = 'PARTITIONED_MODELS'

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.mixture import GaussianMixture
import snowflake.snowpark.functions as F

from snowflake.ml.model import custom_model
from snowflake.ml.registry import Registry
from snowflake.ml.model.model_signature import ModelSignature, FeatureSpec, DataType


# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

In [None]:
n_rows = 2000
# Create the dataframe of data: 2 categories, 2 values
df = pd.DataFrame({
    'CURRENCY': np.random.choice(['USD', 'EUR', 'CAN'], size=n_rows),
    'COMPANY': np.random.choice(['COMP_A', 'COMP_B', 'COMP_C', 'COMP_D'], size=n_rows),
    'FOO': np.random.normal(size=n_rows),
    'BAR': np.random.binomial(n=5, p=.5, size=n_rows),
})
df['BAR'] = df['BAR'].astype(float)
df.head()

In [None]:
snowpark_df = session.create_dataframe(df)
snowpark_df.write.mode("overwrite").save_as_table(f"{DATABASE}.{SCHEMA}.DATA")

In [None]:
class ExamplePartitionedModel(custom_model.CustomModel):
    @custom_model.partitioned_inference_api
    def predict(self, df: pd.DataFrame) -> pd.DataFrame:
        from sklearn.mixture import GaussianMixture
        gm = GaussianMixture(n_components=2, random_state=0)
        gm.fit(df[['__TARGET__']])
    
        output_df = pd.DataFrame([{
            'MEANS': list(gm.means_.flatten()),
            'COVARIANCES': list(gm.covariances_.flatten()),
        }])
        return output_df

my_model = ExamplePartitionedModel()

In [None]:
# Test that it works:
preds = my_model.predict(snowpark_df.rename(F.col("FOO"), "__TARGET__").filter(F.col("CURRENCY") == 'USD').to_pandas())
preds

In [None]:
reg = Registry(session=session, database_name=DATABASE, schema_name=SCHEMA)
#reg.delete_model("my_model")  # use this to clear the model for rapid iteration
model_version = reg.log_model(my_model,
    model_name="my_model",
    version_name="v1",
    options={"function_type": "TABLE_FUNCTION"},
    conda_dependencies=["pandas", "scikit-learn"],
    signatures={
        "predict": ModelSignature(
            inputs=[
                FeatureSpec(dtype=DataType.STRING, name='PARTITION'),
                FeatureSpec(dtype=DataType.FLOAT, name='__TARGET__'),
            ],
            outputs=[
                FeatureSpec(dtype=DataType.FLOAT, name='MEANS', shape=(1, 2)),
                FeatureSpec(dtype=DataType.FLOAT, name='COVARIANCES', shape=(1, 2)),
            ],
        )
    }
)

In [None]:
# Cumbersome way of dealing with column renames
result = model_version.run(snowpark_df \
                           .rename(F.col("FOO"), "__TARGET__") \
                           .rename(F.col("COMPANY"), "PARTITION") \
                           .select("PARTITION", "__TARGET__"),
                           partition_column='PARTITION')
result.rename(F.col("PARTITION"), "COMPANY")

In [None]:
def split_cols(snowpark_df, column_to_split, output_column_names, delimiter='|'):
    """
    Snowpark operations to extract the components in a combined string
    
    Assumes order of output_column_names are in the same order in the column_to_split string
    """
    snowpark_df = snowpark_df.with_column("___SPLIT", F.split(F.col(column_to_split), F.lit(delimiter)))
    for i, col in enumerate(output_column_names):
        snowpark_df = snowpark_df.with_column(col, F.get(F.col("___SPLIT"), F.lit(i)))
        snowpark_df = snowpark_df.with_column(col, F.regexp_replace(F.col(col), '^"(.*)"$', '\1'))  # Stupid step needed to get rid of extra "
    return snowpark_df.drop("___SPLIT", column_to_split)

In [None]:
def call_model(model_version, snowpark_df, group_columns, target_column):
    # Prepare the grouping column
    if isinstance(group_columns, str):
        snowpark_df = snowpark_df.rename(F.col(group_columns), "PARTITION")
    elif isinstance(group_columns, (list, tuple)):
        concat_columns = []
        for group_column in group_columns:
            concat_columns.append(F.col(group_column))
            concat_columns.append(F.lit("|"))
        concat_columns.pop(-1)
        snowpark_df = snowpark_df.with_column("PARTITION", F.concat(*concat_columns))

    # Prepare the target column
    snowpark_df = snowpark_df.rename(F.col(target_column), '__TARGET__')

    # Call the model predictions
    result_df = model_version.run(snowpark_df, partition_column='PARTITION')

    # Prepare the grouping column for output
    if isinstance(group_columns, str):
        result_df = result_df.rename(F.col("PARTITION"), group_columns)
    elif isinstance(group_columns, (list, tuple)):
        result_df = split_cols(result_df, 'PARTITION', group_columns, delimiter='|')

    return result_df.drop("__TARGET__")

# Calling a model with a single column group

In [None]:
call_model(model_version, snowpark_df, group_columns='CURRENCY', target_column='FOO')

# Calling a model with a composite group

In [None]:
call_model(model_version, snowpark_df, group_columns=('COMPANY', 'CURRENCY'), target_column='FOO')