Define the Widgets for variables that need to be parameterized in this notebook

In [0]:
%sql
CREATE WIDGET TEXT secure_core_catalog_name DEFAULT "";
CREATE WIDGET TEXT secure_core_generate_kek_dsk_function_name DEFAULT "";
CREATE WIDGET TEXT secure_core_get_decrypted_dsk_function_name DEFAULT "";
CREATE WIDGET TEXT secure_core_key_vault_name DEFAULT "";
CREATE WIDGET TEXT secure_core_pci_pii_attribute_table_name DEFAULT "";

In [0]:
%python
import os
DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME = os.getenv("DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME")
print(DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME)
dbutils.widgets.text("DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME", DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME)

Create a Python backed SQL UDF to generate the KEK (Key Encryption Key) and DSK (Data Salt Key) for each of the PCI and PII attributes and store the KEK as Keys and DSK as secrets in key vault. The DSK will be encrypted using the respective KEK prior to storing as secrets.

In [0]:
%sql
CREATE OR REPLACE FUNCTION `${secure_core_catalog_name}`.default.${secure_core_generate_kek_dsk_function_name}(
  key_vault_url    STRING,
  kek_name         STRING,
  dsk_secret_name  STRING
)
RETURNS STRING
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'generate_kek_dsk_batch'
CREDENTIALS (`${DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME}` DEFAULT)
ENVIRONMENT (
  dependencies = '[
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_identity-1.16.0-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_keyvault_keys-4.8.0-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_keyvault_secrets-4.8.0-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/msal-1.32.3-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/msal_extensions-1.3.1-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_common-1.1.28-py2.py3-none-any.whl"
  ]',
  environment_version = 'None'
)
AS
$$
from datetime import datetime, timedelta
import os, base64
import pandas as pd
from typing import Iterator, Tuple

from azure.identity import DefaultAzureCredential
from azure.keyvault.keys import KeyClient, KeyRotationPolicy, KeyRotationLifetimeAction, KeyRotationPolicyAction
from azure.keyvault.keys.crypto import CryptographyClient, EncryptionAlgorithm
from azure.keyvault.secrets import SecretClient
from azure.core.exceptions import ResourceNotFoundError

_credential = DefaultAzureCredential()
# ------------------------------------------------------------------
# Small client cache keyed by Key Vault URL
# ------------------------------------------------------------------
_client_cache = {}
_CACHE_TTL_SEC = 300  # 5 minutes

def _get_clients(vault_url: str):
    from time import time
    current_time = time()
    
    if vault_url not in _client_cache or current_time - _client_cache[vault_url]["timestamp"] > _CACHE_TTL_SEC:
        try:
            _client_cache[vault_url] = {
                "clients": (
                    KeyClient(vault_url=vault_url, credential=_credential),
                    SecretClient(vault_url=vault_url, credential=_credential)
                ),
                "timestamp": current_time
            }
        except ClientAuthenticationError as e:
            raise PermissionError(f"Authentication failed for {vault_url}: {str(e)}")
    
    return _client_cache[vault_url]["clients"]
    
def generate_kek_dsk_batch(
    batches: Iterator[Tuple[pd.Series, pd.Series, pd.Series]]) -> Iterator[pd.Series]:

    for key_vault_url_s, kek_name_s, dsk_secret_name_s in batches:
        results = []

        for vault_url, kek_name, dsk_name in zip(key_vault_url_s, kek_name_s, dsk_secret_name_s):

            key_client, secret_client = _get_clients(vault_url)

            # ---- 1. ensure KEK exists (& add rotation policy) ----------
            try:
                kek = key_client.get_key(kek_name)
                kek_message = f"KEK '{kek_name}' exists (version: {kek.properties.version})."
            except ResourceNotFoundError:
                expires = datetime.utcnow() + timedelta(days=365)
                kek = key_client.create_rsa_key(
                    name       = kek_name,
                    size       = 2048
                )

                key_client.update_key_properties(
                    name=kek_name,
                    expires_on=expires
                )
                
                try:
                    rotation_policy = KeyRotationPolicy(
                        expires_in="P2Y",
                        lifetime_actions=[
                            KeyRotationLifetimeAction(
                                time_after_create="P18M",
                                action=KeyRotationPolicyAction.rotate
                            ),
                            KeyRotationLifetimeAction(
                                time_before_expiry="P30D",
                                action=KeyRotationPolicyAction.notify
                            )
                        ]
                    )
                    key_client.update_key_rotation_policy(kek_name, rotation_policy)
                    kek_message = f"KEK {kek_name} created with rotation policy enabled."
                except Exception as e:
                    kek_message = f"Error in creating KEK {str(e)}"

            # ---- 2. skip if DSK already exists -------------------------
            try:
                secret_client.get_secret(dsk_name)
                results.append(f"{kek_message} DSK secret '{dsk_name}' already exists – skipped creation.")
                continue
            except ResourceNotFoundError:
                pass

            # ---- 3. generate & encrypt new DSK ------------------------
            dsk_raw = os.urandom(32)                                    # 256-bit
            dsk_b64 = base64.urlsafe_b64encode(dsk_raw).decode().rstrip("=")

            crypto_client = CryptographyClient(kek.id, credential=_credential)
            enc = crypto_client.encrypt(EncryptionAlgorithm.rsa_oaep_256,
                                        dsk_b64.encode("utf-8"))
            encrypted_dsk_b64 = base64.b64encode(enc.ciphertext).decode()

            # ---- 4. store DSK with KEK-version tag --------------------
            secret_client.set_secret(
                name = dsk_name,
                value = encrypted_dsk_b64,
                tags  = {"kek_version": kek.properties.version},
                content_type = "encrypted_dsk_base64"
            )

            results.append(f"{kek_message} DSK '{dsk_name}' encrypted & stored.")

        yield pd.Series(results)
