In [1]:
# Snowpark
import snowflake.snowpark
import snowflake.snowpark.functions as F
from snowflake.snowpark.functions import sproc, udf, udtf, col, call_table_function
from snowflake.snowpark.session import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.window import Window
import json

import pandas as pd
import datetime
from prophet import Prophet
from prophet.plot import plot_plotly

with open('creds.json') as f:
    connection_parameters = json.load(f)


In [2]:

session = Session.builder.configs(connection_parameters).create()
session.add_packages('snowflake-snowpark-python', 'scikit-learn', 'pandas', 'numpy', 'joblib', 'cachetools')

The version of package cachetools in the local environment is 5.2.0, which does not fit the criteria for the requirement cachetools. Your UDF might not work when the package version is different between the server and your local environment


In [3]:
session.sql("CREATE OR REPLACE STAGE MODEL").collect()

[Row(status='Stage area MODEL successfully created.')]

### Explore Corn Price history

In [4]:
corn_df = session.table('corn_price_daily')

In [5]:
corn_df.count()

11128

In [6]:
corn_df.describe().to_pandas()

Unnamed: 0,SUMMARY,COMMODITY,COMMODITY_NAME,SYMBOL,VALUE,UNITS
0,max,KN.AGR42,US Corn Futures,ZC,831.25,USD per 1 Bushel
1,stddev,,,,146.407447,
2,min,KN.AGR42,US Corn Futures,ZC,142.75,USD per 1 Bushel
3,count,11128,11128,11128,11128.0,11128
4,mean,,,,340.053779,


In [7]:
corn_df.limit(100).to_pandas()

Unnamed: 0,COMMODITY,COMMODITY_NAME,SYMBOL,DATE,VALUE,UNITS
0,KN.AGR42,US Corn Futures,ZC,1979-12-27,289.25,USD per 1 Bushel
1,KN.AGR42,US Corn Futures,ZC,1979-12-28,291.25,USD per 1 Bushel
2,KN.AGR42,US Corn Futures,ZC,1979-12-31,289.20,USD per 1 Bushel
3,KN.AGR42,US Corn Futures,ZC,1980-01-02,286.50,USD per 1 Bushel
4,KN.AGR42,US Corn Futures,ZC,1980-01-03,286.50,USD per 1 Bushel
...,...,...,...,...,...,...
95,KN.AGR42,US Corn Futures,ZC,1980-05-15,273.25,USD per 1 Bushel
96,KN.AGR42,US Corn Futures,ZC,1980-05-16,271.25,USD per 1 Bushel
97,KN.AGR42,US Corn Futures,ZC,1980-05-19,272.00,USD per 1 Bushel
98,KN.AGR42,US Corn Futures,ZC,1980-05-20,271.75,USD per 1 Bushel


### Train Prophet forecasting model

In [8]:
# get data to pandas
train_df = corn_df.select('DATE','VALUE').to_pandas()
train_df.columns = ['ds', 'y']

In [9]:
m = Prophet()
m.fit(train_df)


INFO:prophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
  components = components.append(new_comp)


Initial log joint probability = -266.764
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
      99         19047     0.0617229       1856.89      0.9647      0.9647      120   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     199       20358.6      0.032306       1565.39      0.4455           1      232   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     299       20951.3     0.0259927       2714.52           1           1      337   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     399       21265.8     0.0472744       1904.33           1           1      447   
    Iter      log prob        ||dx||      ||grad||       alpha      alpha0  # evals  Notes 
     499       21656.3     0.0046457       468.828           1           1      557   
    Iter      log prob        ||dx||      ||grad||       alpha  

<prophet.forecaster.Prophet at 0x7fd94a06b1c0>

In [10]:
future = m.make_future_dataframe(periods=365)
forecast = m.predict(future)

  components = components.append(new_comp)
  components = components.append(new_comp)


