# Scikit-Learn ML using Warehouse Compute via Snowflake Pandas

# Create Session

In [5]:
from config import get_snowpark_session

session = get_snowpark_session()

session_id: 14413643953
version: 1.36.0
database: "_DEV_ANALYTICS"
schema: "ASTAUS"
user: "astaus"


# Load Data

In [6]:
df = (
    session.table(["_dev_analytics", "transaction_db__astaus","transactions"])
    .select(
        "sales_channel",
        "transaction_revenue",
        "transaction_margin"
    )
)
df.show()


------------------------------------------------------------------
|"SALES_CHANNEL"  |"TRANSACTION_REVENUE"  |"TRANSACTION_MARGIN"  |
------------------------------------------------------------------
|web              |500.00                 |50.00                 |
|web              |500.00                 |50.00                 |
|web              |500.00                 |50.00                 |
|web              |500.00                 |50.00                 |
|store            |500.00                 |50.00                 |
|store            |500.00                 |50.00                 |
|store            |500.00                 |50.00                 |
|store            |500.00                 |50.00                 |
------------------------------------------------------------------



# Train Model

In [7]:
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

In [None]:
pd_df = df.to_pandas()

X = pd_df.drop("TRANSACTION_MARGIN", axis=1)
y = pd_df[["TRANSACTION_MARGIN"]]

In [9]:
preprocessor = ColumnTransformer(
    transformers=[
        ("onehot", OneHotEncoder(), ["SALES_CHANNEL"]),
        ("scale", StandardScaler(), ["TRANSACTION_REVENUE"])
    ]
)

pipe = Pipeline(steps=[
    ("preprocess", preprocessor),
    ("linreg", LinearRegression())
])

pipe.fit(X, y) # type: ignore

In [10]:
pipe.score(X, y) # type: ignore

1.0

# Register Model

In [11]:
import pathlib

from snowflake.ml.model.task import Task
from snowflake.ml.registry import Registry

pathlib.PosixPath = pathlib.PurePosixPath

In [12]:
registry = Registry(
    session=session,
    database_name="_dev_analytics",
    schema_name="transaction_db__astaus")

In [13]:
model_ref = registry.log_model(
    pipe,
    comment="Scikit-Model for predicting transaction margin.",
    metrics={},
    task=Task.TABULAR_REGRESSION,
    model_name="sk_margin_prediction",
    version_name="v1",
    sample_input_data=df.drop("transaction_margin"),
    options={
        "relax_version":True
    }
)

Logging model: creating model manifest...:  33%|███▎      | 2/6 [00:02<00:05,  1.37s/it]  

  core.DataType.from_snowpark_type(data_type)


Model logged successfully.: 100%|██████████| 6/6 [00:14<00:00,  2.44s/it]                          


# Inference

In [14]:
results = session.sql(""" --begin-sql

    with mv as model _dev_analytics.transaction_db__astaus.sk_margin_prediction
    select
        transaction_id,
        product_id,
        sales_channel,
        transaction_revenue,
        mv!predict(
            sales_channel,
            transaction_revenue
        )['output_feature_0']::number(12,2) transaction_margin_pred
    from _dev_analytics.transaction_db__astaus.transactions
    ;

""")

results.show()

---------------------------------------------------------------------------------------------------------
|"TRANSACTION_ID"  |"PRODUCT_ID"  |"SALES_CHANNEL"  |"TRANSACTION_REVENUE"  |"TRANSACTION_MARGIN_PRED"  |
---------------------------------------------------------------------------------------------------------
|78654             |1556434       |store            |500.00                 |50.00                      |
|12345             |1556434       |web              |500.00                 |50.00                      |
|12345             |1556434       |web              |500.00                 |50.00                      |
|56789             |1556434       |web              |500.00                 |50.00                      |
|56789             |1556434       |web              |500.00                 |50.00                      |
|99999             |1556434       |store            |500.00                 |50.00                      |
|88888             |1556434       |store      