# Testing Arrow Decimal
## Design:
### 1. Test individual logic - Snowflake data downcast( DECIMAL(37,18) ), PyArrow operand creation, PyArrow multiply().
### 2. Confirm they meet response time and identify bottlenecks and optimizations.
### 3. Sum of testing logic could be registered as a Snowflake Model Object.
        Model Object can run an SPCS Inferencing Service that is meant for real-time potentially avoiding UDF Python Construction / Deconstruction of the Environment which is the primary wall time on small Decimal Arrow Operations.
        

In [None]:
# Standard Python Libraries
import pandas as pd

# PyArrow for High-Performance Computing
import pyarrow as pa
import pyarrow.compute as pc

# Snowflake Snowpark (DataFrame API for Snowflake)
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col

# Initialize Snowflake session
session = get_active_session()

In [None]:
-- Setting Up Environment for Testing Downcasting Techniques in Snowflake --

-- Step 1: Set User Role and Access Permissions --
-- Granting all privileges on the database to the user role (only necessary if not previously configured) --
-- USE ROLE ACCOUNTADMIN;
-- GRANT ALL PRIVILEGES ON DATABASE CONTAINER_RUNTIME_LAB TO ROLE CONTAINER_RUNTIME_LAB_USER;

-- Step 2: Switch to the User Role and Database for Testing --
USE ROLE CONTAINER_RUNTIME_LAB_USER;
USE DATABASE CONTAINER_RUNTIME_LAB;

-- Step 3: Create a Dedicated Schema for Testing (Arrow) --
CREATE OR REPLACE SCHEMA ARROW;

-- Step 4: Create Sample Table for Downcasting Tests --
CREATE OR REPLACE TABLE SAMPLE_TBL (
    col_a DECIMAL(38,2),  -- High precision DECIMAL column for testing downcasting
    col_b DECIMAL(38,2)   -- Another DECIMAL column for multi-column testing
);
-- DECIMAL(38,12)
-- Step 5: Populate Sample Table with Random Test Data (2000 Rows) --
-- Using Snowflake's GENERATOR function for fast data creation --
INSERT INTO SAMPLE_TBL (col_a, col_b)
SELECT 
    ROUND(RANDOM() * 9999999999999.99, 1) AS col_a,  -- Random decimal values with 1 decimal place
    ROUND(RANDOM() * 9999999999999.99, 1) AS col_b
FROM 
    TABLE(GENERATOR(ROWCOUNT => 2000));


In [None]:
-- Testing Various Techniques for Downcasting DECIMAL Columns in Snowflake --

-- New Table Creation Method for Downcasting DECIMAL Precision and Scale --
-- This approach creates a new table with the downcasted column, preserving data integrity. --
-- Recommended: Use Snowpark for Lazy Evaluation in DAG Execution (Efficient and Scalable). --

-- Step 1: Create a New Table with Downcasted Column (DECIMAL(35,2)) --
CREATE OR REPLACE TABLE SAMPLE_TBL_NEW AS 
SELECT 
    CAST(COL_A AS DECIMAL(37,2)) AS COL_A,  -- Downcasting COL_A to DECIMAL(35,2)
    -- Explicitly list other columns to maintain structure --
    COL_B  
FROM 
    SAMPLE_TBL;

-- Step 2: Swap Tables (Instant Metadata Operation) --
ALTER TABLE SAMPLE_TBL RENAME TO SAMPLE_TBL_OLD;  -- Rename original table for backup
ALTER TABLE SAMPLE_TBL_NEW RENAME TO SAMPLE_TBL;  -- Replace with the new downcasted version


In [None]:
DROP TABLE SAMPLE_TBL_OLD;                        -- Drop the backup table (Cleanup)

-- Verify the Precision and Scale of the Downcasted Column --
SELECT 
    COLUMN_NAME, 
    DATA_TYPE, 
    NUMERIC_PRECISION, 
    NUMERIC_SCALE 
