In [None]:
SELECT current_database() AS DATABASE_NAME, current_schema() AS SCHEMA_NAME

In [None]:
# Import python packages
import logging
# from snowflake.core import Root
from snowflake.snowpark.context import get_active_session

logger = logging.getLogger("fred_logger")

current_context_df = cells.sql_get_context.to_pandas()
database_name = current_context_df.iloc[0,0]
schema_name = current_context_df.iloc[0,1]

session = get_active_session()

logger.info("03_analytics_table_processing start")

In [None]:
CREATE OR REPLACE PROCEDURE merge_fred_updates_sp(DATABASE_NAME STRING, SCHEMA_NAME STRING, ENV STRING)
 RETURNS STRING
 LANGUAGE PYTHON
 RUNTIME_VERSION=3.10
 PACKAGES=('snowflake-snowpark-python')
 HANDLER='main'
AS
$$
from snowflake.snowpark import Session
import snowflake.snowpark.functions as F



def table_exists(session, schema='', name=''):
    exists = session.sql("SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS".format(schema, name)).collect()[0]['TABLE_EXISTS']
    return exists

def create_fred_table(session, DATABASE_NAME, ENV):
    _ = session.sql(f"CREATE TABLE FRED_10Y_2Y LIKE {DATABASE_NAME}.{ENV}_HARMONIZED.FRED_FLATTENED").collect()
    _ = session.sql("ALTER TABLE FRED_10Y_2Y ADD COLUMN META_UPDATED_AT TIMESTAMP").collect()

# Uncomment only if we need to process another table
# def create_fred_stream(session):
#     _ = session.sql("CREATE STREAM FRED_10Y_2Y_STREAM ON TABLE FRED_10Y_2Y").collect()

def merge_fred_updates(session, DATABASE_NAME, ENV, SCHEMA_NAME):
    # _ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE').collect()

    source = session.table(f"{DATABASE_NAME}.{ENV}_HARMONIZED.FRED_STREAM")
    target = session.table(f"{DATABASE_NAME}.{SCHEMA_NAME}.FRED_10Y_2Y")

    # TODO: Is the if clause supposed to be based on "META_UPDATED_AT"?
    cols_to_update = {c: source[c] for c in source.schema.names if "METADATA" not in c}
    metadata_col_to_update = {"META_UPDATED_AT": F.current_timestamp()}
    updates = {**cols_to_update, **metadata_col_to_update}

    # merge into DIM_CUSTOMER
    target.merge(source, target['OBSERVATION_DATE'] == source['OBSERVATION_DATE'], \
                        [F.when_matched().update(updates), F.when_not_matched().insert(updates)])
    
def main(session: Session, DATABASE_NAME: str, SCHEMA_NAME: str, ENV: str) -> str:
    
    if not table_exists(session, schema=SCHEMA_NAME, name='FRED_10Y_2Y'):
            create_fred_table(session, DATABASE_NAME, ENV)
            # create_fred_stream(session)
    # Process data incrementally
    merge_fred_updates(session, DATABASE_NAME, ENV, SCHEMA_NAME)
    return "FRED_10Y_2Y table updated successfully!"

$$;


In [None]:
# To call the sproc
# session.use_schema(f"{database_name}.{schema_name}")
# env = schema_name[:3]
# session.sql(f"CALL merge_fred_updates_sp('{database_name}', '{schema_name}', '{env}')").collect()


In [None]:
def create_spread_udf(session, schema_name):
    session.sql(f"""
        CREATE OR REPLACE FUNCTION {schema_name}.calculate_spread(ten_year_yield FLOAT, two_year_yield FLOAT) 
        RETURNS FLOAT 
        LANGUAGE PYTHON 
        RUNTIME_VERSION = '3.8' 
        PACKAGES = ('snowflake-snowpark-python')
        HANDLER = 'calculate_spread' 
        AS 
        $$
def calculate_spread(ten_year_yield, two_year_yield):
    if ten_year_yield is None or two_year_yield is None:
        return None
    return ten_year_yield - two_year_yield
        $$;
        """).collect()

create_spread_udf(session, schema_name)

In [None]:
CREATE OR REPLACE FUNCTION CHECK_SPREAD_STATUS(spread FLOAT)
RETURNS STRING
AS
$$
    CASE
        WHEN spread > 0 THEN 'POSITIVE'
        WHEN spread < 0 THEN 'NEGATIVE'
        ELSE 'ZERO'
    END
$$;

In [None]:
CREATE OR REPLACE PROCEDURE create_analytical_tables_sp(SCHEMA_NAME STRING, TABLE_NAME STRING)
 RETURNS STRING
 LANGUAGE PYTHON
 RUNTIME_VERSION=3.10
 PACKAGES=('snowflake-snowpark-python')
 HANDLER='main'
