In [None]:
DATABASE = 'test_numpyro'
SCHEMA = 'numpyro'
WAREHOUSE = 'FD_WH'
NUM_TIME_SERIES = 10
UDTF_NAME = 'waffle_divorce_vect_udtf'
MODEL_STAGE_NAME = '@pymodels'
SNOWPARK_CONNECTION_NAME = 'personal_sandbox'
SINGLE_MODEL_MIN_RUN_TIME_SECONDS = 10
INPUT_TABLE_NAME = 'SERIES_WAFFLE_DIVORCE'
OUTPUT_TABLE_NAME = 'DATA_WITH_PREDICTIONS'

In [93]:
import numpy as np
import pandas as pd
import numpyro
import tqdm

from snowflake.snowpark.session import Session
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import types as T
from snowflake.snowpark import functions as F

In [94]:
session = Session.builder.config("connection_name", SNOWPARK_CONNECTION_NAME).create()

In [95]:
session.sql(f'create database if not exists {DATABASE}').collect()
session.use_database(DATABASE)
session.sql(f'create schema if not exists {SCHEMA}').collect()
session.use_schema(SCHEMA)


In [96]:
session.use_warehouse(WAREHOUSE)
# Add stage for UDFs and Stored Procs
session.sql(
    f"""
    create stage if not exists {MODEL_STAGE_NAME.replace('@', '')}
    """
).collect()

[Row(status='PYMODELS already exists, statement succeeded.')]

