# Forecast with Snowflake Cortex

> Generic method template to forecast with Snowflake Cortex

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#| default_exp forecast

In [3]:
#| hide
from nbdev.showdoc import *

In [4]:
#| export
import yaml
import random
import string
import logging
import numpy as np
import streamlit as st
import altair as alt
import pandas as pd

from datetime import datetime
from cortex_forecast.connection import SnowparkConnection


logging.getLogger('snowflake.snowpark').setLevel(logging.WARNING)
 

In [5]:
#| export

class SnowflakeMLForecast(SnowparkConnection):
    def __init__(self, config_file, connection_config=None):
        super().__init__(connection_config=connection_config)
        with open(config_file, 'r') as file:
            self.config = yaml.safe_load(file)
        self.model_name = self._generate_unique_model_name()

    def _generate_unique_model_name(self):
        suffix = ''.join(random.choices(string.ascii_lowercase, k=5))
        timestamp = datetime.now().strftime("%Y%m%d")
        return f"{self.config['model']['name']}_{timestamp}_{suffix}"
    
    def _generate_input_data_sql(self):
        table = self.config['input_data']['table']
        timestamp_col = self.config['input_data']['timestamp_column']
        target_col = self.config['input_data']['target_column']
        exogenous_cols = self.config['input_data'].get('exogenous_columns') or []
        training_days = self.config['forecast_config'].get('training_days')

        columns = [f"TO_TIMESTAMP_NTZ({timestamp_col}) AS {timestamp_col}",
                f"{target_col} AS {target_col}"]

        if exogenous_cols:
            columns.extend(exogenous_cols)
        else:
            columns.append("*")

        sql = f"""
        CREATE OR REPLACE TEMPORARY TABLE {self.model_name}_train AS
        SELECT {', '.join(columns)} EXCLUDE ({timestamp_col}, {target_col})
        FROM {table}
        """

        if training_days:
            sql += f"""
            WHERE TO_TIMESTAMP_NTZ({timestamp_col}) BETWEEN 
            DATEADD(day, -{training_days}, (SELECT MAX({timestamp_col}) FROM {table})) 
            AND 
            (SELECT MAX({timestamp_col}) FROM {table})
            """

        sql += ";"

        print("Generated SQL:")
        print(sql)
        return sql

    def _generate_create_model_sql(self):
        input_data = f"SYSTEM$REFERENCE('{self.config['input_data']['table_type']}', '{self.model_name}_train')"
        timestamp_col = self.config['input_data']['timestamp_column']
        target_col = self.config['input_data']['target_column']
        series_col = self.config['input_data'].get('series_column')
        config_object = self.config['forecast_config'].get('config_object', {})
    
        sql = f"""
        CREATE OR REPLACE SNOWFLAKE.ML.FORECAST {self.model_name}(
            INPUT_DATA => {input_data},
            TIMESTAMP_COLNAME => '{timestamp_col}',
            TARGET_COLNAME => '{target_col}',
        """
        
        if series_col:
            sql += f"""SERIES_COLNAME => '{series_col}',\n"""
        
        config_sql = "{"
        for key, value in config_object.items():
            if isinstance(value, dict):
                nested_config = "{"
                nested_config += ", ".join([f"'{k}': {self._format_value(v)}" for k, v in value.items()])
                nested_config += "}"
                config_sql += f"'{key}': {nested_config}, "
            else:
                config_sql += f"'{key}': {self._format_value(value)}, "
        config_sql = config_sql.rstrip(", ") + "}"

        sql += f"CONFIG_OBJECT => {config_sql},"
        
        sql = sql.rstrip(',')  # Clean up trailing commas
        sql += ")"
        tags = self.config['model'].get('tags')
        comment = self.config['model'].get('comment')
        
        if tags:
            tag_str = ", ".join([f"{k} = '{v}'" for k, v in tags.items()])
            sql += f" WITH TAG ({tag_str})"
        
        if comment:
            sql += f" COMMENT = '{comment}'"
        
        sql += ";"

        print("Generated SQL:")
        print(sql)
        
        return sql

    def create_tags(self):
        """
        Create the necessary tags in Snowflake before running the forecast creation.
        If a tag already exists, it will notify the user instead of raising an error.
        """
        tags = self.config['model'].get('tags')
        if not tags:
            self.display("No tags to create.", content_type="text")
            return

        for tag_name, tag_comment in tags.items():
            create_tag_sql = f"CREATE TAG {tag_name} COMMENT = 'Specifies the {tag_comment.lower()}';"
            try:
                self.display(f"Attempting to create tag: {tag_name}", content_type="text")
                self.run_command(create_tag_sql)
                self.display(f"Tag '{tag_name}' created successfully.", content_type="text")
            except Exception as e:
                if "already exists" in str(e):
                    self.display(f"Tag '{tag_name}' already exists.", content_type="text")
                else:
                    self.display(f"Error creating tag '{tag_name}': {e}", content_type="text")

    def _format_value(self, value):
        """
        Helper function to format values for SQL. Converts Python None to SQL NULL,
        and ensures strings are correctly quoted.
        """
        if value is None:
            return "NULL"
        elif isinstance(value, bool):
            return "TRUE" if value else "FALSE"
        elif isinstance(value, (int, float)):
            return str(value)
        elif isinstance(value, str):
            return f"'{value}'"
        return str(value)

    def _generate_forecast_sql(self):
        try:
            # Fetch values from the configuration
            forecast_days = self.config['forecast_config'].get('forecast_days')
            output_table = self.config['output']['table']
            input_data_table = self.config['forecast_config'].get('table')  # Table to be used for prediction, if any
            config_object = self.config['forecast_config'].get('config_object', {})
            evaluation_config = config_object.get('evaluation_config', {})

            print("Configuration Details:")
            print(f"Forecast Days: {forecast_days}")
            print(f"Output Table: {output_table}")
            print(f"Evaluation Config: {evaluation_config}")
            prediction_interval = evaluation_config.get('prediction_interval', 0.95)
            series_col = self.config['input_data'].get('series_column')
            timestamp_col = self.config['input_data']['timestamp_column']
            sql = f"""
            CREATE OR REPLACE TABLE {output_table} AS
            SELECT 
            """
            if series_col:
                sql += f"series::string as {series_col},\n"

            sql += f"""
                ts AS {timestamp_col},
                CASE WHEN forecast < 0 THEN 0 ELSE forecast END AS forecast,
                CASE WHEN lower_bound < 0 THEN 0 ELSE lower_bound END AS lower_bound,
                CASE WHEN upper_bound < 0 THEN 0 ELSE upper_bound END AS upper_bound
            FROM
                TABLE({self.model_name}!FORECAST(
            """

            # Include the INPUT_DATA if input_data_table is provided in the configuration
            if input_data_table:
                sql += f"""
                INPUT_DATA => SYSTEM$REFERENCE('TABLE', '{input_data_table}'),
                TIMESTAMP_COLNAME => '{timestamp_col}',\n"""

            # Add optional series column; use NULL if series_col is None
            if series_col:
                sql += f"""
                    SERIES_COLNAME => '{series_col}',\n
                """

            sql += f"CONFIG_OBJECT => {{'prediction_interval': {prediction_interval}}}\n"
            
            # Include FORECASTING_PERIODS only if forecast_days is provided
            if forecast_days:
                sql += f", FORECASTING_PERIODS => {forecast_days}"
            
            sql += "));"

            print("Generated Forecast SQL:")
            print(sql)
            return sql

        except KeyError as e:
            print(f"KeyError encountered: {e}")
            raise e


    def run_query(self, query):
        """
        Execute a query and return the result as a Pandas DataFrame.
        """
        df = self.session.sql(query).to_pandas() if self.session else None
        return df

    def run_command(self, query):
        """
        Execute a command and return the result.
        """
        result = self.session.sql(query).collect() if self.session else None
        return result

    def create_and_run_forecast(self):
        self.create_tags()

        print("Step 1/4: Creating training table...")
        self.run_command(self._generate_input_data_sql())

        print("Step 2/4: Creating forecast model...")
        self.run_command(self._generate_create_model_sql())

        print("Step 3/4: Generating forecasts...")
        self.run_command(self._generate_forecast_sql())

        print("Step 4/4: Fetching forecast results...")
        forecast_data = self.run_query(f"SELECT * FROM {self.config['output']['table']} ORDER BY {self.config['input_data']['timestamp_column']}")

        return forecast_data

    def cleanup(self):
        print("Cleaning up temporary tables and models...")
        cleanup_commands = f"""
        DROP TABLE IF EXISTS {self.model_name}_train;
        DROP TABLE IF EXISTS {self.config['output']['table']};
        """
        # DROP MODEL IF EXISTS {self.model_name};

        for command in cleanup_commands.split(';'):
            if command.strip():
                self.run_command(command)

    # Other existing methods...

    def is_streamlit(self):
        """
        Check if the environment is Streamlit.
        """
        try:
            return st._is_running_with_streamlit
        except AttributeError:
            return False

    def display(self, content, content_type="text", **kwargs):
        """
        Display content based on the environment (Streamlit or console).
        """
        if self.is_streamlit():
            if content_type == "text":
                st.write(content)
            elif content_type == "chart":
                st.altair_chart(content, use_container_width=True)
            elif content_type == "dataframe":
                st.write(content)
            elif content_type == "code":
                st.code(content, language=kwargs.get('language', ''))
        else:
            if content_type == "text":
                print(content)
            elif content_type == "chart":
                content.show()
            elif content_type == "dataframe":
                print(content)
            elif content_type == "code":
                print(content)

    def create_visualization(self, df, max_historic_date):
        max_historic_date_rule = alt.Chart(pd.DataFrame({'x': [max_historic_date]})).mark_rule(color='orange', strokeDash=[5, 5]).encode(x='x:T')
        max_historic_date_label = alt.Chart(pd.DataFrame({'x': [max_historic_date], 'label': ['Forecast -->']})).mark_text(
            align='left', baseline='bottom', dx=5, dy=5, fontSize=12
        ).encode(x='x:T', y=alt.value(5), text='label:N')

        line_chart = (
            alt.Chart(df)
            .mark_line(point=True)
            .encode(
                x=alt.X("TS:T", axis=alt.Axis(title="Date")),
                y=alt.Y("Volume:Q"),
                color=alt.Color('Value Type:N', legend=alt.Legend(title="Forecast Type")),
                strokeDash=alt.StrokeDash('Type:N', legend=alt.Legend(title="Data Type"))
            ).properties(
                title={
                    "text": ["Forecast and Historic Volume"], 
                    "subtitle": ["Comparing forecasted volume with historic data"],
                    "color": "black",
                    "subtitleColor": "gray"
                },
                width=800,
                height=400
            )
        )

        return line_chart, max_historic_date_rule, max_historic_date_label

    def generate_forecast_and_visualization(self, forecasting_period, confidence_interval):
        if self.config['input_data'].get('series_column') is None:
            df_forecast = self.session.sql(f"""
                CALL {self.model_name}!FORECAST(
                    FORECASTING_PERIODS => {forecasting_period},
                    CONFIG_OBJECT => {{'prediction_interval': {confidence_interval}}}
                );
            """).collect()
            df_forecast = pd.DataFrame(df_forecast)
            df_actuals = self.load_historic_actuals()
            timestamp_col = self.config['input_data']['timestamp_column']
            target_col = self.config['input_data']['target_column']
            df_actuals = df_actuals.rename(columns={timestamp_col.upper(): 'TS', target_col.upper(): 'FORECAST'})

            try:
                print('Getting historical max date') 
                max_historic_date = df_actuals['TS'].max()
                df_actuals['LOWER_BOUND'] = np.NaN
                df_actuals['UPPER_BOUND'] = np.NaN
                df_actuals['Type'] = 'Historic'
                df_forecast['Type'] = 'Forecast'
                df_combined = pd.concat([df_forecast, df_actuals], ignore_index=True)
                df_combined['LOWER_BOUND'] = np.where(df_combined['LOWER_BOUND'] < 0, 0, df_combined['LOWER_BOUND'])
                df = df_combined.melt(id_vars=['TS', 'Type'], value_vars=['FORECAST', 'LOWER_BOUND', 'UPPER_BOUND'], var_name='Value Type', value_name='Volume')
                df = df.dropna(subset=['Volume'])
                line_chart, max_historic_date_rule, max_historic_date_label = self.create_visualization(df, max_historic_date)
                if self.is_streamlit():
                    st.session_state['chart'] = alt.layer(line_chart, max_historic_date_rule, max_historic_date_label)
                    st.session_state['df'] = df
                else:
                    self.display(alt.layer(line_chart, max_historic_date_rule, max_historic_date_label), content_type="chart")
                    self.display(df, content_type="dataframe")
            except KeyError as e:
                print(f"KeyError encountered: {e}")
        else:
            print("Currently Plotting is not supported with series this is a POC might come back and implement")
        self.show_key_data_aspects()
       

    def show_key_data_aspects(self):
        self.display("Top 10 Feature Importances", content_type="text")
        feature_importance = f"CALL {self.model_name}!EXPLAIN_FEATURE_IMPORTANCE();"
        f_i = self.session.sql(feature_importance).collect()[:10]
        df_fi = pd.DataFrame(f_i)
        df_fi = df_fi.drop(columns=['SERIES'])
        chart = alt.Chart(df_fi).mark_bar().encode(
            x=alt.X('SCORE:Q', title='Feature Importance'),
            y=alt.Y('FEATURE_NAME:N', title='Feature', sort='-x')
        ).properties(
            title="Feature Importance Plot",
            width=600,
            height=300
        )
        self.display(chart, content_type="chart")
        self.display(df_fi, content_type="dataframe")
        
        self.display("Underlying Model Metrics", content_type="text")
        metric_call = f"CALL {self.model_name}!SHOW_EVALUATION_METRICS();"
        metrics = self.session.sql(metric_call).collect()
        metrics = [metric.as_dict() for metric in metrics]
        metrics = pd.DataFrame(metrics)
        metrics = metrics.drop(columns=['SERIES'])
        self.display(metrics, content_type="dataframe")

    # Custom method to load historical data
    def load_historic_actuals(self):
        return self.session.table(self.config['input_data']['table']).to_pandas()

