# Snowpark AST Decoder Demo

In [None]:
from snowflake.snowpark import Session
from snowflake.snowpark._internal.utils import set_transmit_query_to_server
from snowflake.snowpark._internal.utils import set_ast_state, AstFlagSource
import logging
from snowflake.snowpark.functions import avg, count, max, min, udaf, udtf
from snowflake.snowpark.types import DoubleType, FloatType, IntegerType, StringType, StructField, StructType
import base64
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

# For displaying images.
from IPython.display import Image
from IPython.core.display import HTML 

# Connecting to my cloud workspace running the server.
CONNECTION_PARAMETERS = {
    "account": "s3testaccount",
    "host": "snowflake.reg.local",
    "user": "snowman",
    "password": "test",
    "role": "sysadmin",
    "warehouse": "regress",
    "database": "testdb",
    "schema": "public",
    "port": "53200",
    "protocol": "http",
}

In [None]:
session = (
    Session.builder.configs(CONNECTION_PARAMETERS)
    .config("local_testing", False)
    .getOrCreate()
)

In [None]:
# Configure logging.
logging.basicConfig(level=logging.INFO)  # Set to INFO or DEBUG for more details
logger = logging.getLogger("snowflake.snowpark")

# Verify logging level.
logger.setLevel(logging.INFO)

## Enabling Required Parameters

In [None]:
# ENABLE_DATAFRAME should be set to True in the account level:
# this parameter controls whether to use the SQL query or AST on the server-side.
session.sql("show parameters like 'ENABLE_DATAFRAME' in account").show()

In [None]:
# The Dataframe Processor relies on a Snowflake Notebook for execution:
session.sql("show notebooks;").show()

In [None]:
# Since the decoder logic uses Python 3.10+ features, ensure that the notebook engine version uses Python 3.10+.
session.sql("show parameters like 'NOTEBOOK_ENGINE_VERSION'").show()

In [None]:
# To record the ASTs, we need to enable the AST flag:
AST_ENABLED = True
set_ast_state(AstFlagSource.TEST, AST_ENABLED)

In [None]:
# This helper (flag setter) controls whether to send the actual Snowpark query or a fake query to the server.
# Transmits "SELECT 'This is a fake query!!';".
set_transmit_query_to_server(False)

## Testing a Basic Snowpark Query

In [None]:
# Testing a basic Snowpark query with the AST enabled:
with session.ast_listener() as al:
    result = session.create_dataframe([1, 2, 3, 4]).collect()
    print(result)

In [None]:
# The recorded AST:
print(al.base64_batches)

In [None]:
# In plaintext:
message = proto.Request()
message.ParseFromString(base64.b64decode(al.base64_batches[0]))
message

## Testing a More Complex Example

In [None]:
# Create a table for diamonds.
session.sql("""
    CREATE OR REPLACE TABLE diamonds (
        'id' INTEGER,
        'carat' FLOAT,
        'cut' STRING,
        'color' STRING,
        'clarity' STRING,
        'depth' FLOAT,
        'table' INTEGER,
        'price' INTEGER,
        'x' FLOAT,
        'y' FLOAT,
        'z' FLOAT
    );
""")

