### Training a tensorflow Classification model on 1 million row training set using UDTF (user defined table function)
### Steps Followed
##### 1. Import Snowpark libraries
##### 2. Connect to snowflake
##### 3. Define model stage
##### 4. Define output schema
##### 5. Define UDTF and class for UDTF
##### 6. Load data to snowpark dataframe
##### 7. Define and build parameter dataframe with Epochs and Batch_Size
##### 8. Combine data and parameter dataframes (Cross Join)
##### 9. Train using the UDTF defined in step 5

#### 1. Import Snowpark libraries

In [None]:
import pandas as pd
import numpy as np
import joblib
import os
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from snowflake.snowpark.functions import (
    udtf,
    col,
    lit,
    row_number,
    table_function,
)
from snowflake.snowpark.types import (
    Variant,
    IntegerType,
    BooleanType,
    FloatType,
    StringType,
    DoubleType,
    BooleanType,
    DateType,
    StructType,
    StructField,
    LongType,
    DecimalType,
)
from functools import reduce
from snowflake.snowpark.window import Window
import json
from snowflake.snowpark import Session

### 2. Connect to snowflake


In [18]:
snowflake_connection_cfg = open('cred.json')
snowflake_connection_cfg = snowflake_connection_cfg.read()
snowflake_connection_cfg = json.loads(snowflake_connection_cfg)

'''
APP_WH XS
LAB_WH S
HMWH M optimized warehouse
DCR_MA_WH L
BANK1_WH XL
'''

# Creating Snowpark Session
staples_tf_session = Session.builder.configs(snowflake_connection_cfg).create()
print('Current Database:', staples_tf_session.get_current_database())
print('Current Schema:', staples_tf_session.get_current_schema())
print('Current Warehouse:', staples_tf_session.get_current_warehouse())
print("Warehouse set up:")
staples_tf_session.sql("show warehouses like 'APP_WH'").collect()

Current Database: "BANK1_CRM_DB"
Current Schema: "PUBLIC"
Current Warehouse: "APP_WH"
Warehouse set up:


[Row(name='APP_WH', state='STARTED', type='STANDARD', size='X-Small', min_cluster_count=1, max_cluster_count=1, started_clusters=1, running=0, queued=0, is_default='N', is_current='Y', auto_suspend=600, auto_resume='true', available=' 100', provisioning='0', quiescing='0', other='0', created_on=datetime.datetime(2022, 2, 27, 4, 51, 57, 85000, tzinfo=<DstTzInfo 'America/Los_Angeles' PST-1 day, 16:00:00 STD>), resumed_on=datetime.datetime(2022, 12, 30, 21, 27, 7, 83000, tzinfo=<DstTzInfo 'America/Los_Angeles' PST-1 day, 16:00:00 STD>), updated_on=datetime.datetime(2022, 12, 30, 21, 27, 7, 83000, tzinfo=<DstTzInfo 'America/Los_Angeles' PST-1 day, 16:00:00 STD>), owner='SYSADMIN', comment='', enable_query_acceleration='false', query_acceleration_max_scale_factor=8, resource_monitor='null', actives=1, pendings=0, failed=0, suspended=0, uuid='1463550724', scaling_policy='STANDARD')]

### 3. Define model stage

In [None]:
staples_tf_session.sql(
    """
create or replace stage udtfmodels
"""
).collect()

### 4. Define output schema

In [19]:
schema = StructType(
    [
        StructField("EPOCH", IntegerType()),
        StructField("BATCH_SIZE", IntegerType()),
        StructField("Accuracy", FloatType()),
        StructField("Checkpoint", StringType()),
    ]
)

### 5. Define UDTF and class for UDTF

