In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


In [None]:
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.modeling.preprocessing import KBinsDiscretizer, OneHotEncoder
from snowflake.ml.modeling.impute import SimpleImputer

from snowflake.ml.modeling.compose import ColumnTransformer
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.preprocessing import StandardScaler, OrdinalEncoder
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.xgboost import XGBRegressor

import snowflake.snowpark.functions as F

In [None]:
#!pip install xgboost==2.0.3

from time import time
import random

In [None]:
snowdf = session.table("tpcds_xgboost.demo.feature_store")
snowdf = snowdf.drop(['CA_ZIP','CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
snowdf.show()

In [None]:
## Dropping any null values
from snowflake.snowpark.functions import col, is_null

# Create a filter condition for non-finite values across all columns
non_finite_filter = None

# Iterate over all columns and update the filter condition
for column in snowdf.columns:
    current_filter = is_null(col(column))
    non_finite_filter = current_filter if non_finite_filter is None else (non_finite_filter | current_filter)

# Apply the filter to the DataFrame to exclude rows with any non-finite values
df_filtered = snowdf.filter(~non_finite_filter)


## Clean up cats
def fix_values(columnn):
    return F.upper(F.regexp_replace(F.col(columnn), '[^a-zA-Z0-9]+', '_'))
categorical_cols = ['CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS']
for col in categorical_cols:
    df_filtered = df_filtered.with_column(col, fix_values(col))
    

In [None]:
feature_cols = df_filtered.columns
feature_cols.remove('TOTAL_SALES')
target_col = 'TOTAL_SALES'

snowdf_train, snowdf_test = df_filtered.random_split([0.8, 0.2], seed=82) 
snowdf_train=snowdf_train.limit(1_000)
snowdf_train.count()

In [None]:
 ## Distributed Preprocessing - 25X to 50X faster

numeric_features = ['C_BIRTH_YEAR', 'CD_DEP_COUNT']
numeric_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median'))])

categorical_cols = ['CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS']
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder())
])

preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_cols)
        ])

pipeline = Pipeline(steps=[('preprocessor', preprocessor),('model', XGBRegressor())])

In [None]:
 ## Distributed HyperParameter Optimization
hyper_param = dict(
        model__max_depth=[2,4],
        model__learning_rate=[0.1,0.3],
    )

xg_model = GridSearchCV(
    estimator=pipeline,
    param_grid=hyper_param,
    #cv=5,
    input_cols=numeric_features + categorical_cols,
    label_cols=['TOTAL_SALES'],
    output_cols=["TOTAL_SALES_PREDICT"],
    #verbose=4  ##verbose not working
)

# Fit and Score
xg_model.fit(snowdf_train)
##Takes 25 seconds

In [None]:
session.sql('ALTER SESSION SET USE_CACHED_RESULT=FALSE')

In [None]:
lengths = [1_000_000,5_000_000,10_000_000,25_000_000,50_000_000]
#lengths = [10_000_000,25_000_000,50_000_000]
random.seed(9001)

for i, length in enumerate(lengths):
    seedv = random.randint(1, 1000)
    snowdf_train, snowdf_test = df_filtered.random_split([0.8, 0.2], seed=seedv)  #82
    snowdf_train=snowdf_train.limit(length)
    print (snowdf_train.count())
    init = time()
    xg_model.fit(snowdf_train)
    total_time = (time() - init) / 60
    print(f'total rows: {length} total time: {total_time} seed: {seedv}')
    snowdf_train = session.create_dataframe([1, 2, 3, 4]).to_df("a")
    snowdf_train.show()