# Snowpark AST Decoder Demo

In [1]:
from snowflake.snowpark import Session
from snowflake.snowpark._internal.utils import set_transmit_query_to_server
from snowflake.snowpark._internal.utils import set_ast_state, AstFlagSource
import logging
from snowflake.snowpark.functions import avg, count, max, min, udaf, udtf
from snowflake.snowpark.types import DoubleType, FloatType, IntegerType, StringType, StructField, StructType
import base64
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

# For displaying images.
from IPython.display import Image
from IPython.core.display import HTML 

# Connecting to my cloud workspace running the server.
CONNECTION_PARAMETERS = {
    "account": "s3testaccount",
    "host": "snowflake.reg.local",
    "user": "snowman",
    "password": "test",
    "role": "sysadmin",
    "warehouse": "regress",
    "database": "testdb",
    "schema": "public",
    "port": "53200",
    "protocol": "http",
}

In [2]:
session = (
    Session.builder.configs(CONNECTION_PARAMETERS)
    .config("local_testing", False)
    .getOrCreate()
)

In [3]:
# Configure logging.
logging.basicConfig(level=logging.INFO)  # Set to INFO or DEBUG for more details
logger = logging.getLogger("snowflake.snowpark")

# Verify logging level.
logger.setLevel(logging.INFO)

## Enabling Required Parameters

In [4]:
# ENABLE_DATAFRAME should be set to True in the account level:
# this parameter controls whether to use the SQL query or AST on the server-side.
session.sql("show parameters like 'ENABLE_DATAFRAME' in account").show()

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-0014000112e6] show parameters like 'ENABLE_DATAFRAME' in account


-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"key"             |"value"  |"default"  |"level"  |"description"              |"type"   |"set_by_user"  |"set_in_job"                          |"set_on"                         |"set_by_thread_id"  |"set_by_thread_name"  |"set_by_class"  |"parameter_comment"  |"set_by_controlling_parameter"  |"activate_version"  |"partial_rollout"  |"jira_reference"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [5]:
# The Dataframe Processor relies on a Snowflake Notebook for execution:
session.sql("show notebooks;").show()

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-0014000112ea] show notebooks;


--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"created_on"                      |"name"               |"database_name"  |"schema_name"  |"comment"     |"owner"   |"query_warehouse"  |"url_id"              |"owner_role_type"  |"code_warehouse"  |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|2025-02-19 11:19:07.150000-08:00  |DATAFRAME_PROCESSOR  |TESTDB           |PUBLIC         |ready to use  |SYSADMIN  |REGRESS            |6vtrggcnvhw6iwcnvbow  |ROLE               |REGRESS           |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [6]:
# Since the decoder logic uses Python 3.10+ features, ensure that the notebook engine version uses Python 3.10+.
session.sql("show parameters like 'NOTEBOOK_ENGINE_VERSION'").show()

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-0014000112ee] show parameters like 'NOTEBOOK_ENGINE_VERSION'


----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"key"                    |"value"                   |"default"                |"level"  |"description"                                       |"type"  |"set_by_user"  |"set_in_job"                          |"set_on"                         |"set_by_thread_id"  |"set_by_thread_name"  |"set_by_class"  |"parameter_comment"                                 |"set_by_controlling_parameter"  |"activate_version"  |"partial_rollout"  |"jira_reference"  |
------------------------------------------------------------------------------------------------------

In [7]:
# To record the ASTs, we need to enable the AST flag:
AST_ENABLED = True
set_ast_state(AstFlagSource.TEST, AST_ENABLED)

In [8]:
# This helper (flag setter) controls whether to send the actual Snowpark query or a fake query to the server.
# Transmits "SELECT 'This is a fake query!!';"
set_transmit_query_to_server(False)

## Testing a Basic Snowpark Query