In [20]:
@udtf(
    output_schema=schema,
    input_types=[
        IntegerType(),
        IntegerType(),
        DoubleType(),
        DoubleType(),
        DoubleType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        FloatType(),
        IntegerType(),
    ],
    name="hyperparameter_staples_tf_tuning",
    session=staples_tf_session,
    is_permanent=True,
    stage_location="@udtfmodels",
    packages=["snowflake-snowpark-python", "pandas", "scikit-learn","tensorflow","dill","joblib"],
    replace=True,
)
class forecast:
    #Initializes state for stateful processing of input partitions
    def __init__(self):
        self.EPOCH = None
        self.BATCH_SIZE = None
        self.RECENCY_DAY = []
        self.FREQUENCY = []
        self.MONETORY = []
        self.RMF_SCORE = []
        self.DOTCOM = []
        self.REWARDS_ACCOUNT = []
        self.FREQ_1 = []
        self.FREQ_2 = []
        self.FREQ_3 = []
        self.FREQ_4 = []
        self.FREQ_5 = []
        self.FREQ_6 = []
        self.FREQ_7 = []
        self.FREQ_8 = []
        self.FREQ_9 = []
        self.FREQ_10 = []
        self.FREQ_11 = []
        self.FREQ_12 = []
        self.CNT_PER_PDT = []
        self.CNT_PER_PDT_SFC = []
        self.CNT_PER_PDT_VFC = []
        self.NO_DISCOUNT = []
        self.DISCOUNT_PROMOTION = []
        self.LABEL = []
        self.processedFirstRow = False
    # Processes each input row, returning a tabular value as tuples. Snowflake invokes this method, passing input from the UDTF's arguments.
    def process(
        self,
        EPOCH,
        BATCH_SIZE,
        RECENCY_DAY,
        FREQUENCY,
        MONETORY,
        RMF_SCORE,
        DOTCOM,
        REWARDS_ACCOUNT,
        FREQ_1,
        FREQ_2,
        FREQ_3,
        FREQ_4,
        FREQ_5,
        FREQ_6,
        FREQ_7,
        FREQ_8,
        FREQ_9,
        FREQ_10,
        FREQ_11,
        FREQ_12,
        CNT_PER_PDT,
        CNT_PER_PDT_SFC,
        CNT_PER_PDT_VFC,
        NO_DISCOUNT,
        DISCOUNT_PROMOTION,
        LABEL,
    ):
    # We are telling the UDTF only to run 288 models based on the unique combination of hyperparameters rather than every record in our data set.
        if not self.processedFirstRow:
            self.EPOCH = EPOCH
            self.BATCH_SIZE = BATCH_SIZE
            self.processedFirstRow = True
        self.RECENCY_DAY.append(RECENCY_DAY)
        self.FREQUENCY.append(FREQUENCY)
        self.MONETORY.append(MONETORY)
        self.RMF_SCORE.append(RMF_SCORE)
        self.DOTCOM.append(DOTCOM)
        self.REWARDS_ACCOUNT.append(REWARDS_ACCOUNT)
        self.FREQ_1.append(FREQ_1)
        self.FREQ_2.append(FREQ_2)
        self.FREQ_3.append(FREQ_3)
        self.FREQ_4.append(FREQ_4)
        self.FREQ_5.append(FREQ_5)
        self.FREQ_6.append(FREQ_6)
        self.FREQ_7.append(FREQ_7)
        self.FREQ_8.append(FREQ_8)
        self.FREQ_9.append(FREQ_9)
        self.FREQ_10.append(FREQ_10)
        self.FREQ_11.append(FREQ_11)
        self.FREQ_12.append(FREQ_12)
        self.CNT_PER_PDT.append(CNT_PER_PDT)
        self.CNT_PER_PDT_SFC.append(CNT_PER_PDT_SFC)
        self.CNT_PER_PDT_VFC.append(CNT_PER_PDT_VFC)
        self.NO_DISCOUNT.append(NO_DISCOUNT)
        self.DISCOUNT_PROMOTION.append(DISCOUNT_PROMOTION)
        self.LABEL.append(LABEL)

    def end_partition(self):
    # Finalizes processing of input partitions, returning a tabular value as tuples.
    # Together the process is used to add the rows of a partition into the x1, x2 … lists, and then 
    # end_partition is called when all rows have been processed, and we then can train the model based on the hyperparameters in the first row

        df = pd.DataFrame(
            zip(self.RECENCY_DAY,self.FREQUENCY,self.MONETORY,self.RMF_SCORE,self.DOTCOM,self.REWARDS_ACCOUNT,self.FREQ_1,self.FREQ_2,self.FREQ_3,self.FREQ_4,self.FREQ_5,self.FREQ_6,self.FREQ_7,self.FREQ_8,self.FREQ_9,self.FREQ_10,self.FREQ_11,self.FREQ_12,self.CNT_PER_PDT,self.CNT_PER_PDT_SFC,self.CNT_PER_PDT_VFC,self.NO_DISCOUNT,self.DISCOUNT_PROMOTION,self.LABEL),
            columns=["RECENCY_DAY","FREQUENCY","MONETORY","RMF_SCORE","DOTCOM","REWARDS_ACCOUNT","FREQ_1","FREQ_2","FREQ_3","FREQ_4","FREQ_5","FREQ_6","FREQ_7","FREQ_8","FREQ_9","FREQ_10","FREQ_11","FREQ_12","CNT_PER_PDT","CNT_PER_PDT_SFC","CNT_PER_PDT_VFC","NO_DISCOUNT","DISCOUNT_PROMOTION","LABEL"],
            # zip(self.RECENCY_DAY,self.FREQUENCY,self.MONETORY,self.LABEL),
            # columns=["RECENCY_DAY","FREQUENCY","MONETORY","LABEL"],

        )

        dfx = df.loc[:, df.columns != "LABEL"]
        dfy = df.loc[:, df.columns == "LABEL"]

        X_train, X_test, y_train, y_test = train_test_split(
            dfx, dfy, test_size=0.25, random_state=43
        )
        import tensorflow as tf
        import os
        from tensorflow.keras.models import Sequential
        from tensorflow.keras.layers import Dense, Conv2D , MaxPool2D , Flatten , Dropout , BatchNormalization
        from tensorflow.keras import datasets, layers, models
        from tensorflow.keras import callbacks
        
        clf = Sequential()
        clf.add(Dense(units = 10 , activation = 'relu'))
        clf.add(Dropout(0.2))
        clf.add(Dense(units = 1 , activation = 'sigmoid')) # Tanh
        clf.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=["accuracy"])

        clf.fit(X_train, y_train.values.ravel(), 
                       batch_size=self.BATCH_SIZE, 
                       epochs=self.EPOCH)
        
        y_pred = clf.predict(X_test)
        score = clf.evaluate(X_test, y_test, verbose=0)
        
