In [0]:
from pyspark.sql.functions import col
from secure_core_data_loader import SecureCoreDataLoader

In [0]:
%run ./register_widgets

In [0]:
register_widgets()
params = get_parameters()

Retrieve the parameters required for the notebook

In [0]:
landing_to_vault_load_tracking_table = params["landing_to_vault_load_tracking_table"]
vault_to_anonymized_load_tracking_table = params["vault_to_anonymized_load_tracking_table"]
source_catalog = params["secure_core_landing_catalog_name"]
target_catalog = params["secure_core_vault_catalog_name"]
secure_core_catalog_name = params["secure_core_catalog_name"]
secure_core_decrypt_dsk_function_name = params["secure_core_decrypt_dsk_function_name"]
secure_core_key_vault_name = params["secure_core_key_vault_name"]
#secure_core_secret_scope_name = params["secure_core_secret_scope_name"]
#secure_core_sp_tenant_id_key = params["secure_core_sp_tenant_id_key"]
#secure_core_sp_client_id_key = params["secure_core_sp_client_id_key"]
#secure_core_sp_client_secret_key = params["secure_core_sp_client_secret_key"]


Initialize an object of the class

In [0]:
secure_core_data_loader = SecureCoreDataLoader(spark, params)

- Get all the last loaded version from each of the landing tables maintained in the tracking table for the landing catalog
- Using table history, identify if there are any new version with WRITE operations for the tables in the landing catalog
- If there is a WRITE operation with a version greater than last loaded version, then there is new data loaded to the table in landing catalog to consider incremental data load 
- Register it as a temp table that will be used in the below CTE

In [0]:
secure_core_data_loader.get_tables_with_new_versions(landing_to_vault_load_tracking_table)

The below CTE will:
- Identify all the managed tables existing in landing catalog
- Identify all the managed tables existing in vault catalog
- Identify new tables that exists in landing catalog but not in vault catalog and derive column vault_table_status as "new"
- Identify tables in landing catalog that already exists in vault for which new data has been loaded. Derive column new_data_available as "yes"
- Identify the list of columns in all the tables in landing catalog (this is to get the data type, ordinal position of the columns to generate create table statement and align the hash columns next the respective PCI / PII attribute in vault)
- Identify the list of columns that are tagged as PCI / PII for tables in landing catalog. This is to propogate the tags to the tables in vault catalog
- df_tables_to_load contains list of new tables created and existing tables with new data in landing catalog that needs to be ingested into vault catalog
- df_kek_dsk_to_be_retieved contains the unique list of PCI / PII attributes tagged across all the new tables created in the landing catalog and derive the corresponding KEK & DSK names based on tags classification. This will be used to decrypt the DSK using the KEK to hash the PCI / PII attributes
- df_col_tags contains the column tags across all tbales in the landing catalog