In [9]:
# Testing a basic Snowpark query with the AST enabled:
with session.ast_listener() as al:
    result = session.create_dataframe([1, 2, 3, 4]).collect()
    print(result)

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.connector.cursor:Number of results in first chunk: 4
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-0014000112f6] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(_1=1), Row(_1=2), Row(_1=3), Row(_1=4)]


In [10]:
# The recorded AST:
print(al.base64_batches)

['Cg8KDQj///////////8BEgAScApuCmSKBWEKUgpQChKiAg8KCxj///////////8BEAEKEqICDwoLGP///////////wEQAgoSogIPCgsY////////////ARADChKiAg8KCxj///////////8BEAQaCxj///////////8BEgAYASICCAESJAoiChjiBRUIARABGgIIATILGP///////////wESABgCIgIIAhIIEgYIAxICCAIYASIREg8KDQoFZmluYWwQAxgLIAkqBBABGBs=']


In [11]:
# In plaintext:
message = proto.Request()
message.ParseFromString(base64.b64decode(al.base64_batches[0]))
message

interned_value_table {
  string_values {
    key: -1
    value: ""
  }
}
body {
  assign {
    expr {
      sp_create_dataframe {
        data {
          sp_dataframe_data__list {
            vs {
              int64_val {
                src {
                  file: -1
                }
                v: 1
              }
            }
            vs {
              int64_val {
                src {
                  file: -1
                }
                v: 2
              }
            }
            vs {
              int64_val {
                src {
                  file: -1
                }
                v: 3
              }
            }
            vs {
              int64_val {
                src {
                  file: -1
                }
                v: 4
              }
            }
          }
        }
        src {
          file: -1
        }
      }
    }
    symbol {
    }
    uid: 1
    var_id {
      bitfield1: 1
    }
  }
}
body {
  assign {
    

## Testing a More Complex Example

In [12]:
# Create a table for diamonds!
session.sql("""
    CREATE OR REPLACE TABLE diamonds (
        'id' INTEGER,
        'carat' FLOAT,
        'cut' STRING,
        'color' STRING,
        'clarity' STRING,
        'depth' FLOAT,
        'table' INTEGER,
        'price' INTEGER,
        'x' FLOAT,
        'y' FLOAT,
        'z' FLOAT
    );
""")

<snowflake.snowpark.dataframe.DataFrame at 0x13e600ed0>

In [13]:
# Table data!
data = [
    [1, 0.23, "Ideal", "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43],
    [2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31],
    [3, 0.23, "Good", "E", "VS1", 56.9, 65, 327, 4.05, 4.07, 2.31],
    [4, 0.29, "Premium", "I", "VS2", 62.4, 58, 334, 4.2, 4.23, 2.63],
    [5, 0.31, "Good", "J", "SI2", 63.3, 58, 335, 4.34, 4.35, 2.75],
    [6, 0.24, "Very Good", "J", "VVS2", 62.8, 57, 336, 3.94, 3.96, 2.48],
    [7, 0.24, "Very Good", "I", "VVS1", 62.3, 57, 336, 3.95, 3.98, 2.47],
    [8, 0.26, "Very Good", "H", "SI1", 61.9, 55, 337, 4.07, 4.11, 2.53],
    [9, 0.22, "Fair", "E", "VS2", 65.1, 61, 337, 3.87, 3.78, 2.49],
    [10, 0.23, "Very Good", "H", "VS1", 59.4, 61, 338, 4.00, 4.05, 2.39],
    [11, 0.3, "Good", "J", "SI1", 64.2, 55, 339, 4.25, 4.28, 2.73],
    [12, 0.23, "Ideal", "J", "VS1", 62.8, 56, 340, 3.93, 3.9, 2.46],
    [13, 0.22, "Premium", "F", "SI1", 60.4, 61, 342, 3.88, 3.84, 2.33],
    [14, 0.31, "Ideal", "J", "SI2", 62.2, 54, 344, 4.35, 4.37, 2.71],
    [15, 0.2, "Premium", "E", "SI2", 60.2, 62, 345, 3.79, 3.75, 2.27],
    [16, 0.32, "Premium", "E", "I1", 60.9, 58, 345, 4.38, 4.42, 2.68],
    [17, 0.3, "Ideal", "I", "SI2", 62.5, 54, 348, 4.31, 4.34, 2.68],
    [18, 0.3, "Good", "J", "SI1", 63.4, 54, 351, 4.23, 4.29, 2.7],
    [19, 0.3, "Good", "J", "SI1", 63.8, 56, 351, 4.23, 4.26, 2.71],
    [20, 0.3, "Very Good", "J", "SI1", 62.7, 59, 351, 4.21, 4.27, 2.66],
    [21, 0.3, "Good", "I", "SI2", 63.3, 56, 351, 4.26, 4.3, 2.71],
    [22, 0.23, "Very Good", "E", "VS2", 63.8, 55, 352, 3.85, 3.92, 2.48],
    [23, 0.23, "Very Good", "H", "VS1", 61.0, 57, 353, 3.94, 3.96, 2.41],
    [24, 0.31, "Very Good", "J", "SI1", 59.4, 62, 353, 4.39, 4.43, 2.62],
    [25, 0.31, "Very Good", "J", "SI1", 58.1, 62, 353, 4.44, 4.47, 2.59],
    [26, 0.23, "Very Good", "G", "VVS2", 60.4, 58, 354, 3.97, 4.01, 2.41],
    [27, 0.24, "Premium", "I", "VS1", 62.5, 57, 355, 3.97, 3.94, 2.47],
    [28, 0.3, "Very Good", "J", "VS2", 62.2, 57, 357, 4.28, 4.3, 2.67],
    [29, 0.23, "Very Good", "D", "VS2", 60.5, 61, 357, 3.96, 3.97, 2.4]
]

# Define the table schema
schema = StructType([
    StructField("id", IntegerType()),
    StructField("carat", DoubleType()),
    StructField("cut", StringType()),
    StructField("color", StringType()),
    StructField("clarity", StringType()),
    StructField("depth", DoubleType()),
    StructField("table", IntegerType()),
    StructField("price", IntegerType()),
    StructField("x", DoubleType()),
    StructField("y", DoubleType()),
    StructField("z", DoubleType())
])

In [14]:
# Creating the dataframe.
df = session.create_dataframe(data, schema=schema)
df.limit(5).collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 5
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-00140001130e] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(ID=1, CARAT=0.23, CUT='Ideal', COLOR='E', CLARITY='SI2', DEPTH=61.5, TABLE=55, PRICE=326, X=3.95, Y=3.98, Z=2.43),
 Row(ID=2, CARAT=0.21, CUT='Premium', COLOR='E', CLARITY='SI1', DEPTH=59.8, TABLE=61, PRICE=326, X=3.89, Y=3.84, Z=2.31),
 Row(ID=3, CARAT=0.23, CUT='Good', COLOR='E', CLARITY='VS1', DEPTH=56.9, TABLE=65, PRICE=327, X=4.05, Y=4.07, Z=2.31),
 Row(ID=4, CARAT=0.29, CUT='Premium', COLOR='I', CLARITY='VS2', DEPTH=62.4, TABLE=58, PRICE=334, X=4.2, Y=4.23, Z=2.63),
 Row(ID=5, CARAT=0.31, CUT='Good', COLOR='J', CLARITY='SI2', DEPTH=63.3, TABLE=58, PRICE=335, X=4.34, Y=4.35, Z=2.75)]

In [15]:
# Writing the dataframe to a table.
df.write.save_as_table("t_diamonds", mode="overwrite")

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-00140001131e] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