$$;

Call the function to generate the KEK and DSK for each PCI / PII attribute

In [0]:
%sql
SELECT pci_pii_attribute_name, kek_name, dsk_name, `${secure_core_catalog_name}`.default.${secure_core_generate_kek_dsk_function_name}(
  '${secure_core_key_vault_name}',
  kek_name,
  dsk_name
) as generated_kek_dsk FROM `${secure_core_catalog_name}`.default.${secure_core_pci_pii_attribute_table_name};

Create a Python backed SQL UDF to get the decrypted DSK (Data Salt Key) to be used for hashing the PCI / PII attributes while loading to vault catalog.

In [0]:
CREATE OR REPLACE FUNCTION `${secure_core_catalog_name}`.default.${secure_core_get_decrypted_dsk_function_name}(
  key_vault_url    STRING,
  kek_name         STRING,
  dsk_secret_name  STRING
)
RETURNS STRING
LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'decrypt_dsk_batch'
CREDENTIALS (`${DATABRICKS_DEFAULT_SERVICE_CREDENTIAL_NAME}` DEFAULT)
ENVIRONMENT (
  dependencies = '[
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_identity-1.16.0-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_keyvault_keys-4.8.0-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_keyvault_secrets-4.8.0-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/msal-1.32.3-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/msal_extensions-1.3.1-py3-none-any.whl",
      "/Volumes/edai-00001-dev-a01-masking/default/python_packages/azure_common-1.1.28-py2.py3-none-any.whl"
  ]',
  environment_version = 'None'
)
AS
$$
from datetime import datetime
import base64, pandas as pd, os
from typing import Iterator, Tuple

from azure.identity             import DefaultAzureCredential
from azure.keyvault.keys        import KeyClient
from azure.keyvault.keys.crypto import CryptographyClient, EncryptionAlgorithm
from azure.keyvault.secrets     import SecretClient
from azure.core.exceptions      import ResourceNotFoundError, ClientAuthenticationError

# ------------------------------------------------------------------
# Databricks-patched credential from the service-credential
# ------------------------------------------------------------------
_credential = DefaultAzureCredential()

# ------------------------------------------------------------------
# Small client cache keyed by Key Vault URL
# ------------------------------------------------------------------
_client_cache = {}
_CACHE_TTL_SEC = 300  # 5 minutes

def _get_clients(vault_url: str):
    from time import time
    current_time = time()
    
    if vault_url not in _client_cache or current_time - _client_cache[vault_url]["timestamp"] > _CACHE_TTL_SEC:
        try:
            _client_cache[vault_url] = {
                "clients": (
                    KeyClient(vault_url=vault_url, credential=_credential),
                    SecretClient(vault_url=vault_url, credential=_credential)
                ),
                "timestamp": current_time
            }
        except ClientAuthenticationError as e:
            raise PermissionError(f"Authentication failed for {vault_url}: {str(e)}")
    
    return _client_cache[vault_url]["clients"]

# ------------------------------------------------------------------
# Batch handler: decrypt DSK for each row
# ------------------------------------------------------------------
def decrypt_dsk_batch(
    batches: Iterator[Tuple[pd.Series, pd.Series, pd.Series]]
) -> Iterator[pd.Series]:

    for key_vault_url_s, kek_name_s, dsk_secret_name_s in batches:
        out = []

        for vault_url, kek_name, dsk_name in zip(
                key_vault_url_s, kek_name_s, dsk_secret_name_s):

            key_client, secret_client = _get_clients(vault_url)

            # ---- A. fetch encrypted DSK --------------------------------
            secret = secret_client.get_secret(dsk_name)
            encrypted_b64 = secret.value
            # Key version is stored in tag by the generator UDF
            kek_version_tag = secret.properties.tags.get("kek_version") if secret.properties.tags else None

            # ---- B. resolve KEK version --------------------------------
            if kek_version_tag:
                kek = key_client.get_key(kek_name, kek_version_tag)
            else:
                # fallback to latest if tag missing
                kek = key_client.get_key(kek_name)

            # ---- C. decrypt --------------------------------------------
            cipher_bytes = base64.b64decode(encrypted_b64)
            crypto_client = CryptographyClient(kek.id, credential=_credential)
            plain_bytes   = crypto_client.decrypt(
                                EncryptionAlgorithm.rsa_oaep_256,
                                cipher_bytes
                           ).plaintext

            # plain_bytes is the original url-safe base64 string of the raw DSK
            # restore any missing padding then decode to raw 32-byte key
            plain_b64 = plain_bytes.decode()
            if len(plain_b64) % 4:
                plain_b64 += "=" * (4 - len(plain_b64) % 4)
            dsk_raw = base64.urlsafe_b64decode(plain_b64)

            out.append(dsk_raw.hex())
            #out.append(dsk_raw)

        yield pd.Series(out)
$$;