SOP0127 - Restore Keys For Encryption At Rest
=============================================

Description
-----------

Use this notebook to connect to the `controller` database and restore
both system managed and external keys for encryption at rest.

Steps
-----

### Parameters

Set the `file_to_restore_from`, where we can restore your encryption
keys from. Please make sure it has json file extension.

Set the `certificate_protection_password`. This is the password which
was used to encrypt your certificate.

In [None]:
import os.path
import json

# Backup file with keys exported from BDC cluster
file_to_restore_from = r"your\path\to\bdcEncryptionKeys.json"

certificate_protection_password = "your_password"

print(f"Key(s) will be restored from: '{file_to_restore_from}'. Please make sure you have permission to access this path.")

if os.path.isfile(file_to_restore_from) == False:
    raise SystemExit(f"{file_to_restore_from} does not point to a valid file.")

with open(file_to_restore_from) as json_file:
    try:
        backup_json_data = json.load(json_file)
    except ValueError:
        raise SystemExit(f"{file_to_restore_from} does not have JSON content")

### Instantiate Kubernetes client

In [None]:
# Instantiate the Python Kubernetes client into 'api' variable

import os
from IPython.display import Markdown

try:
    from kubernetes import client, config
    from kubernetes.stream import stream
except ImportError: 

    # Install the Kubernetes module
    import sys
    !{sys.executable} -m pip install kubernetes    
    
    try:
        from kubernetes import client, config
        from kubernetes.stream import stream
    except ImportError:
        display(Markdown(f'HINT: Use [SOP059 - Install Kubernetes Python module](../install/sop059-install-kubernetes-module.ipynb) to resolve this issue.'))
        raise

if "KUBERNETES_SERVICE_PORT" in os.environ and "KUBERNETES_SERVICE_HOST" in os.environ:
    config.load_incluster_config()
else:
    try:
        config.load_kube_config()
    except:
        display(Markdown(f'HINT: Use [TSG118 - Configure Kubernetes config](../repair/tsg118-configure-kube-config.ipynb) to resolve this issue.'))
        raise

api = client.CoreV1Api()

print('Kubernetes client instantiated')

### Get the namespace for the big data cluster

Get the namespace of the Big Data Cluster from the Kuberenetes API.

**NOTE:**

If there is more than one Big Data Cluster in the target Kubernetes
cluster, then either:

-   set \[0\] to the correct value for the big data cluster.
-   set the environment variable AZDATA\_NAMESPACE, before starting
    Azure Data Studio.

In [None]:
# Place Kubernetes namespace name for BDC into 'namespace' variable

if "AZDATA_NAMESPACE" in os.environ:
    namespace = os.environ["AZDATA_NAMESPACE"]
else:
    try:
        namespace = api.list_namespace(label_selector='MSSQL_CLUSTER').items[0].metadata.name
    except IndexError:
        from IPython.display import Markdown
        display(Markdown(f'HINT: Use [TSG081 - Get namespaces (Kubernetes)](../monitor-k8s/tsg081-get-kubernetes-namespaces.ipynb) to resolve this issue.'))
        display(Markdown(f'HINT: Use [TSG010 - Get configuration contexts](../monitor-k8s/tsg010-get-kubernetes-contexts.ipynb) to resolve this issue.'))
        display(Markdown(f'HINT: Use [SOP011 - Set kubernetes configuration context](../common/sop011-set-kubernetes-context.ipynb) to resolve this issue.'))
        raise

print('The kubernetes namespace for your big data cluster is: ' + namespace)

### Python function queries `controller` database and return results.

### Create helper function to run `sqlcmd` against the controller database

In [None]:
name = 'controldb-0'
container = 'mssql-server'

import base64

def run_sqlcmd(query, show_count):
    
    no_count_string=""
    no_count_suffix=""
    if not show_count:
        no_count_string="SET NOCOUNT ON; "
        no_count_suffix = f""" | sed 2d"""

    command=f"""export SQLCMDPASSWORD=$(cat /var/run/secrets/credentials/mssql-sa-password/password);
    /opt/mssql-tools/bin/sqlcmd -b -S . -U sa -y0 -Q "{no_count_string}
    {query}" -d controller  -s"^" | base64 -w 0
    """
    output = stream(api.connect_get_namespaced_pod_exec, name, namespace, command=['/bin/sh', '-c', command], container=container, stderr=True, stdout=True)
    output = base64.b64decode(output)
    return output