## Testing some operations

In [16]:
# Count of diamonds per cut.
with session.ast_listener() as al:
    df_count_per_cut = df.group_by("cut").agg(count("*").alias("count"))
    print(df_count_per_cut.collect())

INFO:snowflake.connector.cursor:Number of results in first chunk: 5
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-00140001132a] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(CUT='Good', COUNT=6), Row(CUT='Ideal', COUNT=4), Row(CUT='Premium', COUNT=6), Row(CUT='Very Good', COUNT=12), Row(CUT='Fair', COUNT=1)]


In [17]:
print(al.base64_batches)

['Cg8KDQj///////////8BEgASQAo+CjTaBjEKGQoVigwSCgsY////////////ARIDY3V0EAESB4ICBAoCCAUaCxj///////////8BEgAYDCICCAwSfQp7CnHKCm4KVwpTwgNQCjaCATMKDxoNCgsKCQoHCgVjb3VudBoTigwQCgsY////////////ARIBKiILGP///////////wESAggBGgVjb3VudCILGP///////////wEQARIGUgQKAggMGgsY////////////ARIAGA0iAggNEiQKIgoY4gUVCAEQARoCCA0yCxj///////////8BEgAYDiICCA4SCBIGCA8SAggOGAEiERIPCg0KBWZpbmFsEAMYCyAJKgQQARgb']