AS
$$
import time
from snowflake.snowpark import Session
import snowflake.snowpark.types as T
import snowflake.snowpark.functions as F


def table_exists(session, schema='', name=''):
    exists = session.sql("SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS".format(schema, name)).collect()[0]['TABLE_EXISTS']
    return exists


def create_fred_combined_weekly_table(session, source_schema, source_table):
    """Create the weekly aggregation table if it doesn't exist"""
    COMBINED_COLUMNS = [
        T.StructField("WEEK_START", T.DateType()),
        T.StructField("TOTAL_RECORDS", T.DecimalType()),
        T.StructField("MIN_10Y", T.DecimalType()),
        T.StructField("MAX_10Y", T.DecimalType()),
        T.StructField("AVG_10Y", T.DecimalType()),
        T.StructField("MIN_2Y", T.DecimalType()),
        T.StructField("MAX_2Y", T.DecimalType()),
        T.StructField("AVG_2Y", T.DecimalType()),
        T.StructField("AVG_SPREAD", T.DecimalType()),
        T.StructField("SPREAD_STATUS", T.StringType()),
        T.StructField("META_UPDATED_AT", T.TimestampType())
    ]
    COMBINED_SCHEMA = T.StructType(COMBINED_COLUMNS)

    session.create_dataframe([[None]*len(COMBINED_SCHEMA.names)], schema=COMBINED_SCHEMA) \
           .na.drop() \
           .write.mode('overwrite').save_as_table(f'{source_schema}.FRED_COMBINED_WEEKLY')
    print("FRED_COMBINED_WEEKLY table created")


def create_fred_combined_monthly_table(session, source_schema, source_table):
    """Create the monthly aggregation table if it doesn't exist"""
    COMBINED_COLUMNS = [
        T.StructField("MONTH_START", T.DateType()),
        T.StructField("TOTAL_RECORDS", T.DecimalType()),
        T.StructField("MIN_10Y", T.DecimalType()),
        T.StructField("MAX_10Y", T.DecimalType()),
        T.StructField("AVG_10Y", T.DecimalType()),
        T.StructField("MIN_2Y", T.DecimalType()),
        T.StructField("MAX_2Y", T.DecimalType()),
        T.StructField("AVG_2Y", T.DecimalType()),
        T.StructField("AVG_SPREAD", T.DecimalType()),
        T.StructField("SPREAD_STATUS", T.StringType()),
        T.StructField("META_UPDATED_AT", T.TimestampType())
    ]
    COMBINED_SCHEMA = T.StructType(COMBINED_COLUMNS)

    session.create_dataframe([[None]*len(COMBINED_SCHEMA.names)], schema=COMBINED_SCHEMA) \
           .na.drop() \
           .write.mode('overwrite').save_as_table(f'{source_schema}.FRED_COMBINED_MONTHLY')
    print("FRED_COMBINED_MONTHLY table created")


def aggregate_fred_daily(session, source_schema, source_table):
    """Aggregate data at daily level and calculate spreads"""
    # Reference the combined source table
    source_table_ref = f"{source_schema}.{source_table}"
    fred_combined = session.table(source_table_ref)
    
    # Calculate spread (10Y - 2Y)
    fred_daily = fred_combined.select(
        fred_combined['OBSERVATION_DATE'].alias('OBS_DATE'),
        fred_combined['10Y_YIELD'],
        fred_combined['2Y_YIELD'],
        F.call_function("calculate_spread", fred_combined['10Y_YIELD'], fred_combined['2Y_YIELD']).alias('SPREAD'),
        F.call_function("check_spread_status", F.call_function("calculate_spread", F.col('10Y_YIELD'), F.col('2Y_YIELD'))).alias("SPREAD_STATUS")
    )
    
    fred_daily.write.mode('overwrite').save_as_table(f'{source_schema}.FRED_COMBINED_DAILY')
    print("FRED_DAILY table aggregated")


