# Calculate treeSHAP on Confidential containers on ACI Demo

### Step 1 : Setup 

In this step we will 
- Setup the necessary libraries and customizable variables 
- Request the public key from attestation well known endpoint and create jwks object that will be used to verify the attestation tokens later in the demo. 

In [6]:

import subprocess
import json
import base64
import requests
import jwt
import json
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from hashlib import sha256

# TODO: update to reflect your setup
registry_name = 'ttdungacr.azurecr.io' # e.g. docker.io/pawankhandavillims
attestation_endpoint = 'sharedneu.neu.attest.azure.net' #this is a shared instance of MAA that you can use for testing
#runtime data is data you want reflected in the attestation token, this is not used in a meaningful way in this demo. 
runtime_data = 'eyJrZXlzIjpbeyJlIjoiQVFBQiIsImtleV9vcHMiOlsiZW5jcnlwdCJdLCJraWQiOiJOdmhmdXEyY0NJT0FCOFhSNFhpOVByME5QXzlDZU16V1FHdFdfSEFMel93Iiwia3R5IjoiUlNBIiwibiI6InY5NjVTUm15cDh6Ykc1ZU5GdURDbW1pU2VhSHB1akcyYkNfa2VMU3V6dkRNTE8xV3lyVUp2ZWFhNWJ6TW9PMHBBNDZwWGttYnFIaXNvelZ6cGlORExDbzZkM3o0VHJHTWVGUGYyQVBJTXUtUlNyek41NnF2SFZ5SXI1Y2FXZkhXay1GTVJEd0FlZnlOWVJIa2RZWWtnbUZLNDRoaFVkdGxDQUtFdjVVUXBGWmp2aDRpSTlqVkJkR1lNeUJhS1FMaGpJNVdJaC1RRzZaYTVzU3VPQ0ZNbm11eXV2TjVEZmxwTEZ6NTk1U3MtRW9CSVktTmlsNmxDdHZjR2dSLUlialVZSEFPczVhamFtVHpnZU84a3gzVkNFOUhjeUtteVVac2l5aUY2SURScDJCcHkzTkhUakl6N3Rta3BUSHg3dEhuUnRsZkUyRlV2MEI2aV9RWWxfWkE1USJ9XX0='

#extract the public key from the openid-configuration and create a JWKS object

def rsa_public_key_from_pem(cert_pem):
    cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend())
    return cert.public_key()

response = requests.get(f"https://{attestation_endpoint}/certs")

if response.status_code == 200:
    cert_data = response.json()
    keys = cert_data['keys']

    # Step 2: Create a JWKS object
    jwks = []

    for key_data in keys:
        x5c = key_data.get('x5c', [])
        if x5c:
            cert_pem = "-----BEGIN CERTIFICATE-----\n" + x5c[0] + "\n-----END CERTIFICATE-----"
            public_key = rsa_public_key_from_pem(cert_pem)
            jwks.append((key_data['kid'], public_key))

    print("JWKS object created successfully.")
else:
    print("Failed to retrieve the signing keys.")

JWKS object created successfully.


### Step 2 : Generate security policy for sum container 

We will use the confcom tooling to generate a security policy from the Azure Resource Manager template. We will further generate a SHA-256 hash of the security policy which will be used later in the demo to verify whether the container group is running the right configuration. 

Note : The "ccePolicy" attribute of the ARM template must be set to a null string "" for this step to work. The tooling requires user input to override the policy if already present and user input is not supported in the notebook. 

In [7]:
# get the hash of the security policy
with open("./cal_shap/template.json", "r") as f:
    # open the template and grab the cce policy
    template = json.loads(f.read())
    security_policy = template.get('resources')[0]['properties']['confidentialComputeProperties']['ccePolicy']
    # decode the base64 encoded policy and hash it
    sha256_hash_sum = sha256(base64.b64decode(security_policy)).hexdigest()
    # print the hash
    print("hash of security policy: ", sha256_hash_sum)


hash of security policy:  526fa8e3c54fbe320686a0d5f00247e9617631ef6cea305096eddc231eece681


#### Step 3 : Check for successful deployment on Azure Portal and get attestation token
In this step we will check for the successful deployment and get the attestation token verified by MAA. We will compare the contents of the "x-ms-sevsnpvm-hostdata" claim and check whether it matches the policy hash from step 3

In [None]:

# TODO: update the public_ip_address to the public ip address of your deployed container group. You can obtain the ip address from azure portal.  

public_ip_address = '52.155.181.121'