# %%

# Example Useage

In [6]:
#| skip
from snowflake.snowpark.version import VERSION
import os

In [8]:
#| skip
forecast_model = SnowflakeMLForecast(
    config_file='./cortex_forecast/files/yaml/storage_forecast_config.yaml',
    connection_config={
        'user': os.getenv('SNOWFLAKE_USER'),
        'password': os.getenv('SNOWFLAKE_PASSWORD'),
        'account': os.getenv('SNOWFLAKE_ACCOUNT'),
        'database': 'CORTEX',
        'warehouse': 'CORTEX_WH',
        'schema': 'DEV',
        'role': 'CORTEX_USER_ROLE'  # Use the desired role
    }
)

snowflake_environment = forecast_model.session.sql('SELECT current_user(), current_version()').collect()
snowpark_version = VERSION
print('\nConnection Established with the following parameters:')
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(forecast_model.session.get_current_role()))
print('Database                    : {}'.format(forecast_model.session.get_current_database()))
print('Schema                      : {}'.format(forecast_model.session.get_current_schema()))
print('Warehouse                   : {}'.format(forecast_model.session.get_current_warehouse()))
print('Snowflake version           : {}'.format(snowflake_environment[0][1]))
print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0], snowpark_version[1], snowpark_version[2]))

# Create Training Data
training_days = 365
predicted_days = 30