FROM 
    INFORMATION_SCHEMA.COLUMNS 
WHERE 
    TABLE_NAME = 'SAMPLE_TBL' 
    AND COLUMN_NAME = 'COL_A';
-- Note: INFORMATION_SCHEMA may temporarily show multiple column versions due to metadata updates.

In [None]:
-- Testing Various Techniques for Downcasting DECIMAL Columns in Snowflake --

-- View Creation Method for Downcasting DECIMAL Precision and Scale --
-- This approach creates a view that dynamically applies the downcast without modifying the original table data. 

-- Recommended: Use Snowpark for Lazy Evaluation in DAG Execution (Efficient and Scalable). --
CREATE OR REPLACE VIEW SAMPLE_TBL_MAT_VIEW AS 
SELECT 
    CAST(COL_A AS DECIMAL(36,2)) AS COL_A,  -- Downcasting COL_A to DECIMAL(35,2)
    -- Explicitly list other columns to maintain structure --
    COL_B  
FROM 
    SAMPLE_TBL;

In [None]:
-- Verifying the Precision and Scale of the Downcasted Column in the View --
SELECT 
    COLUMN_NAME, 
    DATA_TYPE, 
    NUMERIC_PRECISION, 
    NUMERIC_SCALE 
FROM 
    INFORMATION_SCHEMA.COLUMNS 
WHERE 
    TABLE_NAME = 'SAMPLE_TBL_MAT_VIEW' 
    -- AND COLUMN_NAME = 'COL_A';

In [None]:
# Snowpark version of DOWNCAST_CREATE_VIEW
# There is no create_or_replace_view with options like casting columns in Snowpark Dataframe API
# See Dataframe Class - https://github.com/snowflakedb/snowpark-python/blob/0511a45947242ae1f7deb18126886d59bc711926/src/snowflake/snowpark/dataframe.py#L4992

# Using Snowpark session object for setting role and database
session.use_database("CONTAINER_RUNTIME_LAB")
session.use_schema("ARROW")

# Creating the View Using Direct SQL (Downcasting Column)
session.sql("""
CREATE OR REPLACE VIEW SAMPLE_TBL_SNOW_VIEW AS 
SELECT 
    CAST(COL_A AS DECIMAL(38,2)) AS COL_A,  -- Downcasting COL_A to DECIMAL(35,2)
    COL_B  
FROM 
    SAMPLE_TBL;
""").collect()

# Loading the Created View as a Snowpark DataFrame
sample_tbl_view_df = session.table("CONTAINER_RUNTIME_LAB.ARROW.SAMPLE_TBL_SNOW_VIEW")
sample_tbl_view_df.show(5)  # Display the first 5 rows
# Displaying the DataFrame (Lazy Evaluation, Only Shows Query)
# result = session.sql("""
#     SELECT 
#         COL_A, 
#         typeof(COL_A) AS column_type 
#     FROM SAMPLE_TBL_SNOW_VIEW 
#     LIMIT 1;
# """)
# print(result.collect())

In [None]:
-- Verifying the Precision and Scale of the Downcasted Column in the View --
SELECT 
    COLUMN_NAME, 
    DATA_TYPE, 
    NUMERIC_PRECISION, 
    NUMERIC_SCALE 
FROM 
    INFORMATION_SCHEMA.COLUMNS 
WHERE 
    TABLE_NAME = 'SAMPLE_TBL_SNOW_VIEW' 
    AND COLUMN_NAME = 'COL_A';

## Testing the largest and smallest value in Snowflake DECIMAL.
#### Overflow on downcast - intentionally identify and specify a value(Value Clamping) or throw an error.

In [None]:
-- Testing the largest exact integer and decimal in Snowflake
SELECT 
    999999999999999::DECIMAL(38,0) AS max_dec,
    9999999999999999::DECIMAL(38,0) AS max_dec_mag_plus_one,
   -999999999999999999999999999999999999::DECIMAL(38,0) AS min_dec,
    TRY_CAST((9999999999999999999999999999999999999 + 1) AS DECIMAL(38,0)) AS max_plus_one;