print("Function 'run_sqlcmd' defined")

### Python function to execute kubernetes command.

In [None]:
name = 'controldb-0'
container = 'mssql-server'

def execute_k8scommand(command):
    output=stream(api.connect_get_namespaced_pod_exec, name, namespace, command=['/bin/sh', '-c', command], container=container, stderr=True, stdout=True)
    return str(output)
print("Function 'execute_k8scommand' defined")

### Define function to check presence of keys.

In [None]:
import json

def can_import_keys(backup_json_file_path):
  key_list = list({f"'{backup_entry['id']}'" for backup_entry in backup_json_data['keys']})
  key_list.sort()
  keys_list_string = ','.join(key_list)

  print(f"Key names found in backup file: {keys_list_string}")
  sql_check_existing_keys = f"""
    select account_name from Credentials
    where account_name in ({keys_list_string}) and type in ('2','3')
    group by account_name, type
    order by account_name asc
    FOR JSON AUTO"""

  existing_keys = run_sqlcmd(sql_check_existing_keys, False)
  existing_keys = [f"'{key['account_name']}'" for key in json.loads(existing_keys)] if b'' != existing_keys else []
  
  can_import = True
  if 0 != len(existing_keys):
      can_import = False
      keys_list_string = ','.join(existing_keys)
      print(f"Keys {keys_list_string} already exist in the Control plane. Please delete all the keys before importing them.");
  
  return can_import
print("Function 'can_import_keys' defined")

### Restore encryption keys.

Restore the keys from key backup json file, into the Big Data Cluster
control plane.

In [None]:
import base64
import json

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa

def decrypt_rsa_private_key(pem_with_encrypted_private_key, certificate_protection_password):
    rsa_private_key = serialization.load_pem_private_key(
        data=pem_with_encrypted_private_key.encode('utf-8'),
        password=certificate_protection_password.encode('utf-8'),
        backend=default_backend())
    return rsa_private_key

def decrypt_aes_key(encrypted_aes_key, rsa_private_key):
    decrypted_aes_key = rsa_private_key.decrypt(
        encrypted_aes_key,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA256()),
            algorithm=hashes.SHA256(),
            label=None
        ))
    return decrypted_aes_key

def decrypt_data_with_aes_key(encrypted_message, aes_key, iv):
    cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_message = decryptor.update(encrypted_message) + decryptor.finalize()

    block_size_in_bytes = 16
    size_bytes = 4
    decrypted_message = decrypted_message[size_bytes:size_bytes + int.from_bytes(decrypted_message[:size_bytes], byteorder='big')]
    return decrypted_message.decode('utf-8')

pem_with_encrypted_private_key = backup_json_data['encryptor']['private_key']
encrypted_aes_key = base64.urlsafe_b64decode(backup_json_data['encryptor']['symmetric_key'])
rsa_private_key = decrypt_rsa_private_key(pem_with_encrypted_private_key, certificate_protection_password)
decrypted_aes_key = decrypt_aes_key(encrypted_aes_key, rsa_private_key)

def rebuild_credential_sequence_id_map(backup_json_data):
    sequence_id_map = {}
    for backup_entry in backup_json_data['keys']:
        credential_sequence_id = backup_entry['credential_sequence_id'] if 'version' in backup_entry else 0
        account_name = backup_entry['id']
        if account_name not in sequence_id_map:
            sequence_id_map[account_name] = []
        sequence_id_map[account_name].append(credential_sequence_id)

    lowest_sequence_id = -1
    imported_sequence_ids_map_for_keys = {}
    for key_name in sequence_id_map:
        sequence_ids = sequence_id_map[key_name]
        imported_sequence_ids = sorted(sequence_ids)
        imported_sequence_ids.reverse()
        imported_sequence_ids_map = {imported_sequence_ids[i] : lowest_sequence_id - i for i in range(len(imported_sequence_ids))}
        lowest_sequence_id -= len(imported_sequence_ids)
        imported_sequence_ids_map_for_keys[key_name] = imported_sequence_ids_map
    return imported_sequence_ids_map_for_keys