forecast_model.session.sql(f'''CREATE OR REPLACE TABLE storage_usage_train AS
    SELECT 
        TO_TIMESTAMP_NTZ(usage_date) AS usage_date,
        storage_bytes / POWER(1024, 3) AS storage_gb
    FROM 
    (
        SELECT * 
            FROM snowflake.account_usage.storage_usage
            WHERE usage_date < CURRENT_DATE()
    )
    WHERE TO_TIMESTAMP_NTZ(usage_date) > DATEADD(day, -{training_days}, CURRENT_DATE())
''').collect()
forecast_model.session.sql('SELECT * FROM storage_usage_train ORDER BY usage_date DESC LIMIT 10').show()


Connection Established with the following parameters:
User                        : JD_SERVICE_ACCOUNT_ADMIN
Role                        : "CORTEX_USER_ROLE"
Database                    : "CORTEX"
Schema                      : "DEV"
Warehouse                   : "CORTEX_WH"
Snowflake version           : 8.31.2
Snowpark for Python version : 1.19.0
--------------------------------------------
|"USAGE_DATE"         |"STORAGE_GB"        |
--------------------------------------------
|2024-08-21 00:00:00  |263.8299332438037   |
|2024-08-20 00:00:00  |263.74338578339666  |
|2024-08-19 00:00:00  |263.7401140583679   |
|2024-08-18 00:00:00  |263.7366697881371   |
|2024-08-17 00:00:00  |263.73793654982     |
|2024-08-16 00:00:00  |264.9268829189241   |
|2024-08-15 00:00:00  |264.94884450919926  |
|2024-08-14 00:00:00  |274.8103531319648   |
|2024-08-13 00:00:00  |275.6245334902778   |
|2024-08-12 00:00:00  |264.13781207147986  |
--------------------------------------------



