In [None]:
# Import necessary packages
import logging
from snowflake.snowpark.context import get_active_session

# Setup Logger
logger = logging.getLogger("travel_time_metrics_logger")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

# Get Snowpark session
session = get_active_session()

# Define target table
database_name = "TOM_DB"
schema_name = "ANALYTICS"
table_name = "TRAVEL_TIME_METRICS"

logger.info("Session initialized and variables set.")


In [None]:
# Function to check if the table exists
def table_exists(session, database_name='', schema_name='', table_name=''):
    query = f"""
        SELECT EXISTS (
            SELECT 1 
            FROM {database_name}.INFORMATION_SCHEMA.TABLES 
            WHERE TABLE_SCHEMA = '{schema_name}' 
            AND TABLE_NAME = '{table_name}'
        ) AS TABLE_EXISTS
    """
    exists = session.sql(query).collect()[0]['TABLE_EXISTS']
    return exists

# Test the function (optional)
table_exists_flag = table_exists(session, database_name, schema_name, table_name)
logger.info(f"Does the table {table_name} exist? {table_exists_flag}")


In [None]:
# Load the harmonized_tom table
harmonized_tom = session.table("TOM_DB.HARMONIZED_TOM.harmonized_tom")

# Display first few rows to inspect data
harmonized_tom.show(5)
logger.info("Loaded harmonized_tom table.")


In [None]:
from snowflake.snowpark import functions as F

# Prepare final_agg DataFrame without changing the column name for RANK
final_agg = (
    session.table("TOM_DB.HARMONIZED_TOM.HARMONIZED_TOM")
    .group_by(
        F.col("RANK"), 
        F.col("CITY"), 
        F.col("COUNTRY")
    ).agg(
        F.avg("AVG_TIME_PER_6_MILES").alias("AVG_TRAVEL_TIME"),
        F.max("CHANGE_IN_CONGESTION").alias("LATEST_CHANGE"),
        F.avg("CONGESTION_LEVEL").alias("AVG_CONGESTION_LEVEL"),
        F.avg("YEARLY_DELAY_HOURS").alias("AVG_TIME_LOST")
    ).select(
        F.col("RANK"),    # Use RANK directly
        F.col("CITY"),
        F.col("COUNTRY"),
        F.col("AVG_TRAVEL_TIME"),
        F.call_udf("TOM_DB.ANALYTICS.CONGESTION_CHANGE_CONSISTENCY", F.col("LATEST_CHANGE")).alias("CONSISTENCY_SCORE"),
        F.col("AVG_CONGESTION_LEVEL"),
        F.col("AVG_TIME_LOST")
    )
)

# Check schema to verify columns
print(final_agg.schema.names)

In [None]:
# If table doesn't exist, create it
if not table_exists(session, database_name, schema_name, table_name):
    final_agg.write.mode("overwrite").save_as_table(f"{database_name}.{schema_name}.{table_name}")
    logger.info(f"Successfully created {table_name}")

# Otherwise, update it incrementally
else:
    # Use RANK consistently everywhere
    cols_to_update = {
        f'"{c.upper()}"': F.col(f"source.{c.upper()}") for c in final_agg.schema.names
    }

    existing_table = session.table(f"{database_name}.{schema_name}.{table_name}")

# Perform the merge operation with backticks
    existing_table.alias("target") \
    .merge(
        final_agg.alias("source"),
        (F.col("target.`RANK`") == F.col("source.`RANK`")) &  
        (F.col("target.`CITY`") == F.col("source.`CITY`")) &  
        (F.col("target.`COUNTRY`") == F.col("source.`COUNTRY`")),  
        [
            F.when_matched().update(cols_to_update), 
            F.when_not_matched().insert(cols_to_update)
        ]
    )

    logger.info(f"Successfully updated {table_name}")

In [None]:
SELECT * FROM TOM_DB.ANALYTICS.TRAVEL_TIME_METRICS LIMIT 10;


In [None]:
CREATE OR REPLACE PROCEDURE TOM_DB.ANALYTICS.UPDATE_TRAVEL_TIME_METRICS()
RETURNS STRING
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
HANDLER = 'run'
PACKAGES = ('snowflake-snowpark-python')
EXECUTE AS CALLER
AS
$$
import logging
from snowflake.snowpark import Session
from snowflake.snowpark import functions as F

def table_exists(session, database_name, schema_name, table_name):
    query = f"""
        SELECT EXISTS (
            SELECT 1 
            FROM {database_name}.INFORMATION_SCHEMA.TABLES 
            WHERE TABLE_SCHEMA = '{schema_name}' 
            AND TABLE_NAME = '{table_name}'
        ) AS TABLE_EXISTS
    """
    result = session.sql(query).collect()
    return result[0]['TABLE_EXISTS'] if result else False

def run(session: Session):
    try:
        # Setup logging
        logger = logging.getLogger("TRAVEL_TIME_METRICS_LOGGER")
        logger.setLevel(logging.INFO)
        
        database_name = "TOM_DB"
        schema_name = "ANALYTICS"
        table_name = "TRAVEL_TIME_METRICS"

        # Prepare final_agg DataFrame without changing the column name for RANK
        final_agg = (
            session.table("TOM_DB.HARMONIZED_TOM.HARMONIZED_TOM")
            .group_by(
                F.col("RANK"), 
                F.col("CITY"), 
                F.col("COUNTRY")
            ).agg(
                F.avg("AVG_TIME_PER_6_MILES").alias("AVG_TRAVEL_TIME"),
                F.max("CHANGE_IN_CONGESTION").alias("LATEST_CHANGE"),
                F.avg("CONGESTION_LEVEL").alias("AVG_CONGESTION_LEVEL"),
                F.avg("YEARLY_DELAY_HOURS").alias("AVG_TIME_LOST")
            ).select(
                F.col("RANK"),    # Use RANK directly
                F.col("CITY"),
                F.col("COUNTRY"),
                F.col("AVG_TRAVEL_TIME"),
                F.call_udf("TOM_DB.ANALYTICS.CONGESTION_CHANGE_CONSISTENCY", F.col("LATEST_CHANGE")).alias("CONSISTENCY_SCORE"),
                F.col("AVG_CONGESTION_LEVEL"),
                F.col("AVG_TIME_LOST")
            )
        )

        # Check schema to verify columns
        logger.info("Final aggregation schema: " + str(final_agg.schema.names))

        # If table doesn't exist, create it
        if not table_exists(session, database_name, schema_name, table_name):
            final_agg.write.mode("overwrite").save_as_table(f"{database_name}.{schema_name}.{table_name}")
            logger.info(f"Successfully created {table_name}")
            return f"Successfully created {table_name}"

        # Otherwise, update it incrementally
        else:
            # Use RANK consistently everywhere
            cols_to_update = {
                f'"{c.upper()}"': F.col(f"source.{c.upper()}") for c in final_agg.schema.names
            }

            existing_table = session.table(f"{database_name}.{schema_name}.{table_name}")

            # Perform the merge operation with backticks
            existing_table.alias("target") \
            .merge(
                final_agg.alias("source"),
                (F.col("target.`RANK`") == F.col("source.`RANK`")) &  
                (F.col("target.`CITY`") == F.col("source.`CITY`")) &  
                (F.col("target.`COUNTRY`") == F.col("source.`COUNTRY`")),  
                [
                    F.when_matched().update(cols_to_update), 
                    F.when_not_matched().insert(cols_to_update)
                ]
            )

            logger.info(f"Successfully updated {table_name}")
            return f"Successfully updated {table_name}"
    
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        return f"Error: {str(e)}"
$$;