In [18]:
# In plaintext:
message = proto.Request()
message.ParseFromString(base64.b64decode(al.base64_batches[0]))
message

interned_value_table {
  string_values {
    key: -1
    value: ""
  }
}
body {
  assign {
    expr {
      sp_dataframe_group_by {
        cols {
          args {
            string_val {
              src {
                file: -1
              }
              v: "cut"
            }
          }
          variadic: true
        }
        df {
          sp_dataframe_ref {
            id {
              bitfield1: 5
            }
          }
        }
        src {
          file: -1
        }
      }
    }
    symbol {
    }
    uid: 12
    var_id {
      bitfield1: 12
    }
  }
}
body {
  assign {
    expr {
      sp_relational_grouped_dataframe_agg {
        exprs {
          args {
            sp_column_alias {
              col {
                apply_expr {
                  fn {
                    builtin_fn {
                      name {
                        name {
                          sp_name_flat {
                            name: "count"
                          }

In [20]:
# Average, maximum, and minimum price per cut.
df_avg_price_per_cut = df.group_by("cut").agg(
    avg("price").alias("avg_price"), 
    max("price").alias("max_price"), 
    min("price").alias("min_price")
)
df_avg_price_per_cut.sort("avg_price", ascending=False).collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 5
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba8229-0100-0001-0000-001400011336] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(CUT='Very Good', AVG_PRICE=Decimal('348.083333'), MAX_PRICE=357, MIN_PRICE=336),
 Row(CUT='Good', AVG_PRICE=Decimal('342.333333'), MAX_PRICE=351, MIN_PRICE=327),
 Row(CUT='Premium', AVG_PRICE=Decimal('341.166667'), MAX_PRICE=355, MIN_PRICE=326),
 Row(CUT='Ideal', AVG_PRICE=Decimal('339.500000'), MAX_PRICE=348, MIN_PRICE=326),
 Row(CUT='Fair', AVG_PRICE=Decimal('337.000000'), MAX_PRICE=337, MIN_PRICE=337)]

In [21]:
Image(url="https://assets.vrai.com/25216/1692052168-diamond-cut-1-3.jpg")
# source: https://www.vrai.com/journal/post/diamond-cut

In [22]:
# Depth and table are one way to determine the quality of a diamond.
df_aggregated = df.group_by("cut").agg(
    avg("depth").alias("avg_depth"),
    avg("table").alias("avg_table"),
    avg("depth_table_ratio").alias("avg_depth_table_ratio")
)

## Performing a join

In [23]:
Image(url="https://www.brilliance.com/front/img/brilliance-diamond-cut-chart.jpg")
# source: https://www.brilliance.com/education/diamonds/cut

In [24]:
cut_info_data = [
    ("Ideal", "Highest quality cut"),
    ("Premium", "High quality cut, but not as much as Ideal"),
    ("Very Good", "Generally high quality cut"),
    ("Good", "Average cut with some flaws"),
    ("Fair", "Below average cut with noticeable flaws")
]

cut_info_schema = StructType([
    StructField("cut", StringType()),
    StructField("description", StringType())
])

# Create a DataFrame with cut information.
df_cut_info = session.create_dataframe(cut_info_data, schema=cut_info_schema)
df_cut_info.collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.connector.cursor:Number of results in first chunk: 5
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-001400011346] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(CUT='Ideal', DESCRIPTION='Highest quality cut'),
 Row(CUT='Premium', DESCRIPTION='High quality cut, but not as much as Ideal'),
 Row(CUT='Very Good', DESCRIPTION='Generally high quality cut'),
 Row(CUT='Good', DESCRIPTION='Average cut with some flaws'),
 Row(CUT='Fair', DESCRIPTION='Below average cut with noticeable flaws')]

In [25]:
# Write the data to a new table.
df_cut_info.write.save_as_table("cut_info", mode="overwrite")

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-001400011356] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


In [26]:
# Load the cut_info table into a DataFrame.
df_cut_info = session.table("cut_info")
df_cut_info.limit(2).collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 2
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-001400011366] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(CUT='Ideal', DESCRIPTION='Highest quality cut'),
 Row(CUT='Premium', DESCRIPTION='High quality cut, but not as much as Ideal')]

In [27]:
# Join the diamonds DataFrame with the cut_info DataFrame.
df_joined = df.join(df_cut_info, df["cut"] == df_cut_info["cut"], how="inner")

# Select the columns to display.
df_joined = df_joined.select(df["id"], df["carat"], df["cut"], df_cut_info["description"])

# Show the results.
df_joined.collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 29
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-00140001137e] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(ID=1, CARAT=0.23, l_0000_CUT='Ideal', DESCRIPTION='Highest quality cut'),
 Row(ID=2, CARAT=0.21, l_0000_CUT='Premium', DESCRIPTION='High quality cut, but not as much as Ideal'),
 Row(ID=3, CARAT=0.23, l_0000_CUT='Good', DESCRIPTION='Average cut with some flaws'),
 Row(ID=4, CARAT=0.29, l_0000_CUT='Premium', DESCRIPTION='High quality cut, but not as much as Ideal'),
 Row(ID=5, CARAT=0.31, l_0000_CUT='Good', DESCRIPTION='Average cut with some flaws'),
 Row(ID=6, CARAT=0.24, l_0000_CUT='Very Good', DESCRIPTION='Generally high quality cut'),
 Row(ID=7, CARAT=0.24, l_0000_CUT='Very Good', DESCRIPTION='Generally high quality cut'),
 Row(ID=8, CARAT=0.26, l_0000_CUT='Very Good', DESCRIPTION='Generally high quality cut'),
 Row(ID=9, CARAT=0.22, l_0000_CUT='Fair', DESCRIPTION='Below average cut with noticeable flaws'),
 Row(ID=10, CARAT=0.23, l_0000_CUT='Very Good', DESCRIPTION='Generally high quality cut'),
 Row(ID=11, CARAT=0.3, l_0000_CUT='Good', DESCRIPTION='Average cut with some flaws

## Testing a UDF

In [28]:
# Testing a simple UDF.
def calculate_diamond_volume(x: float, y: float, z: float) -> float:
    return x * y * z

# Register the UDF
set_transmit_query_to_server(True)
calculate_diamond_volume_udf = session.udf.register(
    func=calculate_diamond_volume,
    return_type=FloatType(),
    input_types=[FloatType(), FloatType(), FloatType()],
    name="calculate_diamond_volume",
    is_permanent=False,
    replace=True
)
set_transmit_query_to_server(False)

# Use the UDF to calculate the volume of diamonds in table t_diamonds
df_diamonds = session.table("t_diamonds")
df_volumes = df_diamonds.with_column(
    "volume",
    calculate_diamond_volume_udf(
        df_diamonds["x"], 
        df_diamonds["y"], 
        df_diamonds["z"]
    )
)

INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-001400011392] create SCOPED TEMPORARY                     stage if not exists "TESTDB"."PUBLIC".SNOWPARK_TEMP_STAGE_67GYH2U1R8
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-001400011396] ls '@"TESTDB"."PUBLIC".SNOWPARK_TEMP_STAGE_67GYH2U1R8'
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-00140001139a]  SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01ba822a-0100-0001-0000-001400011396')))
INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-00140001139e] 
CR

In [29]:
df_volumes.select("id", "carat", "volume").order_by("volume", ascending=False).collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 29
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-0014000113a6] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(ID=5, CARAT=0.31, VOLUME=51.917249999999996),
 Row(ID=16, CARAT=0.32, VOLUME=51.883728000000005),
 Row(ID=14, CARAT=0.31, VOLUME=51.515744999999995),
 Row(ID=25, CARAT=0.31, VOLUME=51.403212),
 Row(ID=24, CARAT=0.31, VOLUME=50.952974),
 Row(ID=17, CARAT=0.3, VOLUME=50.130472),
 Row(ID=11, CARAT=0.3, VOLUME=49.6587),
 Row(ID=21, CARAT=0.3, VOLUME=49.64177999999999),
 Row(ID=28, CARAT=0.3, VOLUME=49.13868),
 Row(ID=18, CARAT=0.3, VOLUME=48.99609000000001),
 Row(ID=19, CARAT=0.3, VOLUME=48.833658),
 Row(ID=20, CARAT=0.3, VOLUME=47.818022),
 Row(ID=4, CARAT=0.29, VOLUME=46.72458),
 Row(ID=8, CARAT=0.26, VOLUME=42.321081),
 Row(ID=7, CARAT=0.24, VOLUME=38.830870000000004),
 Row(ID=10, CARAT=0.23, VOLUME=38.718),
 Row(ID=6, CARAT=0.24, VOLUME=38.693951999999996),
 Row(ID=27, CARAT=0.24, VOLUME=38.635246),
 Row(ID=26, CARAT=0.23, VOLUME=38.366477),
 Row(ID=1, CARAT=0.23, VOLUME=38.20203),
 Row(ID=3, CARAT=0.23, VOLUME=38.076885),
 Row(ID=29, CARAT=0.23, VOLUME=37.73088),
 Row(ID=12, CARA

## Testing a UDAF

In [30]:
# Define the UDAF class
class AveragePricePerCarat:
    def __init__(self):
        self.total_price = 0.0
        self.total_carat = 0.0

    @property
    def aggregate_state(self):
        return self.total_price, self.total_carat

    def accumulate(self, price, carat):
        if carat > 0:
            self.total_price += price
            self.total_carat += carat

    def merge(self, other):
        self.total_price += other[0]
        self.total_carat += other[1]

    def finish(self):
        return self.total_price / self.total_carat if self.total_carat > 0 else 0.0

# Register the UDAF
set_transmit_query_to_server(True)
avg_price_per_carat_udaf = udaf(AveragePricePerCarat, return_type=FloatType(), input_types=[FloatType(), FloatType()])
set_transmit_query_to_server(False)

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-0014000113b2] ls '@"TESTDB"."PUBLIC".SNOWPARK_TEMP_STAGE_67GYH2U1R8'
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-0014000113b6]  SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01ba822a-0100-0001-0000-0014000113b2')))
INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822a-0100-0001-0000-0014000113ba] 
CREATE
TEMPORARY  AGGREGATE FUNCTION  "TESTDB"."PUBLIC".SNOWPARK_TEMP_AGGREGATE_FUNCTION_76TJSAU2F8(arg1 FLOAT,arg2 FLOAT)

