In [0]:
from pyspark.sql.functions import col
from envlp_encryp_data_loader import EnvelopeEncryptionDataLoader
import uuid

In [0]:
%run ./register_widgets

Define the widgets for all the variables that can be parameterized

To get the values for all the defined widgets and get it as a dictionary

In [0]:
register_widgets()
params = get_parameters()

Retrieve the parameters required for the notebook

In [0]:
source_catalog = params["envlp_encryp_landing_catalog_name"]
target_catalog = params["envlp_encryp_vault_catalog_name"]
envlp_encryp_core_catalog_name = params["envlp_encryp_core_catalog_name"]
envlp_encryp_decrypt_dsk_function_name = params["envlp_encryp_decrypt_dsk_function_name"]
envlp_encryp_key_vault_name = params["envlp_encryp_key_vault_name"]
envlp_encryp_secret_scope_name = params["envlp_encryp_secret_scope_name"]
envlp_encryp_sp_tenant_id_key = params["envlp_encryp_sp_tenant_id_key"]
envlp_encryp_sp_client_id_key = params["envlp_encryp_sp_client_id_key"]
envlp_encryp_sp_client_secret_key = params["envlp_encryp_sp_client_secret_key"]
dataframe_results_volume_location = params["dataframe_results_volume_location"]
envlp_encryp_aes_decrypt_key = params["envlp_encryp_aes_decrypt_key"]

Initialize an object of the class

In [0]:
envlp_encryp_data_loader = EnvelopeEncryptionDataLoader(spark, params)

The below CTE will:
- Identify all the managed tables existing in landing catalog
- Identify all the managed tables existing in vault catalog
- Identify new tables that exists in landing catalog but not in vault catalog
- Identify the list of columns in all the tables in landing catalog (this is to get the data type, ordinal position of the columns to generate create table statement and align the hash columns next the respective PCI / PII attribute in vault)
- Identify the list of columns that are tagged as PCI / PII for tables in landing catalog. This is to propogate the tags to the tables in vault catalog
- df_tables_to_load contains list of new tables created that needs to be ingested into vault catalog
- df_kek_dsk_to_be_retieved contains the unique list of PCI / PII attributes tagged across all the new tables created in the landing catalog and derive the corresponding KEK & DSK names based on tags classification. This will be used to decrypt the DSK using the KEK to hash the PCI / PII attributes
- df_col_tags contains the column tags across all tbales in the landing catalog

In [0]:
tables_with_pci_pii_columns = f"""
--Identify all the managed tables existing in landing catalog
with tables_in_landing as (
    select table_catalog as catalog_name, table_schema as schema_name, table_name from system.information_schema.tables where table_catalog = '{source_catalog}' and table_type = 'MANAGED'
),
--Identify all the managed tables existing in vault catalog
tables_in_vault as (
    select table_catalog as catalog_name, table_schema as schema_name, table_name from system.information_schema.tables where table_catalog = '{target_catalog}' and table_type = 'MANAGED'
),
--Identify new tables that exists in landing catalog but not in vault catalog and tables where new data is available in landing catalog
tables_from_landing as (
    select 
    tl.catalog_name 
    ,tl.schema_name
    ,tl.table_name 
    from tables_in_landing tl
    left join tables_in_vault tv
    on tl.schema_name = tv.schema_name
    and tl.table_name = tv.table_name
    where tv.table_name is null
),
--Identify the list of columns in all the tables in landing catalog
table_columns as (
    select cols.table_catalog as catalog_name, cols.table_schema as schema_name, cols.table_name, cols.column_name, cols.full_data_type, cols.is_nullable, cols.ordinal_position
    from system.information_schema.columns cols
    inner join tables_from_landing lndtbls
    on cols.table_catalog = lndtbls.catalog_name
    and cols.table_schema = lndtbls.schema_name
    and cols.table_name = lndtbls.table_name
),
sensitivity_category as (
    select catalog_name, schema_name, table_name, column_name, tag_name, tag_value from system.information_schema.column_tags where tag_name = 'sensitivity_category' and tag_value in ('pii', 'pci')
),
sensitivity_type as (
    select catalog_name, schema_name, table_name, column_name, tag_name, tag_value from system.information_schema.column_tags where tag_name = 'sensitivity_type'
),
sensitive_col_tags as (
    select sensitivity_type.*
    from sensitivity_category
    inner join sensitivity_type
    on sensitivity_category.catalog_name = sensitivity_type.catalog_name
    and sensitivity_category.schema_name = sensitivity_type.schema_name
    and sensitivity_category.table_name = sensitivity_type.table_name
    and sensitivity_category.column_name = sensitivity_type.column_name
),
--Fetch the new tables that are created in landing catalog and what columns are tagged as PCI / PII in those tables
tables_in_landing_with_pci_pii_columns as (
    select distinct 
    tbl_cols.catalog_name
    ,tbl_cols.schema_name
    ,tbl_cols.table_name
    ,tbl_cols.column_name
    ,tbl_cols.full_data_type
    ,tbl_cols.is_nullable
    ,tbl_cols.ordinal_position
    ,coltg.tag_name
    ,coltg.tag_value
    ,concat('fis-key-encryption-key-', coltg.tag_value) as kek_name
    ,concat('fis-data-salt-key-', coltg.tag_value) as dsk_name
    from table_columns tbl_cols
    left join 
    sensitive_col_tags as coltg
    on coltg.catalog_name = tbl_cols.catalog_name
    and coltg.schema_name = tbl_cols.schema_name
    and coltg.table_name = tbl_cols.table_name
    and coltg.column_name = tbl_cols.column_name 
)
select * from tables_in_landing_with_pci_pii_columns 
order by catalog_name, schema_name, table_name, ordinal_position """