In [11]:
forecast.loc[forecast['ds'] > datetime.datetime(2021, 1, 1, 0, 0, 0)]

Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,weekly,weekly_lower,weekly_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
10575,2021-01-04,506.929927,424.763838,583.257580,506.929927,506.929927,-1.941577,-1.941577,-1.941577,1.087242,1.087242,1.087242,-3.028818,-3.028818,-3.028818,0.0,0.0,0.0,504.988350
10576,2021-01-05,507.019743,427.447497,583.110421,507.019743,507.019743,-1.427423,-1.427423,-1.427423,1.376658,1.376658,1.376658,-2.804082,-2.804082,-2.804082,0.0,0.0,0.0,505.592320
10577,2021-01-06,507.109560,427.312271,579.863542,507.109560,507.109560,-1.123614,-1.123614,-1.123614,1.479052,1.479052,1.479052,-2.602665,-2.602665,-2.602665,0.0,0.0,0.0,505.985946
10578,2021-01-07,507.199376,433.856338,582.787013,507.199376,507.199376,-1.467316,-1.467316,-1.467316,0.956681,0.956681,0.956681,-2.423997,-2.423997,-2.423997,0.0,0.0,0.0,505.732061
10579,2021-01-08,507.289193,431.123845,578.809075,507.289193,507.289193,-0.752807,-0.752807,-0.752807,1.514418,1.514418,1.514418,-2.267225,-2.267225,-2.267225,0.0,0.0,0.0,506.536386
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11488,2023-12-09,602.943804,488.097965,665.308309,585.103293,625.437695,-27.758019,-27.758019,-27.758019,-14.139626,-14.139626,-14.139626,-13.618393,-13.618393,-13.618393,0.0,0.0,0.0,575.185785
11489,2023-12-10,603.033621,518.155290,684.640239,585.123556,625.721939,-5.576856,-5.576856,-5.576856,7.725575,7.725575,7.725575,-13.302431,-13.302431,-13.302431,0.0,0.0,0.0,597.456764
11490,2023-12-11,603.123437,509.565128,675.357348,585.143818,626.005727,-11.873573,-11.873573,-11.873573,1.087242,1.087242,1.087242,-12.960814,-12.960814,-12.960814,0.0,0.0,0.0,591.249864
11491,2023-12-12,603.213254,509.784754,675.760091,585.164081,626.281754,-11.219084,-11.219084,-11.219084,1.376658,1.376658,1.376658,-12.595742,-12.595742,-12.595742,0.0,0.0,0.0,591.994170


In [12]:
plot_plotly(m, forecast)

In [13]:
class Forecast: 
      def __init__(self):
            self.n_periods = None
            
      def process(self, n_periods:int):
            future = m.make_future_dataframe(periods=n_periods)
            forecast = m.predict(future)
            
            forecast = forecast[['ds','yhat', 'yhat_upper', 'yhat_lower']]

            for row in forecast.itertuples():
                  yield (row[1].date(), row[2], row[3], row[4])
                  
forecast_udtf = udtf(
      Forecast, 
      name='forecast', 
      input_types=[T.IntegerType()], 
      output_schema=T.StructType([T.StructField('date', T.DateType()), 
                              T.StructField('yhat', T.FloatType()),
                              T.StructField('yhat_upper', T.FloatType()),
                              T.StructField('yhat_lower', T.FloatType())]),
      is_permanent=True,
      packages=['prophet','pandas'],
      stage_location='@model',
      replace=True
)

