<a href="https://colab.research.google.com/github/AltmannPeter/privacy-key-management/blob/main/HDK.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Blinded Proof of Possession Keys for WSCD protected EUDIWs

In brief:

- Presentations will include a Proof of Possession key protected by the EUDIW Secure Cryptographic Environment (SCE)
- To ensure privacy, the Proof of Possession key material must be unique per presented attestation.
- Generating unique key pairs for every attestation is a key management problem
- One possible solution is Hierarchical Deterministic Key Derivation, where child keys are generated from parent keys
- Each child key is appears unrelated to the parent key or to another child key
- Each child key is also a private and public key and requires the same SCE as the parent key
- Hardened derivation uses the private key to derive child keys
- Regular derivation uses the public key to derive child keys
- Derivation has both depth (level) and width (index)

Existing alternatives

- [BIP32](https://en.bitcoin.it/wiki/BIP_0032) used in Bitcoin
- [SLIP-0010](https://slips.readthedocs.io/en/latest/slip-0010/) used in Ethereum
- [ARKG](https://datatracker.ietf.org/doc/draft-bradleylundberg-cfrg-arkg/) by  Yubico
- [HDK](https://sander.github.io/hierarchical-deterministic-keys/feat/bs-restructure/draft-dijkhuis-cfrg-hdkeys.html#name-introduction) a collaboration with Sander
- PoA by Eric should be mentioned too

The ARKG and HDK approach are both designed with the EUDIW in mind.

## Hierarchical Deterministic Key derivation for EUDIW

### Introduction

With HDK, a large set of seemingly unrelated keys can be bound to an SCE that protects a single private key. Each HDK is deterministically defined by a path, optionally alternated by key handles provided by another party, e.g., an Issuer. There are several instantiations of HDK, of which ECDH and ECSDSA are supported by existing SCEs.

The approach builds on / aligns with the following specifications:

- [RFC 9180](https://datatracker.ietf.org/doc/rfc9180/) for the KEM
- [RFC 9380](https://datatracker.ietf.org/doc/rfc9380/) used in key blinding
- [KBSS](https://datatracker.ietf.org/doc/draft-irtf-cfrg-signature-key-blinding/) details derivation approaches
- [ARKG](https://datatracker.ietf.org/doc/draft-bradleylundberg-cfrg-arkg/) related key derivation specification that HDK extends

Two HDK derivation approaches are possible, local and remote. Local allows the User to derive keys on their local device. Remote allows the User to derive keys from a remotely generated (e.g., Issuer) key handle.

Each HDK instantiation builds on the following:

1. A Key Encapsulation Mechanism (KEM)
2. A Key blinding (KB) scheme
3. A derivation approach that can be either additive (ADD) or multiplicative (MUL).

In our examples below we will use:

- DHKEM(P-256, HKDF-SHA256)
  - Nsecret 32
  - Nenc 65
  - Npk 65
  - Nsk 32

### Building blocks

#### RFC 9180 parts

The RFC 9180 relevant parts are those that detail the Key Encapsulation Mechanism (KEM). The KEM allows two parties to generate a shared secret and to transmit it securely. This shared secret will be used to enable two parties, the Issuer and the User, to derive blinded keys in such a way that the derived keys seem unrelated to anyone who does not know the shared secret.

The KEM used in key derivation is from [RFC 9180 HPKE](https://www.rfc-editor.org/rfc/rfc9180.html), more specifically DH-KEM as detailed in section [4.1](https://www.rfc-editor.org/rfc/rfc9180.html#section-4.1). The DH-KEM relies on a number of primitives:

* Both `GenerateKeyPair()` and `DeriveKeyPair(ikm)` can output a private public key pair `(sk, pk)`
* `Encap(pkR)` and `Decap(enc, skR)` uses receiver public key `pkR` to create an encapsulated shared secret `enc` that can be decapsulated using the receiver private key `skR`.
* `SerializePublicKey(pkX)` and `DeserializePublicKey(pkXm)` performs the conversions between uncompressed elliptic curve points and octet strings according to [SECG](https://secg.org/sec1-v2.pdf).

In [3]:
%%capture
!pip install ecpy

In [4]:
from cryptography.hazmat.primitives             import hashes
from cryptography.hazmat.primitives.asymmetric  import ec
from cryptography.hazmat.primitives.kdf.hkdf    import HKDF
from cryptography.hazmat.primitives             import serialization

import hmac
import secrets
from ecpy.curves                                import Curve, Point
from math                                       import ceil
import types
from dataclasses                                import dataclass

In [5]:
## Shortened functions
SEC1 = serialization.Encoding.X962
UCOMP = serialization.PublicFormat.UncompressedPoint
COMP = serialization.PublicFormat.CompressedPoint

## Static values for DHKEM(P-256, HKDF-SHA256)
Nsecret = 32
Nsk = 32
crv = Curve.get_curve('secp256r1')
G = crv.generator
salt = b""
kem_id = 0x0010
suite_id = b"KEM" + kem_id.to_bytes(2)
bitmask = 0xFF


## Test vector for DHKEM(P-256, HKDF-SHA256)
mode = 0
kem_id = 16
kdf_id = 1
aead_id = 1
info = 0x4f6465206f6e2061204772656369616e2055726e
ikmE = 0x4270e54ffd08d79d5928020af4686d8f6b7d35dbe470265f1f5aa22816ce860e
pkEm = 0x04a92719c6195d5085104f469a8b9814d5838ff72b60501e2c4466e5e67b325ac98536d7b61a1af4b78e5b7f951c0900be863c403ce65c9bfcb9382657222d18c4
skEm = 0x4995788ef4b9d6132b249ce59a77281493eb39af373d236a1fe415cb0c2d7beb
ikmR = 0x668b37171f1072f3cf12ea8a236a45df23fc13b82af3609ad1e354f6ef817550
pkRm = 0x04fe8c19ce0905191ebc298a9245792531f26f0cece2460639e8bc39cb7f706a826a779b4cf969b8a0e539c7f62fb3d30ad6aa8f80e30f1d128aafd68a2ce72ea0
skRm = 0xf3ce7fdae57e1a310d87f1ebbde6f328be0a99cdbcadf4d6589cf29de4b8ffd2
ss = 0xc0d26aeab536609a572b07695d933b589dcf363ff9d93c93adea537aeabb8cb8

In [6]:
## Functions for RFC 9180
def ECDH(a, B):
  return a.exchange(ec.ECDH(), B)

def ExtractAndExpand(ikm, label_extract, label_expand, info, L):
  labeled_ikm = b"HPKE-v1" + suite_id + label_extract + ikm
  labeled_info = L.to_bytes(2) + b"HPKE-v1" + suite_id + label_expand + info
  return HKDF(
        algorithm=hashes.SHA256(),
        length=L,
        salt=salt,
        info=labeled_info
        ).derive(labeled_ikm)

def DeriveKeyPair(ikm):
  label_extract = b"dkp_prk"
  label_expand = b"candidate"
  sk = 0
  counter = 0

  while sk == 0 or sk >= crv.order:
    if counter > 255:
      raise DeriveKeyPairError
    sk = bytearray(ExtractAndExpand(ikm, label_extract, label_expand, counter.to_bytes(1), Nsk))
    sk[0] &= bitmask
    sk = int.from_bytes(sk)
    counter += 1

  sk = ec.derive_private_key(sk, ec.SECP256R1())
  return sk, sk.public_key()

In [7]:
## Setup based on above
skE, pkE = DeriveKeyPair(ikmE.to_bytes((ikmR.bit_length()+7)//8))
skR, pkR = DeriveKeyPair(ikmR.to_bytes((ikmR.bit_length()+7)//8))

# Compute shared_secret
dh = ECDH(skE, pkR)
kem_context = pkE.public_bytes(SEC1, UCOMP) + pkR.public_bytes(SEC1, UCOMP)
shared_secret = ExtractAndExpand(dh,  b"eae_prk", b"shared_secret", kem_context, Nsecret)
assert (shared_secret.hex() == hex(ss)[2:])

#### RFC 9380 parts

With the shared secret generated from the KEM, we now need a standardized way to generate a field element that in turn can be used to derive child keys for any parent key. The key derivation relies on the [hash to field function](https://www.rfc-editor.org/rfc/rfc9380.html#name-hash_to_field-implementatio) from RFC 9380.

In [8]:
## Suite parameters
@dataclass(frozen=True)
class CurveParamsWeierstrass:
    a: int
    b: int
    n: int
    p: int
    m: int
    k: int
    H: types.ModuleType
    L: int
    Z: int
    h_eff: int
    g: tuple[int, int]
    dst: str
    curve: types.ModuleType

    @property
    def h(self) -> hashes.Hash:
      return hashes.Hash(hashes.SHA256())

SECP256R1 = CurveParamsWeierstrass(
    a=-3,
    b=0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B,
    n=0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551,  # order
    p=0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF,
    m=1,
    k=128,
    H=hashes.SHA256,
    L=48,
    Z=-10,
    h_eff=1,
    g=(
        0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296,
        0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5,
    ),
    dst="HashToGroup-",
    curve=ec.SECP256R1(),
)

# # Test vectors
# name    = expand_message_xmd
# DST     = "QUUX-V01-CS02-with-expander-SHA256-128"
# hash    = SHA256
# k       = 128
# len_in_bytes = 0x20

# msg     = ""
# uniform_bytes = "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235"

# msg     = "abc"
# uniform_bytes = "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615"

suite = SECP256R1

In [9]:
def strxor(str1, str2):
    str1 = int(str1.hex(), 16)
    str2 = int(str2.hex(), 16)
    x = hex(str1 ^ str2)[2:]
    return bytes.fromhex(x)

def expand_message_xmd(msg, DST, len_in_bytes):
    """
    https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xmdsha-256
    """
    b_in_bytes = suite.H.digest_size
    s_in_bytes = suite.H.block_size
    ell = ceil(len_in_bytes / b_in_bytes)
    if any([ell > 255, len_in_bytes > 65535, len(DST) > 255]):
        raise ValueError("Input values out of range.")
    DST_len_I2OSP = int.to_bytes(len(DST), 1)
    DST_prime = DST + DST_len_I2OSP
    Z_pad = b"\x00" * s_in_bytes
    l_i_b_str = int.to_bytes(len_in_bytes, 2)
    msg_prime = Z_pad + msg + l_i_b_str + int.to_bytes(0, 1) + DST_prime

    b_0 = suite.h
    b_0.update(msg_prime)
    b_0 = b_0.finalize()

    b_1 = suite.h
    b_1.update(b_0 + int.to_bytes(1, 1) + DST_prime)
    b_1 = b_1.finalize()

    b = [b_0, b_1]
    for i in range(2, ell + 1):
        hash_input_xor = strxor(b[0], b[i - 1])
        str_xor = suite.h
        str_xor.update(hash_input_xor + int.to_bytes(i, 1) + DST_prime)
        b.append(str_xor.finalize())

    uniform_bytes = b[1]
    for i in b[2:]:
        uniform_bytes += i

    return uniform_bytes[:len_in_bytes]

In [10]:
assert (expand_message_xmd(b"", b"QUUX-V01-CS02-with-expander-SHA256-128", 0x20).hex() == "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235")
assert (expand_message_xmd(b"abc", b"QUUX-V01-CS02-with-expander-SHA256-128", 0x20).hex() == "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615")

In [11]:
def hash_to_field(msg, count):
    """
    https://www.rfc-editor.org/rfc/rfc9380.html#name-p256_xmdsha-256_sswu_ro_
    """
    m, L, p = suite.m, suite.L, suite.p

    len_in_bytes = count * m * L
    uniform_bytes = expand_message_xmd(
        msg=msg,
        DST=b"QUUX-V01-CS02-with-P256_XMD:SHA-256_SSWU_RO_",
        len_in_bytes=len_in_bytes,
    )

    u = []
    for i in range(count):
        for j in range(m):
            elm_offset = L * (j + i * m)
            tv = uniform_bytes[elm_offset : elm_offset + L]
            e_j = int.from_bytes(tv) % p
            u.append(e_j)
    return u

In [12]:
assert (hash_to_field(b"", 2)[0] == 0xad5342c66a6dd0ff080df1da0ea1c04b96e0330dd89406465eeba11582515009)
assert (hash_to_field(b"", 2)[1] == 0x8c0f1d43204bd6f6ea70ae8013070a1518b43873bcd850aafa0a9e220e2eea5a)

#### Local Derivation

With our derived field element from a common shared secret, we can now do key derivation. A User can derive keys locally using Local Derivation as detailed below.

In [13]:
## X is the point corresponding to pkX
E = Point(pkE.public_numbers().x, pkE.public_numbers().y, crv)
seed = bytes.fromhex(hex(ikmE)[2:])

In [14]:
def encoded_point_to_key(Em):
  return point_to_key(crv.decode_point(Em))

def point_to_key(E):
  return ec.EllipticCurvePublicNumbers(E.x, E.y, SECP256R1.curve)

## Ignore that camel case is used from here on, I will fix it later
def CreateContext(pk, index):
  return pk.public_bytes(SEC1, UCOMP) + index.to_bytes(4)

def DeriveSalt(salt, ctx):
  derived_salt = suite.h
  derived_salt.update(salt + ctx)
  return derived_salt.finalize()

def DeriveBlindKey(ikm):
  bk_scalar = hash_to_field(ikm, 1)[0]
  l = (bk_scalar.bit_length() + 7) // 8
  return bk_scalar.to_bytes(l)

def DeriveBlindingFactor(bk, ctx):
  msg = bk + b'\x00' + ctx
  return hash_to_field(msg, 1)[0]

def DeriveBlindPublicKey(pk, bf, mode):
  P = Point(pk.public_numbers().x, pk.public_numbers().y, crv)
  BP = P + bf * G if mode == "ADD" else (P * bf if mode == "MUL" else None)
  return point_to_key(BP)

def Combine(s_0, s_1, mode):
  return (s_0 + s_1) % crv.order if mode == "ADD" else ((s_0 * s_1) % crv.order if mode == "MUL" else None)

def HDK(pk, salt, index, mode, bf=None, sk=None):
  ctx = CreateContext(pk, index)
  bk = DeriveBlindKey(salt)
  bf_0 = DeriveBlindingFactor(bk, ctx)

  salt_b = DeriveSalt(salt, ctx)
  pk_b = DeriveBlindPublicKey(pk, bf_0, mode)

  return [pk_b, salt_b]

In [15]:
print("Local additive key derivation in compressed form for pkE index 0: ")
HDK_ADD_0_0 = HDK(pk=pkE, salt=seed, index=0, mode="ADD")[0].public_key().public_bytes(SEC1, COMP)
print(HDK_ADD_0_0.hex())

print("\nLocal additive key derivation in compressed form for pkE index 1: ")
HDK_ADD_0_1 = HDK(pk=pkE, salt=seed, index=1, mode="ADD")[0].public_key().public_bytes(SEC1, COMP)
print(HDK_ADD_0_1.hex())

Local additive key derivation in compressed form for pkE index 0: 
03f02b710bd856c813d4308b3e676bba0721916b87625e396aa3e8b456c90881e9

Local additive key derivation in compressed form for pkE index 1: 
026fb0592172aefe8d7c189ba05bd5904ade835cb88137bfd850e3d2b0bae51df8


In [16]:
print("Local multiplicative key derivation in compressed form for pkE index 0: ")
HDK_MUL_0_0 = HDK(pk=pkE, salt=seed, index=0, mode="MUL")[0].public_key().public_bytes(SEC1, COMP)
print(HDK_MUL_0_0.hex())

print("\nLocal multiplicative key derivation in compressed form for pkE index 1: ")
HDK_MUL_0_1 = HDK(pk=pkE, salt=seed, index=1, mode="MUL")[0].public_key().public_bytes(SEC1, COMP)
print(HDK_MUL_0_1.hex())

Local multiplicative key derivation in compressed form for pkE index 0: 
0308708c34ce059a6bfaeb838f70bec0c8a90cfb388eb5654e7a09addf6e2e7fb5

Local multiplicative key derivation in compressed form for pkE index 1: 
0385ba7b098f407481e9d4ee0af51eaf2f1760d1160ac31ae675b0eff975a419f1