def aggregate_fred_weekly(session, source_schema, source_table):
    """Aggregate data at weekly level from the source table"""
    # Reference the combined source table
    source_table_ref = f"{source_schema}.{source_table}"
    fred_combined = session.table(source_table_ref)
    
    # Perform weekly aggregation
    fred_weekly_agg = fred_combined.group_by(F.date_trunc('WEEK', F.col('OBSERVATION_DATE')).alias("WEEK_START")) \
                                   .agg(
                                       F.count('OBSERVATION_DATE').alias("TOTAL_RECORDS"),
                                       F.round(F.min('10Y_YIELD'), 2).alias("MIN_10Y"),
                                       F.round(F.max('10Y_YIELD'), 2).alias("MAX_10Y"),
                                       F.round(F.avg('10Y_YIELD'), 2).alias("AVG_10Y"),
                                       F.round(F.min('2Y_YIELD'), 2).alias("MIN_2Y"),
                                       F.round(F.max('2Y_YIELD'), 2).alias("MAX_2Y"),
                                       F.round(F.avg('2Y_YIELD'), 2).alias("AVG_2Y"),
                                       F.round(F.avg(F.call_function(f"{source_schema}.calculate_spread", 
                                                                  F.col('10Y_YIELD'), F.col('2Y_YIELD'))), 2).alias("AVG_SPREAD"),
                                       F.call_function("CHECK_SPREAD_STATUS", F.avg(F.call_function("calculate_spread", F.col('10Y_YIELD'), F.col('2Y_YIELD')))).alias("SPREAD_STATUS")
                                   )
    
    # Add timestamp and write to table
    fred_weekly_agg = fred_weekly_agg.withColumn("META_UPDATED_AT", F.current_timestamp())
    fred_weekly_agg.write.mode('overwrite').save_as_table(f'{source_schema}.FRED_COMBINED_WEEKLY')
    
    print("FRED_COMBINED_WEEKLY table aggregated")


def aggregate_fred_monthly(session, source_schema, source_table):
    """Aggregate data at monthly level from the source table"""
    # Reference the combined source table
    source_table_ref = f"{source_schema}.{source_table}"
    fred_combined = session.table(source_table_ref)
    
    # Perform monthly aggregation
    fred_monthly_agg = fred_combined.group_by(F.date_trunc('MONTH', F.col('OBSERVATION_DATE')).alias("MONTH_START")) \
                                    .agg(
                                        F.count('OBSERVATION_DATE').alias("TOTAL_RECORDS"),
                                        F.round(F.min('10Y_YIELD'), 2).alias("MIN_10Y"),
                                        F.round(F.max('10Y_YIELD'), 2).alias("MAX_10Y"),
                                        F.round(F.avg('10Y_YIELD'), 2).alias("AVG_10Y"),
                                        F.round(F.min('2Y_YIELD'), 2).alias("MIN_2Y"),
                                        F.round(F.max('2Y_YIELD'), 2).alias("MAX_2Y"),
                                        F.round(F.avg('2Y_YIELD'), 2).alias("AVG_2Y"),
                                        F.round(F.avg(F.call_function(f"{source_schema}.calculate_spread", 
                                                                  F.col('10Y_YIELD'), F.col('2Y_YIELD'))), 2).alias("AVG_SPREAD"),
                                        F.call_function("CHECK_SPREAD_STATUS", F.avg(F.call_function("calculate_spread", F.col('10Y_YIELD'), F.col('2Y_YIELD')))).alias("SPREAD_STATUS")
                                    )
    
    # Add timestamp and write to table
    fred_monthly_agg = fred_monthly_agg.withColumn("META_UPDATED_AT", F.current_timestamp())
    fred_monthly_agg.write.mode('overwrite').save_as_table(f'{source_schema}.FRED_COMBINED_MONTHLY')
    
    print("FRED_COMBINED_MONTHLY table aggregated")

def main(session: Session, SCHEMA_NAME, TABLE_NAME) -> str:
    if not table_exists(session, schema=SCHEMA_NAME, name=TABLE_NAME):
        return f"Error: Source table {SCHEMA_NAME}.{TABLE_NAME} does not exist"
    
    print(f"Source table {SCHEMA_NAME}.{TABLE_NAME} found")
    
    # Create aggregate tables if they don't exist
    if not table_exists(session, schema=SCHEMA_NAME, name='FRED_COMBINED_WEEKLY'):
        create_fred_combined_weekly_table(session, SCHEMA_NAME, TABLE_NAME)
    
    if not table_exists(session, schema=SCHEMA_NAME, name='FRED_COMBINED_MONTHLY'):
        create_fred_combined_monthly_table(session, SCHEMA_NAME, TABLE_NAME)
    
    # Run aggregations
    aggregate_fred_daily(session, SCHEMA_NAME, TABLE_NAME)
    aggregate_fred_weekly(session, SCHEMA_NAME, TABLE_NAME)
    aggregate_fred_monthly(session, SCHEMA_NAME, TABLE_NAME)

    return f"Successfully aggregated analytics tables from {SCHEMA_NAME}.{TABLE_NAME}"
$$;


In [None]:
# To call the sproc
# session.sql(f"CALL create_analytical_tables_sp('{schema_name}', 'FRED_10Y_2Y')").collect()

In [None]:
-- drop table FRED_COMBINED_MONTHLY;
-- drop table FRED_COMBINED_WEEKLY;
-- drop table FRED_DAILY;
-- drop table FRED_MONTHLY;
-- drop table FRED_WEEKLY;