In [None]:
# Table data!
data = [
    [1, 0.23, "Ideal", "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43],
    [2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31],
    [3, 0.23, "Good", "E", "VS1", 56.9, 65, 327, 4.05, 4.07, 2.31],
    [4, 0.29, "Premium", "I", "VS2", 62.4, 58, 334, 4.2, 4.23, 2.63],
    [5, 0.31, "Good", "J", "SI2", 63.3, 58, 335, 4.34, 4.35, 2.75],
    [6, 0.24, "Very Good", "J", "VVS2", 62.8, 57, 336, 3.94, 3.96, 2.48],
    [7, 0.24, "Very Good", "I", "VVS1", 62.3, 57, 336, 3.95, 3.98, 2.47],
    [8, 0.26, "Very Good", "H", "SI1", 61.9, 55, 337, 4.07, 4.11, 2.53],
    [9, 0.22, "Fair", "E", "VS2", 65.1, 61, 337, 3.87, 3.78, 2.49],
    [10, 0.23, "Very Good", "H", "VS1", 59.4, 61, 338, 4.00, 4.05, 2.39],
    [11, 0.3, "Good", "J", "SI1", 64.2, 55, 339, 4.25, 4.28, 2.73],
    [12, 0.23, "Ideal", "J", "VS1", 62.8, 56, 340, 3.93, 3.9, 2.46],
    [13, 0.22, "Premium", "F", "SI1", 60.4, 61, 342, 3.88, 3.84, 2.33],
    [14, 0.31, "Ideal", "J", "SI2", 62.2, 54, 344, 4.35, 4.37, 2.71],
    [15, 0.2, "Premium", "E", "SI2", 60.2, 62, 345, 3.79, 3.75, 2.27],
    [16, 0.32, "Premium", "E", "I1", 60.9, 58, 345, 4.38, 4.42, 2.68],
    [17, 0.3, "Ideal", "I", "SI2", 62.5, 54, 348, 4.31, 4.34, 2.68],
    [18, 0.3, "Good", "J", "SI1", 63.4, 54, 351, 4.23, 4.29, 2.7],
    [19, 0.3, "Good", "J", "SI1", 63.8, 56, 351, 4.23, 4.26, 2.71],
    [20, 0.3, "Very Good", "J", "SI1", 62.7, 59, 351, 4.21, 4.27, 2.66],
    [21, 0.3, "Good", "I", "SI2", 63.3, 56, 351, 4.26, 4.3, 2.71],
    [22, 0.23, "Very Good", "E", "VS2", 63.8, 55, 352, 3.85, 3.92, 2.48],
    [23, 0.23, "Very Good", "H", "VS1", 61.0, 57, 353, 3.94, 3.96, 2.41],
    [24, 0.31, "Very Good", "J", "SI1", 59.4, 62, 353, 4.39, 4.43, 2.62],
    [25, 0.31, "Very Good", "J", "SI1", 58.1, 62, 353, 4.44, 4.47, 2.59],
    [26, 0.23, "Very Good", "G", "VVS2", 60.4, 58, 354, 3.97, 4.01, 2.41],
    [27, 0.24, "Premium", "I", "VS1", 62.5, 57, 355, 3.97, 3.94, 2.47],
    [28, 0.3, "Very Good", "J", "VS2", 62.2, 57, 357, 4.28, 4.3, 2.67],
    [29, 0.23, "Very Good", "D", "VS2", 60.5, 61, 357, 3.96, 3.97, 2.4]
]

# Define the table schema.
schema = StructType([
    StructField("id", IntegerType()),
    StructField("carat", DoubleType()),
    StructField("cut", StringType()),
    StructField("color", StringType()),
    StructField("clarity", StringType()),
    StructField("depth", DoubleType()),
    StructField("table", IntegerType()),
    StructField("price", IntegerType()),
    StructField("x", DoubleType()),
    StructField("y", DoubleType()),
    StructField("z", DoubleType())
])

In [None]:
# Creating the dataframe.
df = session.create_dataframe(data, schema=schema)
df.limit(5).collect()

In [None]:
# Writing the dataframe to a table.
df.write.save_as_table("t_diamonds", mode="overwrite")

## Testing some operations

In [None]:
# Count of diamonds per cut.
with session.ast_listener() as al:
    df_count_per_cut = df.group_by("cut").agg(count("*").alias("number of diamonds"))
    print(df_count_per_cut.collect())

In [None]:
print(al.base64_batches)

In [None]:
# In plaintext:
message = proto.Request()
message.ParseFromString(base64.b64decode(al.base64_batches[0]))
message

In [None]:
# Average, maximum, and minimum price per cut.
df_avg_price_per_cut = df.group_by("cut").agg(
    avg("price").alias("avg_price"), 
    max("price").alias("max_price"), 
    min("price").alias("min_price")
)
df_avg_price_per_cut.sort("avg_price", ascending=False).collect()

In [None]:
Image(url="https://assets.vrai.com/25216/1692052168-diamond-cut-1-3.jpg")
# source: https://www.vrai.com/journal/post/diamond-cut

In [None]:
# Depth and table are one way to determine the quality of a diamond.
df = session.table("t_diamonds")
df_aggregated = df.group_by("cut").agg(
    avg("depth").alias("avg_depth"),
    avg("table").alias("avg_table"),
    avg(df["depth"] / df["table"]).alias("avg_depth_table_ratio")
)
df_aggregated.sort("avg_depth_table_ratio", ascending=False).collect()

## Performing a join

In [None]:
Image(url="https://www.brilliance.com/front/img/brilliance-diamond-cut-chart.jpg")
# source: https://www.brilliance.com/education/diamonds/cut

In [None]:
cut_info_data = [
    ("Ideal", "Highest quality cut"),
    ("Premium", "High quality cut, but not as much as Ideal"),
    ("Very Good", "Generally high quality cut"),
    ("Good", "Average cut with some flaws"),
    ("Fair", "Below average cut with noticeable flaws")
]

cut_info_schema = StructType([
    StructField("cut", StringType()),
    StructField("description", StringType())
])

# Create a DataFrame with cut information.
df_cut_info = session.create_dataframe(cut_info_data, schema=cut_info_schema)
df_cut_info.collect()

In [None]:
# Write the data to a new table.
df_cut_info.write.save_as_table("cut_info", mode="overwrite")

In [None]:
# Load the cut_info table into a DataFrame.
df_cut_info = session.table("cut_info")
df_cut_info.limit(2).collect()

In [None]:
# Join the diamonds DataFrame with the cut_info DataFrame.
df_joined = df.join(df_cut_info, df["cut"] == df_cut_info["cut"], how="inner")

# Select the columns to display.
df_joined = df_joined.select(df["id"], df["carat"], df["cut"], df_cut_info["description"])

# Show the results.
df_joined.collect()

## Testing a UDF

In [None]:
# Testing a simple UDF.
def calculate_diamond_volume(x: float, y: float, z: float) -> float:
    return x * y * z

# Register the UDF.
set_transmit_query_to_server(True)
calculate_diamond_volume_udf = session.udf.register(
    func=calculate_diamond_volume,
    return_type=FloatType(),
    input_types=[FloatType(), FloatType(), FloatType()],
    name="calculate_diamond_volume",
    is_permanent=False,
    replace=True
)
set_transmit_query_to_server(False)

# Use the UDF to calculate the volume of diamonds in table t_diamonds.
df_diamonds = session.table("t_diamonds")
df_volumes = df_diamonds.with_column(
    "volume",
    calculate_diamond_volume_udf(
        df_diamonds["x"], 
        df_diamonds["y"], 
        df_diamonds["z"]
    )
)

In [None]:
df_volumes.select("id", "carat", "volume").order_by("volume", ascending=False).collect()

## Testing a UDAF

In [None]:
# Define the UDAF class.
class AveragePricePerCarat:
    def __init__(self):
        self.total_price = 0.0
        self.total_carat = 0.0

    @property
    def aggregate_state(self):
        return self.total_price, self.total_carat

    def accumulate(self, price, carat):
        if carat > 0:
            self.total_price += price
            self.total_carat += carat

    def merge(self, other):
        self.total_price += other[0]
        self.total_carat += other[1]

    def finish(self):
        return self.total_price / self.total_carat if self.total_carat > 0 else 0.0

# Register the UDAF.
set_transmit_query_to_server(True)
avg_price_per_carat_udaf = udaf(AveragePricePerCarat, return_type=FloatType(), input_types=[FloatType(), FloatType()])
set_transmit_query_to_server(False)

In [None]:
# Use the UDAF on the diamonds DataFrame.
df = session.table("t_diamonds")
df_avg_price_per_carat = df.group_by(df["cut"]).agg(avg_price_per_carat_udaf(df["price"], df["carat"]).alias("avg_price_per_carat"))

# Show the results.
df_avg_price_per_carat.order_by("avg_price_per_carat", ascending=False).collect()

## Testing a UDTF

In [None]:
# This is a UDTF which provides information on the color and clarity grading of a diamond!

In [None]:
Image(url="https://4cs.gia.edu/wp-content/uploads/2024/07/02_Color-D-Z-Scale_960x800.jpg")
# source: https://4cs.gia.edu/en-us/diamond-color/

In [None]:
Image(url="https://lisarobinjewelry.com/cdn/shop/files/Diamond_Clarity_Scale_Graphic_1600x.jpg?v=1686052968")
# source: https://lisarobinjewelry.com/pages/what-is-diamond-clarity

In [None]:
class ColorClarityDetails:
    def process(self, color, clarity):
        color_desc = self.get_color_description(color)
        clarity_desc = self.get_clarity_description(clarity)
        yield (color, color_desc, clarity, clarity_desc)

    def get_color_description(self, color):
        color_descriptions = {
            "D": "Colorless",
            "E": "Colorless",
            "F": "Colorless",
            "G": "Near Colorless",
            "H": "Near Colorless",
            "I": "Near Colorless",
            "J": "Near Colorless"
        }
        return color_descriptions.get(color, "Unknown")

    def get_clarity_description(self, clarity):
        clarity_descriptions = {
            "IF": "Internally Flawless",
            "VVS1": "Very, Very Slightly Included 1",
            "VVS2": "Very, Very Slightly Included 2",
            "VS1": "Very Slightly Included 1",
            "VS2": "Very Slightly Included 2",
            "SI1": "Slightly Included 1",
            "SI2": "Slightly Included 2",
            "I1": "Included 1",
            "I2": "Included 2",
            "I3": "Included 3"
        }
        return clarity_descriptions.get(clarity, "Unknown")

# Define the schema of the output table
output_schema = StructType([
    StructField("color", StringType()),
    StructField("color_description", StringType()),
    StructField("clarity", StringType()),
    StructField("clarity_description", StringType())
])

# Register the UDTF
set_transmit_query_to_server(True)
color_clarity_udtf = udtf(ColorClarityDetails, output_schema=output_schema, input_types=[StringType(), StringType()])
set_transmit_query_to_server(False)

In [None]:
# Run the UDTF.
df = session.table("t_diamonds")
df.select(color_clarity_udtf(df["color"], df["clarity"])).limit(10).collect()

In [None]:
# Same UDTF logic but this cell is to display the AST information.
with session.ast_listener() as al:
    # Register the UDTF.
    set_transmit_query_to_server(True)
    color_clarity_udtf = udtf(ColorClarityDetails, output_schema=output_schema, input_types=[StringType(), StringType()])
    set_transmit_query_to_server(False)
    
    # Calling the UDTF.
    df = session.table("t_diamonds")
    df.select(color_clarity_udtf(df["color"], df["clarity"])).limit(10).collect()
print("Number of AST messages recorded: ", len(al.base64_batches))

In [None]:
# Print the AST recorded: 
print(al.base64_batches)

In [None]:
# In plaintext:
message = proto.Request()
message.ParseFromString(base64.b64decode(al.base64_batches[0]))
message