In [9]:
#| skip
# Run Forecast
forecast_model = SnowflakeMLForecast(
    config_file='./cortex_forecast/files/yaml/storage_forecast_config.yaml',
    connection_config={
        'user': os.getenv('SNOWFLAKE_USER'),
        'password': os.getenv('SNOWFLAKE_PASSWORD'),
        'account': os.getenv('SNOWFLAKE_ACCOUNT'),
        'database': 'CORTEX',
        'warehouse': 'CORTEX_WH',
        'schema': 'DEV',
        'role': 'CORTEX_USER_ROLE'  # Use the desired role
    }
)
forecast_data = forecast_model.create_and_run_forecast()
forecasting_period = 30
confidence_interval = 0.95
forecast_model.generate_forecast_and_visualization(forecasting_period, confidence_interval)

Attempting to create tag: environment
Tag 'environment' already exists.
Attempting to create tag: team
Tag 'team' already exists.
Step 1/4: Creating training table...
Generated SQL:

        CREATE OR REPLACE TEMPORARY TABLE my_forecast_model_20240822_jvwbh_train AS
        SELECT TO_TIMESTAMP_NTZ(usage_date) AS usage_date, storage_gb AS storage_gb, * EXCLUDE (usage_date, storage_gb)
        FROM storage_usage_train
        
            WHERE TO_TIMESTAMP_NTZ(usage_date) BETWEEN 
            DATEADD(day, -365, (SELECT MAX(usage_date) FROM storage_usage_train)) 
            AND 
            (SELECT MAX(usage_date) FROM storage_usage_train)
            ;