RETURNS FLOAT
LANGUAGE PYTHON 
VOLATILE
RUNTIME_VERSION=3.11

PACKAGES=('cloudpickle==2.2.1')


HANDLER='compute'

AS $$
import pickle

func = pickle.loads(bytes.fromhex('80

In [32]:
# Use the UDAF on the diamonds DataFrame
df = session.table("t_diamonds")
df_avg_price_per_carat = df.group_by(df["cut"]).agg(avg_price_per_carat_udaf(df["price"], df["carat"]).alias("avg_price_per_carat"))

# Show the results
df_avg_price_per_carat.order_by("avg_price_per_carat", ascending=False).collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.connector.cursor:Number of results in first chunk: 5
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822b-0100-0001-0000-0014000113d6] SELECT 'This is a fake query!!'; --No actual query sent, the server should rely on the provided AST!


[Row(CUT='Fair', AVG_PRICE_PER_CARAT=1531.8181818181818),
 Row(CUT='Premium', AVG_PRICE_PER_CARAT=1383.1081081081081),
 Row(CUT='Very Good', AVG_PRICE_PER_CARAT=1343.0868167202573),
 Row(CUT='Ideal', AVG_PRICE_PER_CARAT=1269.1588785046729),
 Row(CUT='Good', AVG_PRICE_PER_CARAT=1180.4597701149423)]

## Testing a UDTF

In [None]:
# This is a UDTF which provides information on the color and clarity grading of a diamond!

In [33]:
Image(url="https://4cs.gia.edu/wp-content/uploads/2024/07/02_Color-D-Z-Scale_960x800.jpg")
# source: https://4cs.gia.edu/en-us/diamond-color/

In [34]:
Image(url="https://lisarobinjewelry.com/cdn/shop/files/Diamond_Clarity_Scale_Graphic_1600x.jpg?v=1686052968")
# source: https://lisarobinjewelry.com/pages/what-is-diamond-clarity

In [35]:
class ColorClarityDetails:
    def process(self, color, clarity):
        color_desc = self.get_color_description(color)
        clarity_desc = self.get_clarity_description(clarity)
        yield (color, color_desc, clarity, clarity_desc)

    def get_color_description(self, color):
        color_descriptions = {
            "D": "Colorless",
            "E": "Colorless",
            "F": "Colorless",
            "G": "Near Colorless",
            "H": "Near Colorless",
            "I": "Near Colorless",
            "J": "Near Colorless"
        }
        return color_descriptions.get(color, "Unknown")

    def get_clarity_description(self, clarity):
        clarity_descriptions = {
            "IF": "Internally Flawless",
            "VVS1": "Very, Very Slightly Included 1",
            "VVS2": "Very, Very Slightly Included 2",
            "VS1": "Very Slightly Included 1",
            "VS2": "Very Slightly Included 2",
            "SI1": "Slightly Included 1",
            "SI2": "Slightly Included 2",
            "I1": "Included 1",
            "I2": "Included 2",
            "I3": "Included 3"
        }
        return clarity_descriptions.get(clarity, "Unknown")

# Define the schema of the output table
output_schema = StructType([
    StructField("color", StringType()),
    StructField("color_description", StringType()),
    StructField("clarity", StringType()),
    StructField("clarity_description", StringType())
])

# Register the UDTF
set_transmit_query_to_server(True)
color_clarity_udtf = udtf(ColorClarityDetails, output_schema=output_schema, input_types=[StringType(), StringType()])
set_transmit_query_to_server(False)

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822b-0100-0001-0000-0014000113e6] ls '@"TESTDB"."PUBLIC".SNOWPARK_TEMP_STAGE_67GYH2U1R8'
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822b-0100-0001-0000-0014000113ea]  SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01ba822b-0100-0001-0000-0014000113e6')))
INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822b-0100-0001-0000-0014000113ee] 
CREATE
TEMPORARY  FUNCTION  "TESTDB"."PUBLIC".SNOWPARK_TEMP_TABLE_FUNCTION_XD5OWEP6HZ(arg1 STRING,arg2 STRING)

