In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import json
import sys

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

## Parameter instruction

In [None]:
'''
If calling the notebook from a task using EXECUTE NOTEBOOK MY_DATABASE.PUBLIC.MY_NOTEBOOK
please pass the following three params

1. Source Account Identifier
2. Data Share Name
3. Source Database environment Identifier (e.g. _PRD if the database on Production follows a suffix DB_CORE_HR_PRD)

The values will be read internally by notebook as follows
sys.argv = ['SRCACCOUNT', 'SHARENAME', '_PRD']

If there are no params passed then the notebook will use the default values fro the config.json file
and sys.argv will be blank
sys.argv = []
'''


In [None]:
import yaml

# Load Configuration from file or string
with open("config.yaml", "r") as f:
    app_config = yaml.safe_load(f)

app_config

In [None]:
if len(sys.argv) == 3:
    V_source_account = sys.argv[0]
    V_share_name = sys.argv[1]
    V_source_environment_identifier = sys.argv[2]
else:
    V_source_account = app_config['default']['source_account']
    V_share_name = app_config['default']['share_name']
    V_source_environment_identifier = app_config['default']['source_environment_identifier']

V_organisation_name = session.sql('SELECT CURRENT_ORGANIZATION_NAME()').collect()[0][0]

V_this_environment_identifier = app_config['this_environment_identifier']
V_log_table = app_config['log_table']
V_execution_role = app_config['execution_role']
V_create_table_if_not_exists = app_config['create_table_if_not_exists']
V_create_schema_if_not_exists = app_config['create_schema_if_not_exists']

V_mask_column_list = app_config['mask_column_list']

print (f"{V_organisation_name}.{V_source_account} | {V_share_name} | {V_source_environment_identifier} | {V_this_environment_identifier}")
print (f"{V_execution_role} | {V_create_schema_if_not_exists} | {V_create_table_if_not_exists}")

In [None]:
# Set Execution Role context
sql = f"USE ROLE {V_execution_role}"
print(sql)
session.sql(sql).collect()

In [None]:
sql = f"SHOW SHARES LIKE '{V_share_name}'"
share = session.sql(sql).collect()

# Access the COMMENT field from the first row
comment = share[0]["comment"] if share else None

# print("Comment:", comment)
json_comment = comment.replace("'", '"')
json_data = json.loads(json_comment)
json_data

table_details = json_data["tables"]
database_name = json_data["database_name"]
target_database = database_name.replace(V_source_environment_identifier, V_this_environment_identifier)
print(f"{database_name} --> {target_database}")

In [None]:
# Create Shared Database
source_shared_db = f"READER{V_share_name}"
sql = f"CREATE OR REPLACE DATABASE {source_shared_db} FROM SHARE {V_organisation_name}.{V_source_account}.{V_share_name}"
print(sql)
session.sql(sql).collect()

In [None]:
def data_share_log_table_info(share_name, log_table, target_db, target_schema, target_table, date_col, filter_date, creator_id, config, refresh_ts):

    filter_date = 'NULL' if filter_date == '' else f"'{filter_date}'"
    config_json = json.dumps(config)  # converts dict to JSON string with double quotes
    config_sql_value = f"PARSE_JSON('{config_json}')"

    log_sql = f"""
    INSERT INTO {log_table} (SHARE_NAME, DATABASE_NAME, SCHEMA_NAME, TABLE_NAME, FILTER_DATE_COLUMN, FILTER_DATE, DATA_SHARE_CREATOR_ID, DATA_SHARE_CONFIG, REFRESH_TS)
        SELECT
            '{share_name}',
            '{target_db}',
            '{target_schema}',
            '{target_table}',
            '{date_col}',
            {filter_date},
            '{creator_id}',
            {config_sql_value},
            '{refresh_ts}'
    """
    
    session.sql(log_sql).collect()    

In [None]:
from datetime import datetime

# Define the new mask column list
mask_column_list = "name, email, login"

refresh_ts = datetime.utcnow()
for table_full_name, share_config in table_details.items():
    schema_name, table_name = table_full_name.split('.')

    column_sql = f"""
        SELECT COLUMN_NAME, DATA_TYPE
        FROM {source_shared_db}.INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = '{schema_name}' AND TABLE_NAME = '{table_name}'
        ORDER BY ORDINAL_POSITION
    """

    df_cols = session.sql(column_sql)
    df_pd = df_cols.to_pandas()

    # Original column list (raw, unchanged)
    col_list_original = df_pd["COLUMN_NAME"].tolist()
    raw_cols = ", ".join(col_list_original)

    # Start with the original list, and allow modifications
    col_list = col_list_original.copy()

    # Create a lookup for column data types
    col_type_map = dict(zip(df_pd["COLUMN_NAME"], df_pd["DATA_TYPE"]))

    # Apply masking if needed
    if "mask" in share_config:
        for mask_info in share_config["mask"]["columns"]:
            mask_column = mask_info["mask_column"]
            masked_tag = mask_info["masked_tag"]
            for i in range(len(col_list)):
                if mask_column == col_list[i]:
                    if masked_tag.startswith("="):
                        col_list[i] = f"{masked_tag[1:]} AS {mask_column}"
                    else:
                        col_data_type = col_type_map.get(mask_column, "").lower()
                        if any(t in col_data_type for t in ["number", "int", "float", "decimal", "boolean"]):
                            col_list[i] = f"{masked_tag} AS {mask_column}"  # no quotes
                        else:
                            col_list[i] = f"'{masked_tag}' AS {mask_column}"  # with quotes

    # Convert the mask_column_list string to a list of patterns
    mask_patterns = [pattern.strip() for pattern in V_mask_column_list]

    # Loop through all columns and apply pattern-based masking
    for i in range(len(col_list)):
        column_name = col_list_original[i]  # Use original to avoid matching on aliased columns
        for pattern in mask_patterns:
            if pattern.lower() in column_name.lower():
                # Overwrite the column with a masked value
                col_list[i] = f"'*' AS {column_name}"
                break # Move to the next column once a match is found

    val_cols = ", ".join(col_list)

    target_table = f"{target_database}.{schema_name}.{table_name}"
    source_table = f"{source_shared_db}.{schema_name}.{table_name}"


    select_sql = f"SELECT {val_cols} FROM {source_table}"
    # Apply filter if provided
    if "filter" in share_config:
        date_column = share_config['filter']['date_column']
        filter_date = json_data['filter_date']
        select_sql += f" WHERE TO_DATE({date_column}) > '{filter_date}'"
    else:
        date_column = ''
        filter_date = ''

    execute_sql = ""
    if V_create_schema_if_not_exists:
        execute_sql += f"CREATE OR ALTER SCHEMA {target_database}.{schema_name}; "

    print(f"Loading table {target_table}")
    print(f"{select_sql}")
    
    if V_create_table_if_not_exists:
        execute_sql += f"CREATE OR REPLACE TABLE {target_table} AS {select_sql}; "
    else:
        execute_sql += f"TRUNCATE TABLE {target_table};"
        execute_sql += f"INSERT INTO {target_table} ({raw_cols}) {select_sql}; "

    session.sql("ALTER SESSION SET MULTI_STATEMENT_COUNT = 0").collect()
    session.sql(execute_sql).collect()

    data_share_log_table_info(
        V_share_name,
        V_log_table,
        target_database,
        schema_name,
        table_name,
        date_column,
        filter_date,
        json_data['share_created_by'],
        share_config,
        refresh_ts
    )

print(f"Finished loading tables to database {target_database}")

In [None]:
df = session.sql(f"select * from {V_log_table} where refresh_ts = '{refresh_ts}' order by record_id").collect()
df