In [97]:
df = pd.read_csv('WaffleDivorce.csv', sep=';')
df.columns = [c.upper() for c in df.columns]  # Just upper case to avoid case changes with Snowflake
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 50 entries, 0 to 49
Data columns (total 13 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   LOCATION           50 non-null     object 
 1   LOC                50 non-null     object 
 2   POPULATION         50 non-null     float64
 3   MEDIANAGEMARRIAGE  50 non-null     float64
 4   MARRIAGE           50 non-null     float64
 5   MARRIAGE SE        50 non-null     float64
 6   DIVORCE            50 non-null     float64
 7   DIVORCE SE         50 non-null     float64
 8   WAFFLEHOUSES       50 non-null     int64  
 9   SOUTH              50 non-null     int64  
 10  SLAVES1860         50 non-null     int64  
 11  POPULATION1860     50 non-null     int64  
 12  PROPSLAVES1860     50 non-null     float64
dtypes: float64(7), int64(4), object(2)
memory usage: 5.2+ KB


In [98]:
df

Unnamed: 0,LOCATION,LOC,POPULATION,MEDIANAGEMARRIAGE,MARRIAGE,MARRIAGE SE,DIVORCE,DIVORCE SE,WAFFLEHOUSES,SOUTH,SLAVES1860,POPULATION1860,PROPSLAVES1860
0,Alabama,AL,4.78,25.3,20.2,1.27,12.7,0.79,128,1,435080,964201,0.45
1,Alaska,AK,0.71,25.2,26.0,2.93,12.5,2.05,0,0,0,0,0.0
2,Arizona,AZ,6.33,25.8,20.3,0.98,10.8,0.74,18,0,0,0,0.0
3,Arkansas,AR,2.92,24.3,26.4,1.7,13.5,1.22,41,1,111115,435450,0.26
4,California,CA,37.25,26.8,19.1,0.39,8.0,0.24,0,0,0,379994,0.0
5,Colorado,CO,5.03,25.7,23.5,1.24,11.6,0.94,11,0,0,34277,0.0
6,Connecticut,CT,3.57,27.6,17.1,1.06,6.7,0.77,0,0,0,460147,0.0
7,Delaware,DE,0.9,26.6,23.1,2.89,8.9,1.39,3,0,1798,112216,0.016
8,District of Columbia,DC,0.6,29.7,17.7,2.53,6.3,1.89,0,0,0,75080,0.0
9,Florida,FL,18.8,26.4,17.0,0.58,8.5,0.32,133,1,61745,140424,0.44


In [99]:
# Create the multiple series
dfs = []
cols_to_shuffle = ['MARRIAGE', 'DIVORCE', 'MEDIANAGEMARRIAGE']
for i in range(NUM_TIME_SERIES):
    d = df.copy()
    d.insert(0, 'TS_ID', i)
    for col in cols_to_shuffle:
        d[col] = np.random.choice(d[col].values, size=d.shape[0], replace=False)
    dfs.append(d)
dfs = pd.concat(dfs)
dfs.TS_ID.nunique()


1000

In [100]:
# Write it to a table
sf_df = session.create_dataframe(dfs)
sf_df.write.saveAsTable(f'{DATABASE}.{SCHEMA}.{INPUT_TABLE_NAME}', mode="overwrite", create_temp_table=False)

  success, _, _, ci_output = write_pandas(


In [101]:
# This will have to be customized for your needs
def forecast_function(df:pd.DataFrame, min_seconds=SINGLE_MODEL_MIN_RUN_TIME_SECONDS) -> pd.DataFrame:
    """
    Wraps functionality of Bayesian Regression Using NumPyro 0.15.13
    according to https://num.pyro.ai/en/0.15.3/tutorials/bayesian_regression.html

    For benchmarking purposes, put a minimum run time default of 10 seconds.
    """
    import random
    import time

    start_time = time.time()

    from jax import random, vmap
    import jax.numpy as jnp
    from jax.scipy.special import logsumexp

    from numpyro import handlers
    from numpyro.diagnostics import hpdi
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS
    from numpyro.infer import Predictive

    def standardize(x):
        return (x - x.mean()) / x.std()

    def model(marriage=None, age=None, divorce=None):
        a = numpyro.sample("a", dist.Normal(0.0, 0.2))
        M, A = 0.0, 0.0
        if marriage is not None:
            bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
            M = bM * marriage
        if age is not None:
            bA = numpyro.sample("bA", dist.Normal(0.0, 0.5))
            A = bA * age
        sigma = numpyro.sample("sigma", dist.Exponential(1.0))
        mu = a + M + A
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)

    df = df.copy()

    df["AGESCALED"] = df.MEDIANAGEMARRIAGE.pipe(standardize)
    df["MARRIAGESCALED"] = df.MARRIAGE.pipe(standardize)
    df["DIVORCESCALED"] = df.DIVORCE.pipe(standardize)

    # Start from this source of randomness. We will split keys for subsequent operations.
    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    # Run NUTS.
    kernel = NUTS(model)
    num_samples = 2000
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
    mcmc.run(
        rng_key_, marriage=df.MARRIAGESCALED.values, divorce=df.DIVORCESCALED.values
    )
    samples_1 = mcmc.get_samples()

    rng_key, rng_key_ = random.split(rng_key)
    predictive = Predictive(model, samples_1)
    predictions = predictive(rng_key_, marriage=df.MARRIAGESCALED.values)["obs"]
    #df = df.filter(["LOCATION"])
    df["MEAN_PREDICTIONS"] = jnp.mean(predictions, axis=0)

    run_time = time.time() - start_time
    if run_time < min_seconds:
        time.sleep(min_seconds - run_time)
    return df

In [102]:
%%time
# Getting a baseline local single run
predictions = forecast_function(df)

sample: 100%|██████████| 3000/3000 [00:00<00:00, 4244.91it/s, 3 steps of size 6.91e-01. acc. prob=0.93]


CPU times: user 1.25 s, sys: 96.5 ms, total: 1.35 s
Wall time: 10 s


In [103]:
## Settings for the input, output and UDTF.
input_df = session.table(INPUT_TABLE_NAME)
input_col_names = input_df.columns
input_dtypes = [field.datatype for field in input_df.schema.fields]
vect_udtf_input_dtypes = [T.PandasDataFrameType(input_dtypes)]
vect_udtf_output_schema = T.PandasDataFrameType(
    input_dtypes + [T.FloatType(), T.FloatType(), T.FloatType(), T.FloatType()], input_col_names + ["AGESCALED", "MARRIAGESCALED", "DIVORCESCALED", "MEAN_PREDICTIONS"]
)

In [104]:
%%time
# Defining the UDTF, should not necessarily need to be edited
@F.udtf(output_schema = vect_udtf_output_schema,
     input_types = vect_udtf_input_dtypes,
     name = UDTF_NAME, is_permanent=True, stage_location=MODEL_STAGE_NAME, session=session,
     packages=["pandas","jax","numpyro"], replace=True)
class Forecast:
    def end_partition(self, df):
        
        df.columns = input_col_names # NOTE: In Vectorized udtf you have to put the column names back into the df

        forecast = forecast_function(df)

        yield forecast

CPU times: user 41.2 ms, sys: 15.5 ms, total: 56.7 ms
Wall time: 4.38 s


In [105]:
%%time
# Call the UDTF
numpyro_test = F.table_function(
    f"{DATABASE}.{SCHEMA}.{UDTF_NAME}"
)
model_vect_udtf = input_df.select(
    #*[F.col(col_nm) for col_nm in input_col_names],
    numpyro_test(
            *[F.col(col_nm) for col_nm in input_col_names]
        ).over(partition_by=["TS_ID"], order_by=["LOCATION"])
)

CPU times: user 22.4 ms, sys: 5.01 ms, total: 27.4 ms
Wall time: 346 ms


In [106]:
%%time
# Write to table
model_vect_udtf.write.save_as_table(
    OUTPUT_TABLE_NAME, mode="overwrite"
)

CPU times: user 33.6 ms, sys: 8.12 ms, total: 41.7 ms
Wall time: 4min 16s