command_length_limit = 30000
if not can_import_keys(file_to_restore_from):
    from IPython.display import Markdown
    print("The keys cannot be imported since one or more keys exist. Delete the keys and re-run this cell before proceeding")
    display(Markdown(f'HINT: Use [SOP0124 - List Keys For Encryption At Rest.](../tde/sop124-list-keys-encryption-at-rest.ipynb) to resolve this issue.'))
    display(Markdown(f'HINT: Use [SOP0125 - Delete Key For Encryption At Rest](../tde/sop125-delete-keys-encryption-at-rest.ipynb) to resolve this issue.'))
else:
    key_password = base64.b64decode(str(api.read_namespaced_secret("controller-db-rw-secret", namespace).data['encryptionPassword'])).decode('utf-8')
    tsql_values_template = """('{account_name}', {type}, EncryptByKey(Key_GUID('ControllerDbSymmetricKey'), N'{secret_credential}'),
    N'{application_metadata}', N'{version}', {creation_timestamp_utc}, {credential_sequence_id})"""
    value_list = []
    
    imported_sequence_ids_map_for_keys = rebuild_credential_sequence_id_map(backup_json_data)

    for backup_entry in backup_json_data['keys']:
        secret_value = ""
        entry_type = backup_entry['type']
        if(entry_type == 2):
            secret_value = backup_entry['value']
        elif(entry_type == 3):
            aes_encrypted_jwk = base64.urlsafe_b64decode(backup_entry['value'])
            iv =  base64.urlsafe_b64decode(backup_entry['iv'])
            secret_value = decrypt_data_with_aes_key(aes_encrypted_jwk, decrypted_aes_key, iv)
        else:
            raise SystemExit(f"""Invalid type found {entry_type}""")
        if 'version' in backup_entry:
            version = backup_entry['version']
            creation_timestamp_utc = "N'{0}'".format(backup_entry['creation_timestamp_utc'])
            credential_sequence_id = backup_entry['credential_sequence_id']
        else:
            version = "0"
            creation_timestamp_utc = "SYSUTCDATETIME()"
            credential_sequence_id = 0

        # Make credential_sequence_id negative to treat imported keys as older than any new keys created on this BDC deployment
        credential_sequence_id = imported_sequence_ids_map_for_keys[backup_entry["id"]][credential_sequence_id]

        value_list.append(tsql_values_template.format(account_name = backup_entry["id"], type = backup_entry["type"], secret_credential = secret_value, application_metadata = backup_entry["tags"].replace('"', '\\"'),
        version = version, creation_timestamp_utc = creation_timestamp_utc, credential_sequence_id = credential_sequence_id))
        print(f"""Encryption Key {backup_entry["id"]} processed.""")
        # Due to some command length limit, we will restore rows in batch.
        if len(','.join(value_list)) > command_length_limit:
            tsql_value_string = ','.join(value_list)
            tsql_insert_statement = f"""
            OPEN SYMMETRIC KEY ControllerDbSymmetricKey DECRYPTION BY PASSWORD = '{key_password}'
            INSERT INTO Credentials(account_name, type, encrypted_password, application_metadata, version, creation_timestamp_utc, credential_sequence_id) VALUES {tsql_value_string}"""
            run_output = run_sqlcmd(tsql_insert_statement, True)
            value_list = []

    if len(value_list) is not 0:
        tsql_value_string = ','.join(value_list)
        tsql_insert_statement = f"""
        OPEN SYMMETRIC KEY ControllerDbSymmetricKey DECRYPTION BY PASSWORD = '{key_password}'
        INSERT INTO Credentials(account_name, type, encrypted_password, application_metadata, version, creation_timestamp_utc, credential_sequence_id) VALUES {tsql_value_string}
        """
        run_output = run_sqlcmd(tsql_insert_statement, True)

In [None]:
print("Notebook execution is complete.")