df_source_tables_with_pci_pii_columns = spark.sql(tables_with_pci_pii_columns)
#display(df_source_tables_with_pci_pii_columns)

#Save the dataframe results to volume to persist the results
unique_id = str(uuid.uuid4())
path = f"{dataframe_results_volume_location}/{unique_id}"
df_source_tables_with_pci_pii_columns.write.mode("overwrite").parquet(path)

#Read the dataframe results from the persisted location
df_source_tables_with_pci_pii_columns = spark.read.parquet(path)

In [0]:
df_tables_to_load = envlp_encryp_data_loader.get_distinct_tables_to_load(df_source_tables_with_pci_pii_columns)
#display(df_tables_to_load)

df_kek_dsk_to_be_retieved = envlp_encryp_data_loader.get_kek_dsk(df_source_tables_with_pci_pii_columns)
#display(df_kek_dsk_to_be_retieved)

df_kek_dsk_to_be_retieved.createOrReplaceTempView("kek_dsk_to_be_retrieved")

df_col_tags = envlp_encryp_data_loader.get_column_tags(source_catalog)
#display(df_col_tags)

SQL to fetch decrypted Data Salt Key for all the PCI / PII attributes by calling the UDF. The SQL will be used as a CTE when hashing the data

In [0]:
fetch_decrypted_dsk_sql = f"""
 SELECT tag_value as attribute,
         AES_ENCRYPT(`{envlp_encryp_core_catalog_name}`.default.`get_decrypted_dsk`(
           try_secret('{params["envlp_encryp_secret_scope_name"]}', '{params["envlp_encryp_sp_tenant_id_key"]}'),
           try_secret('{params["envlp_encryp_secret_scope_name"]}', '{params["envlp_encryp_sp_client_id_key"]}'),
           try_secret('{params["envlp_encryp_secret_scope_name"]}', '{params["envlp_encryp_sp_client_secret_key"]}'),
           '{envlp_encryp_key_vault_name}',
           kek_name,
           dsk_name
         ), try_secret('{params["envlp_encryp_secret_scope_name"]}', '{params["envlp_encryp_aes_decrypt_key"]}')) as decrypted_dsk
 FROM kek_dsk_to_be_retrieved
 """

- process_single_table - Will process the data for each table.
- The PCI / PII columns are collected from the dataframe (which is read from the source table) and hash statements are created.
- The column names and datatypes are identified to create tables in the vault catalog
- Insert statement is generated by identifying the hash columns
- Alter table statements are created to apply the masking functions to the PCI / PII attributes to the tables in the vault catalog
- Alter tables statements are created to propogate the tags from tables in landing catalog to tables in vault catalog

In [0]:
def process_single_table(row, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, envlp_encryp_data_loader):
    # Extract relevant table metadata from the input row
    source_catalog = row["catalog_name"]
    schema = row["schema_name"]
    table = row["table_name"]

    # Lists to collect ALTER statements for masking and tag propagation
    apply_masking_statements = []
    propagate_tags_statements = []


    # Retrieve PCI/PII tagged columns for hashing
    pci_pii_columns = envlp_encryp_data_loader.collect_pci_pii_columns(source_catalog, schema, table, df_source_tables_with_pci_pii_columns)
    #print(pci_pii_columns)

    # Generate SHA512 hash expressions for PCI/PII columns
    hash_statements = envlp_encryp_data_loader.generate_hash_statements(pci_pii_columns)
    #print(hash_statements)

    # Fetch full list of column names and data types from the source table
    columns_datatype = envlp_encryp_data_loader.collect_source_table_columns_datatype(source_catalog, schema, table, df_source_tables_with_pci_pii_columns)
    #print(columns_datatype)

    
    target_table = f"`{target_catalog}`.`{schema}`.`{table}`"
    
    envlp_encryp_data_loader.create_schema(target_catalog, schema)
    create_target_table_statement = envlp_encryp_data_loader.generate_create_table_statement(
           source_catalog, schema, table, columns_datatype, hash_statements, target_table
    )
    #print(create_target_table_statement)
    spark.sql(create_target_table_statement)

    #print(fetch_decrypted_dsk_sql)
    # Generate the INSERT statement with hash columns (if needed)
    insert_statements = envlp_encryp_data_loader.generate_insert_statement(
        source_catalog, schema, table, target_catalog, columns_datatype, hash_statements, fetch_decrypted_dsk_sql
    )
    #print(insert_statements)

    # Generate ALTER statements for masking functions
    apply_masking_statements.extend(envlp_encryp_data_loader.apply_masking_function_statement(schema, table, target_catalog, pci_pii_columns))

    # Fetch column tags and prepare ALTER statements to propagate them
    column_tags = envlp_encryp_data_loader.collect_column_tags(source_catalog, schema, table, df_col_tags)
    propagate_tags_statements.extend(envlp_encryp_data_loader.propagate_tags_statement(schema, table, target_catalog, column_tags, hash_statements))

    # Execute data insert into the vault table
    spark.sql(insert_statements)

    # Apply masking functions to the appropriate columns
    for apply_mask in apply_masking_statements:
        #print(apply_mask)
        spark.sql(apply_mask)

    # Apply tags to both raw and hash columns
    for apply_tags in propagate_tags_statements:
        #print(apply_tags)
        spark.sql(apply_tags)

Process the table to the vault catalog with hashing and masking function applied

In [0]:
rows = df_tables_to_load.collect()
for row in rows:
  process_single_table(row, df_source_tables_with_pci_pii_columns, df_col_tags, fetch_decrypted_dsk_sql, target_catalog, envlp_encryp_data_loader)

Delete the dataframe results stored in the volume

In [0]:
dbutils.fs.rm(path, recurse=True)

True