In [0]:
tables_with_pci_pii_columns = f"""
--Identify all the managed tables existing in landing catalog
with tables_in_landing as (
    select table_catalog as catalog_name, table_schema as schema_name, table_name from system.information_schema.tables where table_catalog = '{source_catalog}' and table_type = 'MANAGED'
),
--Identify all the managed tables existing in vault catalog
tables_in_vault as (
    select table_catalog as catalog_name, table_schema as schema_name, table_name from system.information_schema.tables where table_catalog = '{target_catalog}' and table_type = 'MANAGED'
),
--Identify new tables that exists in landing catalog but not in vault catalog and tables where new data is available in landing catalog
tables_from_landing as (
    select 
    tl.catalog_name 
    ,tl.schema_name
    ,tl.table_name 
    ,case when tv.table_name is null then 'new' else 'existing' end as target_table_status
    ,case when twnv.table_name is not null then 'yes' else 'no' end as new_data_available
    from tables_in_landing tl
    left join tables_in_vault tv
    on tl.schema_name = tv.schema_name
    and tl.table_name = tv.table_name
    left join tables_with_new_versions twnv
    on tl.catalog_name = twnv.catalog_name
    and tl.schema_name = twnv.schema_name
    and tl.table_name = twnv.table_name
    --where tv.table_name is null
),
--Identify the list of columns in all the tables in landing catalog
table_columns as (
    select cols.table_catalog as catalog_name, cols.table_schema as schema_name, cols.table_name, cols.column_name, cols.full_data_type, cols.is_nullable, cols.ordinal_position, lndtbls.target_table_status, lndtbls.new_data_available
    from system.information_schema.columns cols
    inner join tables_from_landing lndtbls
    on cols.table_catalog = lndtbls.catalog_name
    and cols.table_schema = lndtbls.schema_name
    and cols.table_name = lndtbls.table_name
),
sensitivity_category as (
    select catalog_name, schema_name, table_name, column_name, tag_name, tag_value from system.information_schema.column_tags where tag_name = 'sensitivity_category' and tag_value in ('pii', 'pci')
),
sensitivity_type as (
    select catalog_name, schema_name, table_name, column_name, tag_name, tag_value from system.information_schema.column_tags where tag_name = 'sensitivity_type'
),
sensitive_col_tags as (
    select sensitivity_type.*
    from sensitivity_category
    inner join sensitivity_type
    on sensitivity_category.catalog_name = sensitivity_type.catalog_name
    and sensitivity_category.schema_name = sensitivity_type.schema_name
    and sensitivity_category.table_name = sensitivity_type.table_name
    and sensitivity_category.column_name = sensitivity_type.column_name
),
--Fetch the new tables that are created in landing catalog and existing tables with new data and what columns are tagged as PCI / PII in those tables
tables_in_landing_with_pci_pii_columns as (
    select distinct 
    tbl_cols.catalog_name
    ,tbl_cols.schema_name
    ,tbl_cols.table_name
    ,tbl_cols.column_name
    ,tbl_cols.full_data_type
    ,tbl_cols.is_nullable
    ,tbl_cols.ordinal_position
    ,coltg.tag_name
    ,coltg.tag_value
    ,concat('fis-key-encryption-key-', coltg.tag_value) as kek_name
    ,concat('fis-data-salt-key-', coltg.tag_value) as dsk_name
    ,tbl_cols.target_table_status
    ,tbl_cols.new_data_available
    from table_columns tbl_cols
    left join 
    sensitive_col_tags as coltg
    on coltg.catalog_name = tbl_cols.catalog_name
    and coltg.schema_name = tbl_cols.schema_name
    and coltg.table_name = tbl_cols.table_name
    and coltg.column_name = tbl_cols.column_name 
)
select * from tables_in_landing_with_pci_pii_columns 
order by catalog_name, schema_name, table_name, ordinal_position """

df_source_tables_with_pci_pii_columns = spark.sql(tables_with_pci_pii_columns)
#display(df_source_tables_with_pci_pii_columns)
#df_source_tables_with_pci_pii_columns = df_source_tables_with_pci_pii_columns.filter(
#    (col("target_table_status") == "new") | 
#    ((col("target_table_status") == "existing") & (col("new_data_available") == "yes"))
#)
#display(df_source_tables_with_pci_pii_columns)

df_tables_to_load = secure_core_data_loader.get_distinct_tables_to_load(df_source_tables_with_pci_pii_columns)
#display(df_tables_to_load)

df_kek_dsk_to_be_retieved = secure_core_data_loader.get_kek_dsk(df_source_tables_with_pci_pii_columns)
#display(df_kek_dsk_to_be_retieved)

df_kek_dsk_to_be_retieved.createOrReplaceTempView("kek_dsk_to_be_retrieved")

df_col_tags = secure_core_data_loader.get_column_tags(source_catalog)
#display(df_col_tags)


In [0]:
df_kek_dsk = spark.sql("select * from kek_dsk_to_be_retrieved")
#display(df_kek_dsk)

SQL to fetch decrypted Data Salt Key for all the PCI / PII attributes by calling the UDF. The SQL will be used as a CTE when hashing the data