RETURNS TABLE (COLOR STRING,COLOR_DESCRIPTION STRING,CLARITY STRING,CLARITY_DESCRIPTION STRING)
LANGUAGE PYTHON 
VOLATILE
RUNTIME_VERSION=3.11

PACKAGES=('cloudpickle==2.2.1')


HANDLER

In [38]:
df = session.table("t_diamonds")
df_color_clarity_details = df.join_table_function(
    color_clarity_udtf(df["color"], df["clarity"])
).collect()

INFO:snowflake.connector.cursor:Number of results in first chunk: 0


SnowparkSQLException: (1304): 01ba822c-0100-0001-0000-00140001140e: 000603 (XX000): 01ba822c-0100-0001-0000-00140001140e: SQL execution internal error:
NON_FATAL: DataframeExecutionUtils::executeDataframe():dataframe_processing_failure(com.snowflake.core.DataframeExecutionUtils:executeDataframe:55) - Dataframe Execution Failed caused by exception: [java.lang.RuntimeException: java.util.concurrent.ExecutionException: java.lang.RuntimeException: Expression type not implemented yet: sp_dataframe_join_table_function at com.snowflake.resources.dataframe.NotebookDataframeProcessor.evalImpl(NotebookDataframeProcessor.java:365)]

In [40]:
# Same UDTF logic but this cell is to display the AST information.
with session.ast_listener() as al:
    # Register the UDTF.
    set_transmit_query_to_server(True)
    color_clarity_udtf = udtf(ColorClarityDetails, output_schema=output_schema, input_types=[StringType(), StringType()])
    set_transmit_query_to_server(False)
    
    # Calling the UDTF.
    df = session.table("t_diamonds")
    df_color_clarity_details = df.join_table_function(
        color_clarity_udtf(df["color"], df["clarity"])
    ).collect()
    df_color_clarity_details