In [None]:
query = """
SELECT 
    999999999999999::DECIMAL(38,0) AS max_dec,
    9999999999999999::DECIMAL(38,0) AS max_dec_mag_plus_one,
    -999999999999999999999999999999999999::DECIMAL(38,0) AS min_dec,
    TRY_CAST((9999999999999999999999999999999999999 + 1) AS DECIMAL(38,0)) AS max_plus_one;
"""

# Executing the query
result_df = session.sql(query).collect()

# Displaying the result
for row in result_df:
    print(row)

In [None]:
-- Testing the largest DECIMAL value in Snowflake
-- SELECT 
--     CAST('999999999999999' AS DECIMAL(38,0)) AS max_decimal;

CREATE OR REPLACE VIEW SAMPLE_TBL_VIEW AS 
SELECT 
    CAST('999999999999999' AS DECIMAL) AS max_decimal;

-- Verifying the Precision and Scale of the Downcasted Column in the View --
SELECT 
    COLUMN_NAME, 
    DATA_TYPE, 
    NUMERIC_PRECISION, 
    NUMERIC_SCALE 
FROM 
    INFORMATION_SCHEMA.COLUMNS 
WHERE 
    TABLE_NAME = 'SAMPLE_TBL_VIEW';

In [None]:
SELECT 
    CAST(COL_A AS DECIMAL(35,2)) AS column_downcasted
FROM 
    SAMPLE_TBL;

In [None]:
from snowflake.snowpark import functions as F

# Define the exact max and min safe range for DECIMAL(37,1)
max_decimal_value =  9999999999999.99
min_decimal_value = -9999999999999.99

# # Ensure values are within the safe range using LEAST and GREATEST
snow_df_input = session.table("SAMPLE_TBL_MAT_VIEW").select(
    F.sql_expr(f"CAST(LEAST(GREATEST(ROUND(col_a, 1), {min_decimal_value}), {max_decimal_value}) AS DECIMAL(37,2))").alias("col_a"),
    F.sql_expr(f"CAST(LEAST(GREATEST(ROUND(col_b, 1), {min_decimal_value}), {max_decimal_value}) AS DECIMAL(37,2))").alias("col_b")
)

snow_df_input.collect()

## Logic for Multiply() into Snowflake Registry

### BNY Team - PREP\_PRECISION() can be any Python function using data already in memory. The code above clamps within Snowflake, but preprocessing with Pandas in Python runtime is also efficient. Treat the decorated function in the Model Object as a driver, calling any custom function you write.


In [None]:
from snowflake.ml.registry import Registry
from snowflake.ml.model import custom_model
import pyarrow as pa
import pyarrow.compute as pc
import pandas as pd

# Name of the class
class arrow_decimal_fast(custom_model.CustomModel):
    def prep_precision(self, input_df: pd.DataFrame) -> pd.DataFrame:
        # Define the exact max and min safe range for DECIMAL(37,1)
        max_decimal_value = 9999999999999999999999999999999999999
        min_decimal_value = -9999999999999999999999999999999999999
        decimal_precision = 37
        decimal_scale = 18

        # Apply clamping to maintain DECIMAL precision
        input_df["COL_A"] = input_df["COL_A"].clip(lower=min_decimal_value, upper=max_decimal_value)
        input_df["COL_B"] = input_df["COL_B"].clip(lower=min_decimal_value, upper=max_decimal_value)
        return input_df

    @custom_model.inference_api
    def arrow_multiply(self, X: pd.DataFrame) -> pd.DataFrame:
        # Preparing input data with DECIMAL precision
        X = self.prep_precision(X)
        
        # Convert columns to PyArrow Decimal
        col_a = pa.array(X["COL_A"], type=pa.decimal128(37, 1))
        col_b = pa.array(X["COL_B"], type=pa.decimal128(37, 1))

        # Perform precise DECIMAL multiplication
        product = pc.multiply(col_a, col_b)
        
        # Convert result to pandas DataFrame
        result_df = pd.DataFrame({"PRODUCT": product.to_pandas()})
        return result_df