#         model_dir = '@udtfmodels'
        
#         model_file = os.path.join('/tmp',str(self.EPOCH), 'keras_model.h5')
#         clf.save(model_file)
        
        filepath = "saved-model-{epoch:02d}-{val_acc:.2f}.hdf5"
        checkpoint = callbacks.ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=False, mode='max')        
        #session.file.put(checkpoint, modeldir,overwrite=True)

        yield (
            self.EPOCH,
            self.BATCH_SIZE,
            score[1], #adds accuracy to the training
            str(checkpoint)
        )

The version of package tensorflow in the local environment is 2.10.0, which does not fit the criteria for the requirement tensorflow. Your UDF might not work when the package version is different between the server and your local environment
package dill is not installed in the local environmentYour UDF might not work when the package is installed on the server but not on your local environment.
The version of package joblib in the local environment is 1.2.0, which does not fit the criteria for the requirement joblib. Your UDF might not work when the package version is different between the server and your local environment


In [None]:
### Define the features

In [21]:
features=["RECENCY_DAY",
"FREQUENCY",
"MONETORY",
"RMF_SCORE",
"DOTCOM",
"REWARDS_ACCOUNT",
"FREQ_1",
"FREQ_2",
"FREQ_3",
"FREQ_4",
"FREQ_5",
"FREQ_6",
"FREQ_7",
"FREQ_8",
"FREQ_9",
"FREQ_10",
"FREQ_11",
"FREQ_12",
"CNT_PER_PDT",
"CNT_PER_PDT_SFC",
"CNT_PER_PDT_VFC",
"NO_DISCOUNT",
"DISCOUNT_PROMOTION"]

### 6. Load data to snowpark dataframe

In [51]:
# staples_tf_session.sql("use warehouse APP_WH").collect()
table_name = 'STAPLES_DATA_TRAIN_1M'
features_df = staples_tf_session.table(table_name).select(*features,'LABEL')
features_df.show()

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"RECENCY_DAY"  |"FREQUENCY"  |"MONETORY"  |"RMF_SCORE"  |"DOTCOM"  |"REWARDS_ACCOUNT"  |"FREQ_1"  |"FREQ_2"  |"FREQ_3"  |"FREQ_4"  |"FREQ_5"  |"FREQ_6"  |"FREQ_7"  |"FREQ_8"  |"FREQ_9"  |"FREQ_10"  |"FREQ_11"  |"FREQ_12"  |"CNT_PER_PDT"  |"CNT_PER_PDT_SFC"  |"CNT_PER_PDT_VFC"  |"NO_DISCOUNT"  |"DISCOUNT_PROMOTION"  |"LABEL"  |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|6.0      