Step 2/4: Creating forecast model...
Generated SQL:

        CREATE OR REPLACE SNOWFLAKE.ML.FORECAST my_forecast_model_20240822_jvwbh(
            INPUT_DATA => SYSTEM$REFERENCE('table', 'my_forecast_model_20240822_jvwbh_train'),
            TIMESTAMP_COLNAME => 'usage_date',
            TARGET_COLNAME => 'storage_gb',
        CONFIG_OBJ

            TS      Type   Value Type      Volume
0   2024-08-22  Forecast     FORECAST  265.244948
1   2024-08-23  Forecast     FORECAST  265.466347
2   2024-08-24  Forecast     FORECAST  265.396539
3   2024-08-25  Forecast     FORECAST  265.399145
4   2024-08-26  Forecast     FORECAST  265.559049
..         ...       ...          ...         ...
513 2024-09-16  Forecast  UPPER_BOUND  289.018949
514 2024-09-17  Forecast  UPPER_BOUND  289.780607
515 2024-09-18  Forecast  UPPER_BOUND  289.955011
516 2024-09-19  Forecast  UPPER_BOUND  290.950464
517 2024-09-20  Forecast  UPPER_BOUND  291.996777

[304 rows x 4 columns]
Top 10 Feature Importances


   RANK                          FEATURE_NAME  SCORE             FEATURE_TYPE
0     1                                  lag7   0.09  derived_from_endogenous
1     2                                 lag91   0.08  derived_from_endogenous
2     3  aggregated_endogenous_trend_features   0.07  derived_from_endogenous
3     4                                 lag14   0.07  derived_from_endogenous
4     5                                 lag21   0.07  derived_from_endogenous
5     6                                 lag42   0.07  derived_from_endogenous
6     7                                 lag28   0.06  derived_from_endogenous
7     8                                 lag49   0.06  derived_from_endogenous
8     9                                 lag56   0.06  derived_from_endogenous
9    10                                   day   0.05   derived_from_timestamp
Underlying Model Metrics
             ERROR_METRIC  METRIC_VALUE  STANDARD_DEVIATION  LOGS
0                     MAE        12.399            

In [7]:
#| hide
import nbdev; nbdev.nbdev_export()