print("Number of AST messages recorded: ", len(al.base64_batches))

INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822c-0100-0001-0000-001400011416] ls '@"TESTDB"."PUBLIC".SNOWPARK_TEMP_STAGE_67GYH2U1R8'
INFO:snowflake.connector.cursor:Number of results in first chunk: 0
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822c-0100-0001-0000-00140001141a]  SELECT "name" FROM ( SELECT  *  FROM  TABLE ( RESULT_SCAN('01ba822c-0100-0001-0000-001400011416')))
INFO:snowflake.connector.cursor:Number of results in first chunk: 1
INFO:snowflake.snowpark._internal.server_connection:Execute query [queryID: 01ba822c-0100-0001-0000-00140001141e] 
CREATE
TEMPORARY  FUNCTION  "TESTDB"."PUBLIC".SNOWPARK_TEMP_TABLE_FUNCTION_YFC4IPSNCL(arg1 STRING,arg2 STRING)

RETURNS TABLE (COLOR STRING,COLOR_DESCRIPTION STRING,CLARITY STRING,CLARITY_DESCRIPTION STRING)
LANGUAGE PYTHON 
VOLATILE
RUNTIME_VERSION=3.11

PACKAGES=('cloudpickle==2.2.1')


HANDLER

SnowparkSQLException: (1304): 01ba822c-0100-0001-0000-001400011426: 000603 (XX000): 01ba822c-0100-0001-0000-001400011426: SQL execution internal error:
NON_FATAL: DataframeExecutionUtils::executeDataframe():dataframe_processing_failure(com.snowflake.core.DataframeExecutionUtils:executeDataframe:55) - Dataframe Execution Failed caused by exception: [java.lang.RuntimeException: java.util.concurrent.ExecutionException: java.lang.RuntimeException: Expression type not implemented yet: sp_dataframe_join_table_function at com.snowflake.resources.dataframe.NotebookDataframeProcessor.evalImpl(NotebookDataframeProcessor.java:365)]

In [41]:
# Print the AST recorded:
print(al.base64_batches)

[]


In [None]:
# In plaintext:
message = proto.Request()
message.ParseFromString(base64.b64decode("\n".join(al.base64_batches)).strip())
message