#runtime data is data you want reflected in the attestation token, this is not used in a meaningful way in this demo. 
runtime_data = 'eyJrZXlzIjpbeyJlIjoiQVFBQiIsImtleV9vcHMiOlsiZW5jcnlwdCJdLCJraWQiOiJOdmhmdXEyY0NJT0FCOFhSNFhpOVByME5QXzlDZU16V1FHdFdfSEFMel93Iiwia3R5IjoiUlNBIiwibiI6InY5NjVTUm15cDh6Ykc1ZU5GdURDbW1pU2VhSHB1akcyYkNfa2VMU3V6dkRNTE8xV3lyVUp2ZWFhNWJ6TW9PMHBBNDZwWGttYnFIaXNvelZ6cGlORExDbzZkM3o0VHJHTWVGUGYyQVBJTXUtUlNyek41NnF2SFZ5SXI1Y2FXZkhXay1GTVJEd0FlZnlOWVJIa2RZWWtnbUZLNDRoaFVkdGxDQUtFdjVVUXBGWmp2aDRpSTlqVkJkR1lNeUJhS1FMaGpJNVdJaC1RRzZaYTVzU3VPQ0ZNbm11eXV2TjVEZmxwTEZ6NTk1U3MtRW9CSVktTmlsNmxDdHZjR2dSLUlialVZSEFPczVhamFtVHpnZU84a3gzVkNFOUhjeUtteVVac2l5aUY2SURScDJCcHkzTkhUakl6N3Rta3BUSHg3dEhuUnRsZkUyRlV2MEI2aV9RWWxfWkE1USJ9XX0='

# call the maa endpoint
maa_response = requests.post(
    f'http://{public_ip_address}/attest/maa',  
    json={"runtime_data": runtime_data, "maa_endpoint": attestation_endpoint}
)
print("Maa Response Status Code: ", maa_response.status_code)

# 1. Check for HTTP errors (e.g., 404, 500)
try:
    maa_response.raise_for_status() 
except requests.exceptions.HTTPError as e:
    print(f"Error: Request failed with HTTP status {maa_response.status_code}")
    print("Full Response Text:", maa_response.text)
    # Stop execution or return
    # If the container group is down or the IP is wrong, this will catch it.
    raise

# 2. Check the response content for application-level errors
try:
    response_json = maa_response.json()
except requests.exceptions.JSONDecodeError:
    # If the response isn't even valid JSON, print the raw text and exit
    print("Error: Received non-JSON response.")
    print("Raw Response Text:", maa_response.text)
    raise

print("Maa Response JSON: ", response_json)

# Check if the response contains an expected error message instead of the token
result = response_json.get("result")
if result == '404 page not found' or 'error' in response_json or (isinstance(result, str) and 'error' in result.lower()):
    print(f"Error: Application-level error reported in response: {response_json}")
    # Stop execution or return
    # This catches the specific '404 page not found' result you saw.
    raise Exception("Application Attestation Error") 

# If the `result` is a JSON string, it needs to be parsed (as intended by your original code)
if isinstance(result, str):
    try:
        # Attempt to parse the `result` string, which should contain the actual token payload
        parsed_result = json.loads(result)
        token = parsed_result.get("token")
        if not token:
            raise ValueError("Token not found in parsed result.")
    except (json.JSONDecodeError, ValueError) as e:
        print(f"Error: Failed to parse 'result' string into JSON or extract 'token'.")
        print(f"Result string content: {result}")
        raise
else:
    # Handle the case where the token might be directly in the top-level JSON or another key
    # (though your original code structure suggests a nested token within a 'result' string)
    # For now, we'll assume the nested structure is correct, but this is a point to check.
    print("Error: 'result' key was not a string as expected.")
    raise
# verify the token

header = jwt.get_unverified_header(token)
kid = header['kid']

# Find the key with a matching 'kid' in the JWKS
key_to_use = None
for key_kid, key in jwks:
        if key_kid == kid:
            key_to_use = key
            break

if key_to_use is not None:
        try:
            payload = jwt.decode(token, key=key_to_use, algorithms=["RS256"])
            print("Valid JWT : Attestation token signature verified:", payload)
            
        except jwt.InvalidTokenError:
            print("Invalid JWT")
else:
        print("No matching key found in JWKS")

print("SEV-SNP Host Data:\n", payload.get("x-ms-sevsnpvm-hostdata"))
if(sha256_hash_sum == payload.get("x-ms-sevsnpvm-hostdata")):
    print("Security Policy Hash Matches")
    print("Host is Trusted")
else:
    print("Security Policy Hash Does Not Match")
    print("Host is Not Trusted")

### Step 4 : Check for successful deployment on Azure Portal and get key release
In this step we will check for the successful deployment and get the key verified by MAA.

In [5]:
public_ip_address = '52.155.181.121'
maa_endpoint = 'sharedneu.neu.attest.azure.net'
akv_endpoint = 'testvaultkhtn.vault.azure.net'
kid = 'mykeyv19'

# call the /key/release
key_response = requests.post(
    f'http://{public_ip_address}/key/release',  
    json={"maa_endpoint": maa_endpoint, "akv_endpoint": akv_endpoint, "kid": kid}
)
print("Key release Response Status Code: ", key_response.status_code)