INFO:snowflake.connector.cursor:query: [ls '@model']
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01a8f7e7-0000-df07-00...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT  *  FROM ( SELECT  *  FROM (information_schema.packages)) WHERE (("LANGUA...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "PACKAGE_NAME", array_agg("VERSION") AS "ARRAY_AGG(VERSION)" FROM ( SELEC...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [PUT 'file:///tmp/placeholder/udf_py_1082791944.zip' '@model/forecast' PARALLEL =...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [CREATE OR REPLACE FUNCTION forecast(arg1 INT) RETURNS TABLE (date DATE,yhat FLOA...]
INFO:snowflake.connector.cursor:query execution done


In [14]:
session.table_function(forecast_udtf(F.lit(30))).filter((F.col("DATE") > F.current_date())).to_pandas()

INFO:snowflake.connector.cursor:query: [SELECT  *  FROM ( SELECT  *  FROM ( TABLE (forecast(30 :: INT) ))) WHERE ("DATE"...]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0,DATE,YHAT,YHAT_UPPER,YHAT_LOWER
0,2022-12-15,560.381053,635.723097,480.167783
1,2022-12-16,561.467158,633.761627,483.800398
2,2022-12-17,546.352333,627.845392,471.37248
3,2022-12-18,568.765087,644.355614,484.600814
4,2022-12-19,562.680113,636.782813,487.650945
5,2022-12-20,563.526165,644.854246,489.352148
6,2022-12-21,564.185949,638.802963,488.799328
7,2022-12-22,564.219231,646.180081,486.73794
8,2022-12-23,565.328424,646.672481,488.621278
9,2022-12-24,550.219239,623.205246,474.531501


### Create Stored Procedure for ongoing trainging of Prophet model

In [15]:
actuals = session.table('corn_price_daily').select(F.col('DATE'),F.col('VALUE')).to_pandas() 
actuals.columns = ['ds', 'actual']

forecast.set_index('ds').join(actuals.set_index('ds'), how='left')


INFO:snowflake.connector.cursor:query: [SELECT "DATE", "VALUE" FROM ( SELECT  *  FROM (corn_price_daily))]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0_level_0,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,weekly,weekly_lower,weekly_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat,actual
ds,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
1979-12-27,309.967589,225.178579,383.209380,309.967589,309.967589,-4.993774,-4.993774,-4.993774,0.956681,0.956681,0.956681,-5.950455,-5.950455,-5.950455,0.0,0.0,0.0,304.973815,289.25
1979-12-28,309.955250,229.186504,383.792535,309.955250,309.955250,-4.025625,-4.025625,-4.025625,1.514418,1.514418,1.514418,-5.540043,-5.540043,-5.540043,0.0,0.0,0.0,305.929625,291.25
1979-12-31,309.918234,231.439974,387.025879,309.918234,309.918234,-3.328009,-3.328009,-3.328009,1.087242,1.087242,1.087242,-4.415250,-4.415250,-4.415250,0.0,0.0,0.0,306.590226,289.20
1980-01-02,309.893557,225.266577,376.393765,309.893557,309.893557,-2.289328,-2.289328,-2.289328,1.479052,1.479052,1.479052,-3.768380,-3.768380,-3.768380,0.0,0.0,0.0,307.604229,286.50
1980-01-03,309.881219,227.093228,384.105890,309.881219,309.881219,-2.522233,-2.522233,-2.522233,0.956681,0.956681,0.956681,-3.478915,-3.478915,-3.478915,0.0,0.0,0.0,307.358985,286.50
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023-12-09,602.943804,488.097965,665.308309,585.103293,625.437695,-27.758019,-27.758019,-27.758019,-14.139626,-14.139626,-14.139626,-13.618393,-13.618393,-13.618393,0.0,0.0,0.0,575.185785,
2023-12-10,603.033621,518.155290,684.640239,585.123556,625.721939,-5.576856,-5.576856,-5.576856,7.725575,7.725575,7.725575,-13.302431,-13.302431,-13.302431,0.0,0.0,0.0,597.456764,
2023-12-11,603.123437,509.565128,675.357348,585.143818,626.005727,-11.873573,-11.873573,-11.873573,1.087242,1.087242,1.087242,-12.960814,-12.960814,-12.960814,0.0,0.0,0.0,591.249864,
2023-12-12,603.213254,509.784754,675.760091,585.164081,626.281754,-11.219084,-11.219084,-11.219084,1.376658,1.376658,1.376658,-12.595742,-12.595742,-12.595742,0.0,0.0,0.0,591.994170,


In [16]:
train_df = (session.table('corn_price_daily')
            .with_column('ROWNUM', F.row_number().over(Window.order_by(F.col('DATE').desc()) ) )
            .filter(F.col('ROWNUM') < 10)
            .select('DATE','VALUE')
           ).to_pandas()
train_df

INFO:snowflake.connector.cursor:query: [SELECT  *  FROM (corn_price_daily)]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "DATE", "VALUE" FROM ( SELECT  *  FROM ( SELECT "COMMODITY", "COMMODITY_N...]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0,DATE,VALUE
0,2022-12-13,654.88
1,2022-12-12,641.25
2,2022-12-09,634.75
3,2022-12-08,642.5
4,2022-12-07,641.25
5,2022-12-06,637.25
6,2022-12-05,640.5
7,2022-12-02,646.25
8,2022-12-01,650.0


In [28]:
session.sql('drop procedure train_prophet(int)').collect()

INFO:snowflake.connector.cursor:query: [drop procedure train_prophet(int)]
INFO:snowflake.connector.cursor:query execution done


[Row(status='TRAIN_PROPHET successfully dropped.')]

In [32]:
def train_prophet(session: snowflake.snowpark.Session, n_periods:int) -> str:
    # gether data
    train_df = (session.table('corn_price_daily')
                .with_column('ROWNUM', F.row_number().over(Window.order_by(F.col('DATE').desc()) ) )
                .filter(F.col('ROWNUM') < n_periods)
                .select('DATE','VALUE')
                ).to_pandas()
    train_df.columns = ['ds','y']

    #fit model
    m = Prophet()
    m.fit(train_df)

    # register newly trained model as udtf for ongoing inference
    class Forecast: 
        def process(self, n_periods:int):
                future = m.make_future_dataframe(periods=n_periods)
                forecast = m.predict(future)
                
                forecast = forecast[['ds','yhat', 'yhat_upper', 'yhat_lower']]

                for row in forecast.itertuples():
                    yield (row[1].date(), row[2], row[3], row[4])
                    
    forecast_udtf = udtf(
        Forecast, 
        name='forecast', 
        input_types=[T.IntegerType()], 
        output_schema=T.StructType([T.StructField('date', T.DateType()), 
                                    T.StructField('yhat', T.FloatType()),
                                    T.StructField('yhat_upper', T.FloatType()),
                                    T.StructField('yhat_lower', T.FloatType())]),
        is_permanent=True,
        packages=['prophet','pandas'],
        stage_location='@model',
        replace=True
    )


    return "Model trained."
# register function as sproc
train_prophet = sproc(train_prophet, 
                      name='train_prophet', 
                      stage_location='@model', 
                      is_permanent=True, 
                      replace=True, 
                      packages=['prophet', 'snowflake-snowpark-python', 'pandas', 'joblib'])

INFO:snowflake.connector.cursor:query: [ls '@model']
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01a8f810-0000-de52-00...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT  *  FROM ( SELECT  *  FROM (information_schema.packages)) WHERE (("LANGUA...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [SELECT "PACKAGE_NAME", array_agg("VERSION") AS "ARRAY_AGG(VERSION)" FROM ( SELEC...]
INFO:snowflake.connector.cursor:query execution done
INFO:snowflake.connector.cursor:query: [CREATE OR REPLACE PROCEDURE train_prophet(arg1 BIGINT) RETURNS STRING LANGUAGE P...]
INFO:snowflake.connector.cursor:query execution done


In [33]:
session.sql("call train_prophet(365)").collect()

INFO:snowflake.connector.cursor:query: [call train_prophet(365)]
INFO:snowflake.connector.cursor:query execution done


[Row(TRAIN_PROPHET='Model trained.')]

In [34]:
session.call('train_prophet', 365)

INFO:snowflake.connector.cursor:query: [CALL train_prophet(365 :: INT)]
INFO:snowflake.connector.cursor:query execution done


'Model trained.'

In [35]:
session.sql('select * from table(forecast(30))').to_pandas()

INFO:snowflake.connector.cursor:query: [select * from table(forecast(30))]
INFO:snowflake.connector.cursor:query execution done


Unnamed: 0,DATE,YHAT,YHAT_UPPER,YHAT_LOWER
0,2021-08-29,501.224411,532.896270,471.563992
1,2021-08-30,506.383947,535.394252,476.201088
2,2021-08-31,507.332986,539.309776,476.514081
3,2021-09-01,509.170293,539.943173,479.047029
4,2021-09-02,509.115726,539.676885,480.474777
...,...,...,...,...
389,2023-01-08,667.724367,700.170862,633.172999
390,2023-01-09,672.158016,707.032804,638.680016
391,2023-01-10,672.381169,706.647560,640.429165
392,2023-01-11,673.492589,708.282518,640.389476
