In [1]:
# Snowpark
import snowflake.snowpark
import snowflake.snowpark.functions as F
from snowflake.snowpark.functions import sproc, udf, udtf, col, call_table_function
from snowflake.snowpark.session import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window
import json

import pandas as pd
import datetime
from prophet import Prophet
from prophet.plot import plot_plotly

with open('creds.json') as f:
    connection_parameters = json.load(f)


In [2]:

session = Session.builder.configs(connection_parameters).create()
session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy', 'joblib', 'cachetools')

The version of package cachetools in the local environment is 5.2.0, which does not fit the criteria for the requirement cachetools. Your UDF might not work when the package version is different between the server and your local environment


In [3]:
session.sql("CREATE OR REPLACE STAGE MODEL").collect()

[Row(status='Stage area MODEL successfully created.')]

### Explore Corn Price history

In [4]:
input_table_name = 'milk_price_daily'
price_df = session.table(input_table_name)

In [5]:
price_df.count()

2299

In [6]:
price_df.describe().to_pandas()

Unnamed: 0,SUMMARY,COMMODITY,COMMODITY_NAME,SYMBOL,CLOSE,HIGH,LOW
0,mean,,,,17.556955,17.606416,17.606416
1,min,KN.AGR5,Class III Milk Futures,DCSc1,11.23,11.45,11.45
2,stddev,,,,3.122225,3.127184,3.127184
3,count,2299,2299,2299,2299.0,2299.0,2299.0
4,max,KN.AGR5,Class III Milk Futures,DCSc1,25.2,25.2,25.2


In [7]:
price_df.limit(100).to_pandas()

Unnamed: 0,COMMODITY,COMMODITY_NAME,SYMBOL,DATE,CLOSE,HIGH,LOW
0,KN.AGR5,Class III Milk Futures,DCSc1,2013-10-22,18.25,18.25,18.25
1,KN.AGR5,Class III Milk Futures,DCSc1,2013-10-23,18.24,18.24,18.24
2,KN.AGR5,Class III Milk Futures,DCSc1,2013-10-24,18.25,18.27,18.27
3,KN.AGR5,Class III Milk Futures,DCSc1,2013-10-25,18.25,18.25,18.25
4,KN.AGR5,Class III Milk Futures,DCSc1,2013-10-28,18.25,18.26,18.26
...,...,...,...,...,...,...,...
95,KN.AGR5,Class III Milk Futures,DCSc1,2014-03-12,22.91,23.30,23.30
96,KN.AGR5,Class III Milk Futures,DCSc1,2014-03-13,23.17,23.23,23.23
97,KN.AGR5,Class III Milk Futures,DCSc1,2014-03-14,23.29,23.30,23.30
98,KN.AGR5,Class III Milk Futures,DCSc1,2014-03-17,23.49,23.61,23.61


### Train Prophet forecasting model

In [8]:
# get data to pandas
train_df = price_df.select('DATE','CLOSE').to_pandas()
train_df.columns = ['ds', 'y']

In [9]:
m = Prophet()
m.fit(train_df)


INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)


Initial log joint probability = -25.5017
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
      99       4670.18     0.0803523       507.149           1           1      111   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     199       4704.16    0.00260017       99.2784      0.7898      0.7898      223   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     299       4715.53    0.00171605       220.268           1           1      356   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     371       4723.18   0.000582173       114.221   6.119e-06       0.001      477  LS failed, Hessian reset 
     399       4727.25    0.00216777       60.4951           1           1      516   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     499       4728.86   7.06144e-05    

<prophet.forecaster.Prophet at 0x7fc7c1e22d00>

In [10]:
future = m.make_future_dataframe(periods=365)
forecast = m.predict(future)

  components = components.append(new_comp)
  components = components.append(new_comp)


In [11]:
forecast.loc[forecast['ds'] > datetime.datetime(2021, 1, 1, 0, 0, 0)]

Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,weekly,weekly_lower,weekly_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
1809,2021-01-04,17.255429,14.546803,18.993261,17.255429,17.255429,-0.379449,-0.379449,-0.379449,0.454796,0.454796,0.454796,-0.834245,-0.834245,-0.834245,0.0,0.0,0.0,16.875980
1810,2021-01-05,17.253715,14.517001,19.089053,17.253715,17.253715,-0.391578,-0.391578,-0.391578,0.460926,0.460926,0.460926,-0.852504,-0.852504,-0.852504,0.0,0.0,0.0,16.862137
1811,2021-01-06,17.252000,14.706706,19.182731,17.252000,17.252000,-0.426407,-0.426407,-0.426407,0.444475,0.444475,0.444475,-0.870881,-0.870881,-0.870881,0.0,0.0,0.0,16.825593
1812,2021-01-07,17.250285,14.660693,19.014572,17.250285,17.250285,-0.433322,-0.433322,-0.433322,0.455632,0.455632,0.455632,-0.888954,-0.888954,-0.888954,0.0,0.0,0.0,16.816963
1813,2021-01-08,17.248571,14.520337,19.006922,17.248571,17.248571,-0.449281,-0.449281,-0.449281,0.457020,0.457020,0.457020,-0.906301,-0.906301,-0.906301,0.0,0.0,0.0,16.799289
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2659,2023-12-08,24.122884,21.178649,27.150131,22.417363,25.859796,0.022328,0.022328,0.022328,0.457020,0.457020,0.457020,-0.434692,-0.434692,-0.434692,0.0,0.0,0.0,24.145212
2660,2023-12-09,24.129656,19.437141,25.406248,22.419214,25.875045,-1.625734,-1.625734,-1.625734,-1.136424,-1.136424,-1.136424,-0.489310,-0.489310,-0.489310,0.0,0.0,0.0,22.503922
2661,2023-12-10,24.136428,19.532361,25.420449,22.418979,25.890295,-1.674262,-1.674262,-1.674262,-1.136424,-1.136424,-1.136424,-0.537838,-0.537838,-0.537838,0.0,0.0,0.0,22.462167
2662,2023-12-11,24.143200,20.994815,26.810953,22.418483,25.907085,-0.125399,-0.125399,-0.125399,0.454796,0.454796,0.454796,-0.580195,-0.580195,-0.580195,0.0,0.0,0.0,24.017801