In [None]:
# Instantiate model with context and register
model_context = custom_model.ModelContext()
model = arrow_decimal_fast(model_context)

# Sample input for schema validation
sample_input = pd.DataFrame({
    "COL_A": [9999999999.9],
    "COL_B": [2.0]
})

# Register the model
registry = Registry(session=session)
registry.log_model(
    model=model,
    model_name="pyarrow_decimal_fast",
    conda_dependencies=["pyarrow"],
    version_name="v1",
    sample_input_data=sample_input,
    target_platform=["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"],
    options={"overwrite": True},
)

# Running the model (Example)
result = model.arrow_multiply(sample_input)
print(result)

## Model Inferencing Service to avoid UDxF overhead

In [None]:
image_repo_name = "arrow_inference_image" # needs to be created
cp_name = "E2E_ML_GPU_NV_S"               # compute_pool should be created
num_spcs_nodes = '2'
service_name = 'ARROW_DEC_SERVICE'

current_database = session.get_current_database().replace('"', '')
current_schema = session.get_current_schema().replace('"', '')
extended_image_repo_name = f"{current_database}.DEFAULT_SCHEMA.{image_repo_name}"
extended_service_name = f'{current_database}.DEFAULT_SCHEMA.{service_name}'

In [None]:
DROP SERVICE IF EXISTS {{service_name}};

In [None]:
mv_base.create_service(
    service_name=extended_service_name,
    service_compute_pool=cp_name,
    image_repo=extended_image_repo_name,
    ingress_enabled=True,
    max_instances=int(num_spcs_nodes),
    build_external_access_integration="ALLOW_ALL_INTEGRATION"
)

In [None]:
SHOW SERVICES;

In [None]:
# mv_base = model_registry.get_model(f"MORTGAGE_LENDING_MLOPS_{VERSION_NUM}").version("XGB_GPU_DIST")
mv_base = model_registry.get_model(f"MORTGAGE_LENDING_MLOPS_{VERSION_NUM}").version("v1")
mv_base.list_services()

In [None]:
mv_base.run(test, 
            function_name = "arrow_multiply", 
            service_name = "DEFAULT_SCHEMA.ARROW_DEC_SERVICE")

## Further Testing / Debugging

In [None]:
from snowflake.snowpark import functions as F

def safe_numeric_37_1(column_name):
    """
    Efficiently ensures a column is safely cast to NUMERIC(37,1) without overflow.
    - Caps values to the safe range of NUMERIC(37,1).
    - Uses standard CAST because range is strictly controlled.
    """
    max_value = 9999999999999999999999999999999999.9
    min_value = -999999999999999999999999999999999.9
    
    # Applying safe cast with range control using standard CAST
    return F.cast(
        F.least(
            F.greatest(F.round(F.col(column_name), 1), F.lit(min_value)),
            F.lit(max_value)
        ), 
        "DECIMAL(37,1)"  # Using NUMBER because DECIMAL/Numeric is not fully supported
    )

# Apply this to your Snowflake DataFrame using Snowpark API
snow_df_input = session.table("SAMPLE_TBL").select(
    safe_numeric_37_1("col_a").alias("col_a"),
    safe_numeric_37_1("col_b").alias("col_b")
)

# Collect the data without overflow
result_df = snow_df_input.collect()
print(result_df)


In [None]:
arrow_tbl = snow_df_input.to_arrow()
print(arrow_tbl)

In [None]:
# Step 2: Multiply and truncate result to decimal128(38, 10)
a = arrow_tbl.column("COL_A")
b = arrow_tbl.column("COL_B")
product = pc.multiply(a, b)
print(product)
# rounded = pc.cast((pc.cast(product, pa.decimal256(75, 1))), pa.decimal128(38, 1))