### 7. Define and build parameter dataframe with Epochs and Batch_Size

In [52]:
param_grid = {
    "EPOCH": [x-(x-1) for x in range(1,201)],
    "BATCH_SIZE": [100],
}

In [53]:
dfs = []
for k, v in param_grid.items():
    df = pd.DataFrame(v, columns=[k])
    dfs.append(df)

df = reduce(lambda left, right: pd.merge(left, right, how="cross"), dfs)
params_df = staples_tf_session.createDataFrame(df)

params_df = params_df.select(
    "*", row_number().over(Window.order_by(lit(1))).as_("EPOCH#")
)

In [54]:
params_df.show()

-------------------------------------
|"EPOCH"  |"BATCH_SIZE"  |"EPOCH#"  |
-------------------------------------
|1        |100           |1         |
|1        |100           |2         |
|1        |100           |3         |
|1        |100           |4         |
|1        |100           |5         |
|1        |100           |6         |
|1        |100           |7         |
|1        |100           |8         |
|1        |100           |9         |
|1        |100           |10        |
-------------------------------------



### 8. Combine data and parameter dataframes (Cross Join)

In [55]:
df = params_df.crossJoin(features_df)

df.show()

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"EPOCH"  |"BATCH_SIZE"  |"EPOCH#"  |"RECENCY_DAY"  |"FREQUENCY"  |"MONETORY"  |"RMF_SCORE"  |"DOTCOM"  |"REWARDS_ACCOUNT"  |"FREQ_1"  |"FREQ_2"  |"FREQ_3"  |"FREQ_4"  |"FREQ_5"  |"FREQ_6"  |"FREQ_7"  |"FREQ_8"  |"FREQ_9"  |"FREQ_10"  |"FREQ_11"  |"FREQ_12"  |"CNT_PER_PDT"  |"CNT_PER_PDT_SFC"  |"CNT_PER_PDT_VFC"  |"NO_DISCOUNT"  |"DISCOUNT_PROMOTION"  |"LABEL"  |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

### Use snowpark optimized warehouse

In [43]:
staples_tf_session.sql("use warehouse HMWH").collect()

[Row(status='Statement executed successfully.')]

### 9. Train using the UDTF defined in step 5

In [56]:
tf_TUNING = table_function("hyperparameter_staples_tf_tuning")
tensorflow_training = df.select(
    df["EPOCH#"],
    (
        tf_TUNING(
            df["EPOCH"],
            df["BATCH_SIZE"],
            df["RECENCY_DAY"],
            df["FREQUENCY"],
            df["MONETORY"],
            df["RMF_SCORE"],
            df["DOTCOM"],
            df["REWARDS_ACCOUNT"],
            df["FREQ_1"],
            df["FREQ_2"],
            df["FREQ_3"],
            df["FREQ_4"],
            df["FREQ_5"],
            df["FREQ_6"],
            df["FREQ_7"],
            df["FREQ_8"],
            df["FREQ_9"],
            df["FREQ_10"],
            df["FREQ_11"],
            df["FREQ_12"],
            df["CNT_PER_PDT"],
            df["CNT_PER_PDT_SFC"],
            df["CNT_PER_PDT_VFC"],
            df["NO_DISCOUNT"],
            df["DISCOUNT_PROMOTION"],
            df["LABEL"]
        ).over(partition_by=df["EPOCH#"])
    )
).sort(col('Accuracy').desc())
tensorflow_training.show()

--------------------------------------------------------------------------------------
|"EPOCH#"  |"ACCURACY"          |"CHECKPOINT"                                        |
--------------------------------------------------------------------------------------
|13        |0.7091280221939087  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|32        |0.7064279913902283  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|87        |0.7064080238342285  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|78        |0.7058680057525635  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|72        |0.705839991569519   |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|86        |0.7058159708976746  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|25        |0.7057880163192749  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|24        |0.7057480216026306  |<keras.callbacks.ModelCheckpoint object at 0xff...  |
|99        |0.7055479884147644  |<keras.cal

In [59]:
staples_tf_session.close()
print('Finished!!!')

Finished!!!