In [12]:
plot_plotly(m, forecast)

In [13]:
class Forecast: 
      def __init__(self):
            self.n_periods = None
            
      def process(self, n_periods:int):
            future = m.make_future_dataframe(periods=n_periods)
            forecast = m.predict(future)
            
            forecast = forecast[['ds','yhat', 'yhat_upper', 'yhat_lower']]

            for row in forecast.itertuples():
                  yield (row[1].date(), row[2], row[3], row[4])
                  
forecast_udtf = udtf(
      Forecast, 
      name='forecast', 
      input_types=[T.IntegerType()], 
      output_schema=T.StructType([T.StructField('date', T.DateType()), 
                              T.StructField('yhat', T.FloatType()),
                              T.StructField('yhat_upper', T.FloatType()),
                              T.StructField('yhat_lower', T.FloatType())]),
      is_permanent=True,
      packages=['prophet','pandas'],
      stage_location='@model',
      replace=True
)

INFO:snowflake.connector.cursor:query: [ls '@model']
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01a8fe54-0000-e13e-00...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT  *  FROM ( SELECT  *  FROM (information_schema.packages)) WHERE (("LANGUA...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "PACKAGE_NAME", array_agg("VERSION") AS "ARRAY_AGG(VERSION)" FROM ( SELEC...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [PUT 'file:///tmp/placeholder/udf_py_1450781152.zip' '@model/forecast' PARALLEL =...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [CREATE OR REPLACE FUNCTION forecast(arg1 INT) RETURNS TABLE (date DATE,yhat FLOA...]
INFO:snowflake.connector.cursor:query execution done


In [14]:
session.table_function(forecast_udtf(F.lit(30))).filter((F.col("DATE") > F.current_date())).to_pandas()

INFO:snowflake.connector.cursor:query: [SELECT  *  FROM ( SELECT  *  FROM ( TABLE (forecast(30 :: INT) ))) WHERE ("DATE"...]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0,DATE,YHAT,YHAT_UPPER,YHAT_LOWER
0,2022-12-16,21.453846,23.67972,19.144364
1,2022-12-17,19.857643,22.215691,17.524976
2,2022-12-18,19.858294,22.202398,17.691725
3,2022-12-19,21.452913,23.798225,19.143648
4,2022-12-20,21.464517,23.619407,19.206926
5,2022-12-21,21.454944,23.553508,19.086299
6,2022-12-22,21.473737,23.876374,19.137958
7,2022-12-23,21.482909,23.638082,19.254296
8,2022-12-24,19.896843,22.333944,17.710527
9,2022-12-25,19.903326,22.138369,17.571351


### Create Stored Procedure for ongoing trainging of Prophet model

In [15]:
actuals = session.table(input_table_name).select(F.col('DATE'),F.col('CLOSE')).to_pandas() 
actuals.columns = ['ds', 'actual']

forecast.set_index('ds').join(actuals.set_index('ds'), how='left')


INFO:snowflake.connector.cursor:query: [SELECT "DATE", "CLOSE" FROM ( SELECT  *  FROM (milk_price_daily))]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0_level_0,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,weekly,weekly_lower,weekly_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat,actual
ds,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
2013-10-22,20.414031,19.304933,23.978891,20.414031,20.414031,1.221935,1.221935,1.221935,0.460926,0.460926,0.460926,0.761009,0.761009,0.761009,0.0,0.0,0.0,21.635966,18.25
2013-10-23,20.419908,19.334258,23.891992,20.419908,20.419908,1.175435,1.175435,1.175435,0.444475,0.444475,0.444475,0.730961,0.730961,0.730961,0.0,0.0,0.0,21.595343,18.24
2013-10-24,20.425785,19.128253,23.866086,20.425785,20.425785,1.159263,1.159263,1.159263,0.455632,0.455632,0.455632,0.703631,0.703631,0.703631,0.0,0.0,0.0,21.585048,18.25
2013-10-25,20.431661,19.349399,23.807144,20.431661,20.431661,1.136817,1.136817,1.136817,0.457020,0.457020,0.457020,0.679797,0.679797,0.679797,0.0,0.0,0.0,21.568478,18.25
2013-10-28,20.449292,19.329883,23.944360,20.449292,20.449292,1.090134,1.090134,1.090134,0.454796,0.454796,0.454796,0.635339,0.635339,0.635339,0.0,0.0,0.0,21.539426,18.25
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023-12-08,24.122884,21.178649,27.150131,22.417363,25.859796,0.022328,0.022328,0.022328,0.457020,0.457020,0.457020,-0.434692,-0.434692,-0.434692,0.0,0.0,0.0,24.145212,
2023-12-09,24.129656,19.437141,25.406248,22.419214,25.875045,-1.625734,-1.625734,-1.625734,-1.136424,-1.136424,-1.136424,-0.489310,-0.489310,-0.489310,0.0,0.0,0.0,22.503922,
2023-12-10,24.136428,19.532361,25.420449,22.418979,25.890295,-1.674262,-1.674262,-1.674262,-1.136424,-1.136424,-1.136424,-0.537838,-0.537838,-0.537838,0.0,0.0,0.0,22.462167,
2023-12-11,24.143200,20.994815,26.810953,22.418483,25.907085,-0.125399,-0.125399,-0.125399,0.454796,0.454796,0.454796,-0.580195,-0.580195,-0.580195,0.0,0.0,0.0,24.017801,


In [16]:
session.sql('drop procedure train_prophet(int)').collect()

INFO:snowflake.connector.cursor:query: [drop procedure train_prophet(int)]
INFO:snowflake.connector.cursor:query execution done


[Row(status='TRAIN_PROPHET successfully dropped.')]

In [17]:
def train_prophet(session: snowflake.snowpark.Session, n_periods:int) -> str:
    # gether data
    train_df = (session.table(input_table_name)
                .with_column('ROWNUM', F.row_number().over(Window.order_by(F.col('DATE').desc()) ) )
                .filter(F.col('ROWNUM') < n_periods)
                .select('DATE','CLOSE')
                ).to_pandas()
    train_df.columns = ['ds','y']

    #fit model
    m = Prophet()
    m.fit(train_df)

    # register newly trained model as udtf for ongoing inference
    class Forecast: 
        def process(self, n_periods:int):
                future = m.make_future_dataframe(periods=n_periods)
                forecast = m.predict(future)
                
                forecast = forecast[['ds','yhat', 'yhat_upper', 'yhat_lower']]

                for row in forecast.itertuples():
                    yield (row[1].date(), row[2], row[3], row[4])
                    
    forecast_udtf = udtf(
        Forecast, 
        name='forecast', 
        input_types=[T.IntegerType()], 
        output_schema=T.StructType([T.StructField('date', T.DateType()), 
                                    T.StructField('yhat', T.FloatType()),
                                    T.StructField('yhat_upper', T.FloatType()),
                                    T.StructField('yhat_lower', T.FloatType())]),
        is_permanent=True,
        packages=['prophet','pandas'],
        stage_location='@model',
        replace=True
    )


    return "Model trained."
# register function as sproc
train_prophet = sproc(train_prophet, 
                      name='train_prophet', 
                      stage_location='@model', 
                      is_permanent=True, 
                      replace=True, 
                      packages=['prophet', 'snowflake-snowpark-python', 'pandas', 'joblib'])

INFO:snowflake.connector.cursor:query: [ls '@model']
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01a8fe54-0000-e173-00...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT  *  FROM ( SELECT  *  FROM (information_schema.packages)) WHERE (("LANGUA...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "PACKAGE_NAME", array_agg("VERSION") AS "ARRAY_AGG(VERSION)" FROM ( SELEC...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [CREATE OR REPLACE PROCEDURE train_prophet(arg1 BIGINT) RETURNS STRING LANGUAGE P...]
INFO:snowflake.connector.cursor:query execution done


In [18]:
session.sql("call train_prophet(365)").collect()

INFO:snowflake.connector.cursor:query: [call train_prophet(365)]
INFO:snowflake.connector.cursor:query execution done


[Row(TRAIN_PROPHET='Model trained.')]

In [19]:
session.call('train_prophet', 365)

INFO:snowflake.connector.cursor:query: [CALL train_prophet(365 :: INT)]
INFO:snowflake.connector.cursor:query execution done


'Model trained.'

In [20]:
session.sql('select * from table(forecast(30))').to_pandas()

INFO:snowflake.connector.cursor:query: [select * from table(forecast(30))]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0,DATE,YHAT,YHAT_UPPER,YHAT_LOWER
0,2021-07-06,16.435837,17.132498,15.753358
1,2021-07-07,16.468016,17.108788,15.806244
2,2021-07-08,16.470716,17.141840,15.789390
3,2021-07-09,16.461695,17.099632,15.787107
4,2021-07-12,16.441097,17.155198,15.776797
...,...,...,...,...
389,2023-01-07,21.600675,22.351872,20.884637
390,2023-01-08,21.610132,22.366169,20.879423
391,2023-01-09,21.473740,22.181413,20.710655
392,2023-01-10,21.475603,22.160518,20.758560