In [0]:
fetch_decrypted_dsk_sql = f"""
  SELECT pci_pii_attribute_name as attribute,
          AES_ENCRYPT(`{secure_core_catalog_name}`.default.`{secure_core_decrypt_dsk_function_name}`(
            '{secure_core_key_vault_name}',
            kek_name,
            dsk_name
          ), 'dskencryptionkey') as decrypted_dsk
  FROM `{secure_core_catalog_name}`.default.pci_pii_attributes
  """

In [0]:
#fetch_decrypted_dsk_sql = f"""
#  SELECT tag_value as attribute,
#          AES_ENCRYPT(`{secure_core_catalog_name}`.default.`get_decrypted_dsk`(
#            '{secure_core_key_vault_name}',
#            kek_name,
#            dsk_name
#          ), 'dskencryptionkey') as decrypted_dsk
#  FROM kek_dsk_to_be_retrieved
#  """

- process_single_table - Will process the data for each table.
- Check if CDF is enabled in the source (landing) table. If not enabled, CDF will be enabled in the source tables and entry will be added to the tracking table
- The PCI / PII columns are collected from the dataframe (which is read from the source table) and hash statements are created.
- The column names and datatypes are identified to create tables in the vault catalog and add an entry to vault_to_anonymized tracking table to track the data loads between vault catalog and anonymized catalog
- Insert statement is generated by identifying the data between last loaded version and any new version of data loaded to landing catalog. This is done for existing tables where CDF is enabled. If it is a new table where CDF is not enabled in the source, all the data is read
- Alter table statements are created to apply the masking functions to the PCI / PII attributes to the tables in the vault catalog
- Alter tables statements are created to propogate the tags from tables in landing catalog to tables in vault catalog
- The landing_to_vault_load tracking table is updated with the latest version loaded from the landing tables.



