# Snowpark AST Decoder Demo

In [2]:
from snowflake.snowpark import Session
from snowflake.snowpark._internal.utils import set_transmit_query_to_server
from snowflake.snowpark._internal.utils import set_ast_state, AstFlagSource

# Connecting to local machine running the server.
CONNECTION_PARAMETERS = {
    "account": "s3testaccount",
    "host": "snowflake.reg.local",
    "user": "snowman",
    "password": "test",
    "role": "sysadmin",
    "warehouse": "regress",
    "database": "testdb",
    "schema": "public",
    "port": "53200",
    "protocol": "http",
}

In [None]:
session = (
    Session.builder.configs(CONNECTION_PARAMETERS)
    .config("local_testing", False)
    .getOrCreate()
)

### Enabling Required Parameters

In [None]:
# ENABLE_DATAFRAME should be set to True in the account level:
# this parameter controls whether to use the SQL query or AST on the server-side.
session.sql("show parameters like 'ENABLE_DATAFRAME' in account").show()

In [None]:
# The Dataframe Processor relies on a Snowflake Notebook for execution:
session.sql("show notebooks;").show()

In [None]:
# Since the decoder logic uses Python 3.10+ features, ensure that the notebook
# engine version uses Python 3.10+.
session.sql("show parameters like 'NOTEBOOK_ENGINE_VERSION'").show()

In [None]:
# To record the ASTs, we need to enable the AST flag:
AST_ENABLED = True
set_ast_state(AstFlagSource.TEST, AST_ENABLED)

In [None]:
# This helper (flag setter) controls whether to send the actual Snowpark query
# or a fake query to the server.
# Transmits "SELECT 'This is a fake query!!';"
set_transmit_query_to_server(False)

### Testing a Basic Snowpark Query

In [None]:
# Testing a basic Snowpark query with the AST enabled:
with session.ast_listener() as al:
    result = session.create_dataframe([1, 2, 3, 4]).collect()
    print(result)

In [None]:
# The recorded AST:
print(al.base64_batches)

### Testing a More Complex Example

In [None]:
# Create a table.
session.sql("create or replace temp table t_diamonds (x int, y int, z int, cut string)").collect()

In [None]:
# Insert data into the table.
session.sql("insert into t_diamonds values (1, 2, 3, 'Ideal')").collect()
session.sql("insert into t_diamonds values (4, 5, 6, 'Premium')").collect()
session.sql("insert into t_diamonds values (7, 8, 9, 'Good')").collect()
session.sql("insert into t_diamonds values (10, 11, 12, 'Very Good')").collect()
session.sql("insert into t_diamonds values (13, 14, 15, 'Fair')").collect()
session.sql("insert into t_diamonds values (16, 17, 18, 'Ideal')").collect()
session.sql("insert into t_diamonds values (19, 20, 21, 'Premium')").collect()
session.sql("insert into t_diamonds values (22, 23, 24, 'Good')").collect()

In [None]:
# Create a dataframe from the table.
df = session.table("t_diamonds")
df.collect()

In [None]:
# Perform some operations on the dataframe.
df.filter(df["cut"] == "Ideal").select(df["x"], df["y"] + 1).collect()

In [None]:
# Perform a simple aggregation.
df.group_by(df["cut"]).agg({"x": "sum", "y": "avg", "z": "max"}).collect()

In [None]:
# Perform a simple join.
df2 = session.table("t_diamonds")
df2 = df2.with_column("new_col", df2["x"] * 2)
df.join(df2, df["x"] == df2["x"]).select(df["x"], df2["new_col"]).collect()