# 1. Check for HTTP errors (e.g., 404, 500)
try:
    key_response.raise_for_status() 
except requests.exceptions.HTTPError as e:
    print(f"Error: Request failed with HTTP status {key_response.status_code}")
    print("Full Response Text:", key_response.text)
    # Stop execution or return
    # If the container group is down or the IP is wrong, this will catch it.
    raise

# 2. Check the response content for application-level errors
try:
    response_json = key_response.json()
except requests.exceptions.JSONDecodeError:
    # If the response isn't even valid JSON, print the raw text and exit
    print("Error: Received non-JSON response.")
    print("Raw Response Text:", key_response.text)
    raise

print("Maa Response JSON: ", response_json)

Key release Response Status Code:  200
Maa Response JSON:  {'public_key_pem': '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwiIUA0ALmgb2L33MxIVq\neYbgnnQn8CJ3w/b8HKjLOcG9upmQ3cZgkI+MEqxngeby58VoT79cbqo3t7c+gO3S\nF1ByKWzmU3Mg8QqrPiX3NgOVx4Cxjc/HQw40eLvYFjstYRpYtD6wkpebR0gUEYRN\nWeGMPe2Now6NUl11Dq6h+7+Wp2BX/DJy3YnHOgcL2Qebby4Qx5kp93UBeT/3tJ3m\nm4gmzEetvA5OOhiNsH+s6ivue4qASIOXrC825bFM/D6sfhuNxnCL4lPEd0wfSVPD\nNf+Vd1uvcJwUD7mqDLWAm6tA9/4XExq3VqU9ORLn2vfY4TLvK8SO78EKD/x/t5Cq\nwwIDAQAB\n-----END PUBLIC KEY-----\n'}


In [9]:
# The base URL of your running Flask API
public_ip_address = "52.155.181.121"

# The item ID you want to get
item_id = 2

# Perform a GET request to /items/<item_id>
response = requests.get(f"http://{public_ip_address}/items/{item_id}")

# Check the status code
print("Status Code:", response.status_code)

# Parse JSON response
if response.ok:
    item_data = response.json()
    print("Item Data:", item_data)
else:
    print("Error:", response.text)

ConnectTimeout: HTTPConnectionPool(host='20.107.164.222', port=80): Max retries exceeded with url: /items/2 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPConnection object at 0x1055a5ba0>, 'Connection to 20.107.164.222 timed out. (connect timeout=None)'))

In [37]:
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.backends import default_backend
import os
import base64

# Extract the PEM string and convert to bytes
public_key_pem = response_json["public_key_pem"].encode("utf-8")

# Parse the PEM into a cryptography public key object
public_key = serialization.load_pem_public_key(public_key_pem, backend=default_backend())

# Generate a random 256-bit (32-byte) symmetric key (e.g., for AES-256)
symmetric_key = os.urandom(32)
print("Generated symmetric key:", base64.b64encode(symmetric_key).decode())

# Encrypt the symmetric key with the RSA public key (using OAEP padding)
encrypted_key = public_key.encrypt(
    symmetric_key,
    padding.OAEP(
        mgf=padding.MGF1(algorithm=hashes.SHA256()),
        algorithm=hashes.SHA256(),
        label=None
    )
)

# Encode the result in base64 (so you can send/store it easily)
encrypted_key_b64 = base64.b64encode(encrypted_key).decode()

print("\nEncrypted symmetric key (base64):")
print(encrypted_key_b64)

Generated symmetric key: EilZRpz21M6uXtkIsI4n/gHeuEXv9deX9S0ajBxJd/w=

Encrypted symmetric key (base64):
aNFLSyGdovrX31qChpbmha+3LeI3YbASonD9gVo7s+3m8a4kfaN7/kuJRudxBgSoAz95gBKzHQOJcjyKCEJl3wVN12JO4ZzI062fdRrUg6ImDLGVKn+bt2wECMshsVOcIusa3jmfK143taYd40W8xlDDQz87E/lqRf8FtJulb5aGyvhPJHcwaGknY9bnvjOqmS/6jOzlea5dtgCLDauCalOaLwPqTk2RQcoeTK7L+OQlfeM3vC6CfFseXU5iIxDldwOZ3YfDMHnhQGmb5Sf+K6UwV7nkxjqMW/dUyZeGQD6Xx30/RHx+VXs/sjDurd0fUHqzKZieXHIKE06C4TDHYg==


In [35]:
url = f"http://{public_ip_address}/decrypt_key"
payload = {
    "encrypted_key_b64": encrypted_key_b64
}

headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)
print("Status code:", response.status_code)

# Pretty-print JSON result
print(json.dumps(response.json(), indent=2))

Status code: 200
{
  "symmetric_key_b64": "WftrHxAft/0aAe4n0pARVyCnqJ1x1jx4VGzNikgL3lA="
}