In [0]:
def process_single_table(row, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, secure_core_data_loader, landing_to_vault_load_tracking_table, vault_to_anonymized_load_tracking_table):
    # Extract relevant table metadata from the input row
    source_catalog = row["catalog_name"]
    schema = row["schema_name"]
    table = row["table_name"]
    target_table_status = row["target_table_status"]
    new_data_available = row["new_data_available"]

    # Lists to collect ALTER statements for masking and tag propagation
    apply_masking_statements = []
    propagate_tags_statements = []

    # Check if Change Data Feed (CDF) is enabled on the source table
    cdf_enabled = secure_core_data_loader.get_cdf_enabled_status(source_catalog, schema, table)

    # Enable CDF if not already enabled
    if not cdf_enabled:
        secure_core_data_loader.enable_cdf(
            source_catalog, schema, table, landing_to_vault_load_tracking_table)

    # Retrieve PCI/PII tagged columns for hashing
    pci_pii_columns = secure_core_data_loader.collect_pci_pii_columns(source_catalog, schema, table, df_source_tables_with_pci_pii_columns)
    #print(pci_pii_columns)

    # Generate SHA512 hash expressions for PCI/PII columns
    hash_statements = secure_core_data_loader.generate_hash_statements(pci_pii_columns)
    #print(hash_statements)

    # Fetch full list of column names and data types from the source table
    columns_datatype = secure_core_data_loader.collect_source_table_columns_datatype(source_catalog, schema, table, df_source_tables_with_pci_pii_columns)
    #print(columns_datatype)

    # If the target table is marked as 'new', create it and add tracking entry
    if target_table_status == "new":
        target_table = f"`{target_catalog}`.`{schema}`.`{table}`"
        create_target_table_statement = secure_core_data_loader.generate_create_table_statement(
           source_catalog, schema, table, columns_datatype, hash_statements, target_table
        )
        add_entry_to_vault_tracking_table_statement = secure_core_data_loader.add_entry_to_tracking_table(target_catalog, schema, table, vault_to_anonymized_load_tracking_table)
        #print(create_target_table_statement)
        spark.sql(create_target_table_statement)
        spark.sql(add_entry_to_vault_tracking_table_statement)

    # Proceed to insert data if:
    # - It's a new vault table, or
    # - It's an existing vault table *and* new data is available
    if (target_table_status == "new" or (target_table_status == "existing" and new_data_available == "yes")):
        # Query to get the last loaded version from the tracking table
        last_loaded_version_query = f"""
            SELECT last_loaded_version FROM 
            `{secure_core_catalog_name}`.`default`.`{landing_to_vault_load_tracking_table}`
            WHERE catalog_name = '{source_catalog}' AND schema_name = '{schema}' AND table_name = '{table}'
        """
        last_loaded_version = spark.sql(last_loaded_version_query).collect()[0][0]

        #print(fetch_decrypted_dsk_sql)
        # Generate the INSERT statement with hash columns (if needed)
        insert_statements = secure_core_data_loader.generate_insert_statement(
            source_catalog, schema, table, target_catalog, columns_datatype, hash_statements, last_loaded_version, cdf_enabled, fetch_decrypted_dsk_sql
        )
        #print(insert_statements)

        # Generate ALTER statements for masking functions
        apply_masking_statements.extend(secure_core_data_loader.apply_masking_function_statement(schema, table, target_catalog, pci_pii_columns))

        # Fetch column tags and prepare ALTER statements to propagate them
        column_tags = secure_core_data_loader.collect_column_tags(source_catalog, schema, table, df_col_tags)
        propagate_tags_statements.extend(secure_core_data_loader.propagate_tags_statement(schema, table, target_catalog, column_tags, hash_statements))

        # Generate the UPDATE statement to update the tracking table with the new version
        update_tracking_table_statements = secure_core_data_loader.update_tracking_table_statement(source_catalog, schema, table, landing_to_vault_load_tracking_table)
        #print(update_tracking_table_statements)

        # Execute data insert into the vault table
        spark.sql(insert_statements)

        # Apply masking functions to the appropriate columns
        for apply_mask in apply_masking_statements:
            #print(apply_mask)
            spark.sql(apply_mask)

        # Apply tags to both raw and hash columns
        for apply_tags in propagate_tags_statements:
            #print(apply_tags)
            spark.sql(apply_tags)

        # Update the version tracking table after successful insert
        spark.sql(update_tracking_table_statements)

In [0]:
#rows = df_tables_to_load.collect()
#for row in rows:
 # process_single_table(row, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, secure_core_data_loader, landing_to_vault_load_tracking_table, vault_to_anonymized_load_tracking_table)

process_tables_parallel - Function definition to load the tables in parallel using ThreadPoolExecutor

In [0]:
from concurrent.futures import ThreadPoolExecutor, as_completed

def process_tables_parallel(df_tables_to_load, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, secure_core_data_loader, landing_to_vault_load_tracking_table, vault_to_anonymized_load_tracking_table, max_workers=5):
    # Collect all rows from the DataFrame into a local list.
    # Each row contains metadata for a single table to be processed.
    rows = df_tables_to_load.collect()

    # Create a ThreadPoolExecutor with the specified number of worker threads.
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit each table (row) to be processed in parallel using process_single_table function.
        # This returns a list of Future objects representing the execution of each task.
        futures = [executor.submit(
                    process_single_table, row, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, secure_core_data_loader, landing_to_vault_load_tracking_table, vault_to_anonymized_load_tracking_table) 
                for row in rows]
        
        # As each future completes, handle its result (or exception if one occurred).
        for future in as_completed(futures):
            #future.result()
             try:
                 # Calling result() will re-raise any exception that occurred in the thread.
                 future.result()
             except Exception as e:
                 # Log any exception that occurred while processing a table.
                 print(f"Error processing a table: {e}")

Call the process_tables_parallel function to process the tables from landing to vault catalog in parallel

In [0]:
#Call the function to process the tables in parallel.
process_tables_parallel(df_tables_to_load, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, secure_core_data_loader, landing_to_vault_load_tracking_table, vault_to_anonymized_load_tracking_table, max_workers=5)