<a href="https://colab.research.google.com/github/owlmt/PQC/blob/main/mlkem_backdoor_experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://dl.acm.org/doi/10.1145/3649476.3660373


In [None]:
import secrets
import hashlib
import struct
from hashlib import shake_256
from typing import List, Tuple

print("=== Backdoored ML KEM core (Kyber768 style) ===")


# =====================================================================
# 1. Parameters (Kyber768 style)
# =====================================================================

class MLKEMParams:
    n = 256        # polynomial degree
    k = 3          # ML KEM 768 uses k = 3
    q = 3329       # modulus
    eta = 2        # CBD parameter for ML KEM 768

    # ECDH K 409 ciphertext size used as payload in the paper
    ECDH_CT_SIZE = 104  # bytes

    @staticmethod
    def compute_c_bits(total_bits: int) -> int:
        """
        c_bits = ceil(b / (k * n)) from Equation (1) of the paper,
        where b is the bit length of the payload ciphertext.
        """
        return (total_bits + MLKEMParams.k * MLKEMParams.n - 1) // (MLKEMParams.k * MLKEMParams.n)


q = MLKEMParams.q
n = MLKEMParams.n
k = MLKEMParams.k


# =====================================================================
# 2. CBD sampler (FIPS 203 Algorithm 8, eta = 2)
# =====================================================================

def cbd_fips203(seed: bytes, eta: int = MLKEMParams.eta) -> List[int]:
    """
    Centered binomial distribution sampler as in FIPS 203 for eta = 2.
    Returns n coefficients in [-eta, eta].
    """
    outlen = 64 * eta
    buf = shake_256(seed).digest(outlen)

    coeffs: List[int] = []
    bitpos = 0

    for _ in range(n):
        x = 0
        y = 0
        for _ in range(eta):
            byte_idx = bitpos // 8
            bit_idx = bitpos % 8
            b = (buf[byte_idx] >> bit_idx) & 1
            x += b
            bitpos += 1
        for _ in range(eta):
            byte_idx = bitpos // 8
            bit_idx = bitpos % 8
            b = (buf[byte_idx] >> bit_idx) & 1
            y += b
            bitpos += 1
        coeffs.append(x - y)

    return coeffs


# =====================================================================
# 3. Polynomial arithmetic in R_q = Z_q[x] / (x^n + 1)
# =====================================================================

def poly_add(a: List[int], b: List[int]) -> List[int]:
    return [(x + y) % q for x, y in zip(a, b)]


def poly_sub(a: List[int], b: List[int]) -> List[int]:
    return [(x - y) % q for x, y in zip(a, b)]


def poly_mul(a: List[int], b: List[int]) -> List[int]:
    """
    Schoolbook multiplication in R_q modulo (x^n + 1).
    """
    res = [0] * (2 * n - 1)
    for i in range(n):
        ai = a[i]
        for j in range(n):
            res[i + j] = (res[i + j] + ai * b[j]) % q

    out = [0] * n
    for i in range(2 * n - 1):
        if i < n:
            out[i] = (out[i] + res[i]) % q
        else:
            # x^n ≡ −1 so x^i = x^{i − n} * x^n ≡ − x^{i − n}
            out[i - n] = (out[i - n] - res[i]) % q
    return out


# =====================================================================
# 4. Matrix A generation (Kyber style but simplified)
# =====================================================================

def gen_matrix(seed_A: bytes) -> List[List[List[int]]]:
    """
    Generate k x k matrix A with coefficients in Z_q.
    This uses SHAKE256(seed_A || i || j) and takes 16 bit words mod q.
    It is not bit compatible with FIPS 203, but the distribution is
    uniform enough for a research demo.
    """
    mat: List[List[List[int]]] = []
    for i in range(k):
        row: List[List[int]] = []
        for j in range(k):
            xof = shake_256(seed_A + bytes([i, j]))
            outlen = 2 * n
            buf = xof.digest(outlen)
            coeffs: List[int] = []
            for t in range(0, outlen, 2):
                val = buf[t] | (buf[t + 1] << 8)
                coeffs.append(val % q)
            row.append(coeffs[:n])
        mat.append(row)
    return mat


def mat_vec_mul(A: List[List[List[int]]],
                s_vec: List[List[int]]) -> List[List[int]]:
    """
    Compute t = A * s where
      A has shape k x k with polynomial entries
      s is a vector of k polynomials.
    Result is a vector of k polynomials.
    """
    res: List[List[int]] = []
    for i in range(k):
        acc = [0] * n
        for j in range(k):
            prod = poly_mul(A[i][j], s_vec[j])
            acc = poly_add(acc, prod)
        res.append(acc)
    return res


# =====================================================================
# 5. Bit helpers and payload encoding (Equation (1) of the paper)
# =====================================================================

def bits_from_bytes_le(data: bytes) -> List[int]:
    """
    Convert bytes to bits in little endian order within each byte.
    """
    out: List[int] = []
    for b in data:
        for j in range(8):
            out.append((b >> j) & 1)
    return out


def bytes_from_bits_le(bits: List[int], out_len: int) -> bytes:
    """
    Inverse of bits_from_bytes_le on the first out_len bytes.
    """
    res = bytearray(out_len)
    for i in range(out_len):
        val = 0
        for j in range(8):
            idx = 8 * i + j
            if idx < len(bits) and bits[idx]:
                val |= 1 << j
        res[i] = val
    return bytes(res)


def encode_payload_to_poly(ct: bytes) -> Tuple[List[int], int, int, int, int]:
    """
    Encode the K 409 ciphertext ct_bd into a flat polynomial vector p
    of length k*n as in the paper.

      b_bits  = |ct| in bits
      c_bits  = ceil(b_bits / (k*n))
      M       = 2^c_bits
      num_used = ceil(b_bits / c_bits)

    For each i in [0, num_used):
      p[i] = sum_{j=0}^{c_bits-1} bit[i*c_bits + j] * 2^j

    Remaining coefficients are zero.
    """
    b_bits = len(ct) * 8
    c_bits = MLKEMParams.compute_c_bits(b_bits)
    M = 1 << c_bits

    bits = bits_from_bytes_le(ct)
    num_used = (b_bits + c_bits - 1) // c_bits

    if num_used > k * n:
        raise ValueError("Payload too large to embed in ML KEM 768 public key")

    p = [0] * (k * n)
    for i in range(num_used):
        val = 0
        for j in range(c_bits):
            idx = i * c_bits + j
            if idx < b_bits and bits[idx]:
                val |= 1 << j
        p[i] = val

    return p, c_bits, M, b_bits, num_used


# =====================================================================
# 6. Compensation vector h and modified t' (Section 3.3)
# =====================================================================

def apply_compensation(t_flat: List[int],
                       p: List[int],
                       c_bits: int,
                       M: int,
                       num_used: int) -> Tuple[List[int], List[int]]:
    """
    Given original LWE component t (flattened, length k*n) and encoded
    payload coefficients p[i] in [0, M), compute a small compensation
    vector h so that for 0 <= i < num_used

       t'[i] = t[i] + h[i] (mod q)
       t'[i] ≡ p[i] (mod M)

    with h[i] in [-M/2, M/2].  For the K 409 example we have
    c_bits = 2, M = 4, so h[i] is in [-2, 2] which matches the analysis
    in the paper about error growth.
    """
    halfM = M // 2
    t_prime = list(t_flat)
    h = [0] * len(t_flat)

    for i in range(num_used):
        cur_mod = t_flat[i] % M
        desired = p[i]
        diff = (desired - cur_mod) % M
        if diff > halfM:
            diff -= M

        h[i] = diff
        t_prime[i] = (t_flat[i] + diff) % q

        if t_prime[i] % M != desired:
            raise AssertionError(f"Compensation failed at coefficient {i}")
        if not -halfM <= diff <= halfM:
            raise AssertionError(f"Compensation outside allowed range at {i}")

    return t_prime, h


def recover_payload_from_tprime(t_prime_flat: List[int],
                                c_bits: int,
                                M: int,
                                b_bits: int,
                                num_used: int,
                                out_ct_len: int) -> bytes:
    """
    Recovery side: take t', reduce modulo M to get p[i], then invert
    the same encoding to extract the original ciphertext.
    """
    bits: List[int] = []
    for i in range(num_used):
        val = t_prime_flat[i] % M
        for j in range(c_bits):
            bits.append((val >> j) & 1)

    bits = bits[:b_bits]
    return bytes_from_bits_le(bits, out_ct_len)


# =====================================================================
# 7. Simulated ECDH on curve K 409 (Section 3.5.1)
# =====================================================================

class SimulatedK409:
    """
    We model Encbd and Decbd as in the paper: we need a deterministic
    mapping from ct_bd to seed_B, and a decryption that recovers seed_B.

    For the demo:
      ct_bd = random 104 bytes
      seed_B = first 32 bytes of SHA 256(ct_bd)
    """

    CT_SIZE = MLKEMParams.ECDH_CT_SIZE

    @staticmethod
    def enc() -> Tuple[bytes, bytes]:
        ct = secrets.token_bytes(SimulatedK409.CT_SIZE)
        seed_B = hashlib.sha256(ct).digest()[:32]
        return ct, seed_B

    @staticmethod
    def dec(ct: bytes) -> bytes:
        return hashlib.sha256(ct).digest()[:32]


# =====================================================================
# 8. Helpers for flattening and unflattening vectors of polynomials
# =====================================================================

def flatten_polys(poly_vec: List[List[int]]) -> List[int]:
    flat: List[int] = []
    for poly in poly_vec:
        flat.extend(poly)
    return flat


def unflatten_polys(flat: List[int]) -> List[List[int]]:
    polys: List[List[int]] = []
    idx = 0
    for _ in range(k):
        polys.append(flat[idx:idx + n])
        idx += n
    return polys


# =====================================================================
# 9. Backdoored ML KEM key generation (Algorithm 2, device side)
# =====================================================================

def backdoored_keygen() -> Tuple[dict, dict, dict]:
    """
    This is Algorithm 2 specialized to ML KEM 768 with the K 409
    instantiation, implemented directly in Python.

    Steps (device side):
      1. Run Encbd on the adversary provided public key (simulated here).
      2. Derive seed_B from ct_bd.
      3. Sample s and e from seed_B using CBD.
      4. Generate matrix A from seed_A and compute t = A * s + e.
      5. Encode ct_bd into polynomial payload p.
      6. Compute compensation h so that t' embeds p modulo M.
      7. Output pk with t' and seed_A, and sk with s plus seed_B and ct_bd.

    The public key structure here is simplified but contains exactly
    the information needed for the backdoor and for honest encryption.
    """
    print("=== Backdoored ML KEM KeyGen (device) ===")

    # Step 1 and 2: adversary public key and Encbd, giving ct_bd and seed_B
    ct_bd, seed_B = SimulatedK409.enc()

    # Step 3: sample s and e from seed_B
    s_vec: List[List[int]] = []
    e_vec: List[List[int]] = []
    for i in range(k):
        s_seed = seed_B + b"\x00" + bytes([i])
        e_seed = seed_B + b"\x10" + bytes([i])
        s_vec.append(cbd_fips203(s_seed))
        e_vec.append(cbd_fips203(e_seed))

    # Step 4: generate A from seed_A and compute t = A * s + e
    seed_A = secrets.token_bytes(32)
    A = gen_matrix(seed_A)
    t_vec = mat_vec_mul(A, s_vec)
    for i in range(k):
        t_vec[i] = poly_add(t_vec[i], e_vec[i])

    # Step 5: encode ct_bd into payload polynomial p
    p, c_bits, M, b_bits, num_used = encode_payload_to_poly(ct_bd)

    # Step 6: apply compensation so that t' embeds p modulo M
    t_flat = flatten_polys(t_vec)
    t_prime_flat, h = apply_compensation(t_flat, p, c_bits, M, num_used)
    t_prime_vec = unflatten_polys(t_prime_flat)

    # Public and secret keys (simplified, research format)
    pk = {
        "seed_A": seed_A,
        "t_vec": t_prime_vec,
        # these metadata fields are not needed in a real backdoor,
        # but are stored here to simplify the demo code
        "c_bits": c_bits,
        "M": M,
        "b_bits": b_bits,
        "num_used": num_used,
    }

    sk = {
        "s_vec": s_vec,
        "seed_B": seed_B,
        "ct_bd": ct_bd,
    }

    debug = {
        "t_vec_original": t_vec,
        "t_vec_backdoored": t_prime_vec,
        "h": h,
    }

    return pk, sk, debug


# =====================================================================
# 10. Attacker recovery (Algorithm 2, adversary side)
# =====================================================================

def attacker_recovery(pk: dict) -> Tuple[bytes, bytes, List[List[int]]]:
    """
    Given only the backdoored public key and the adversary ECDH secret
    key (modeled inside SimulatedK409), the attacker:

      1. Reconstructs t' from the public key.
      2. Recovers payload coefficients p[i] = t'[i] mod M.
      3. Rebuilds ct_bd.
      4. Runs Decbd(ct_bd) to obtain seed_B.
      5. Regenerates s from seed_B using CBD.

    This matches the structure of Algorithm 2 in the paper.
    """
    print("=== Attacker recovery from public key ===")

    t_prime_vec: List[List[int]] = pk["t_vec"]
    t_prime_flat = flatten_polys(t_prime_vec)

    c_bits = pk["c_bits"]
    M = pk["M"]
    b_bits = pk["b_bits"]
    num_used = pk["num_used"]

    # Step 1 to 3: recover ct_bd
    recovered_ct = recover_payload_from_tprime(
        t_prime_flat,
        c_bits,
        M,
        b_bits,
        num_used,
        MLKEMParams.ECDH_CT_SIZE,
    )

    # Step 4: Decbd to get seed_B
    recovered_seed_B = SimulatedK409.dec(recovered_ct)

    # Step 5: regenerate s from seed_B
    s_vec_rec: List[List[int]] = []
    for i in range(k):
        s_seed = recovered_seed_B + b"\x00" + bytes([i])
        s_vec_rec.append(cbd_fips203(s_seed))

    return recovered_ct, recovered_seed_B, s_vec_rec


# =====================================================================
# 11. End to end demonstration
# =====================================================================

def main() -> None:
    print("=== Attacker installs backdoor (modeled by SimulatedK409) ===")
    print("    (In practice this corresponds to a malicious ECDH public key on device.)\n")

    print("=== Victim runs backdoored ML KEM KeyGen ===")
    pk, sk, debug = backdoored_keygen()

    print("\n=== Attacker observes public key and recovers secret ===")
    rec_ct, rec_seed_B, s_vec_rec = attacker_recovery(pk)

    print("\n=== Comparison ===")
    ct_match = (rec_ct == sk["ct_bd"])
    seed_match = (rec_seed_B == sk["seed_B"])
    s_match = all(a == b for a, b in zip(sk["s_vec"], s_vec_rec))

    print("Ciphertext match:", ct_match)
    print("seed_B match:    ", seed_match)
    print("Secret s match:  ", s_match)

    print("\nFirst 20 coefficients of s[0] (true):     ", sk["s_vec"][0][:20])
    print("First 20 coefficients of s[0] (recovered):", s_vec_rec[0][:20])

    print("\nMATCH:", ct_match and seed_match and s_match)


if __name__ == "__main__":
    main()


=== Backdoored ML KEM core (Kyber768 style) ===
=== Attacker installs backdoor (modeled by SimulatedK409) ===
    (In practice this corresponds to a malicious ECDH public key on device.)

=== Victim runs backdoored ML KEM KeyGen ===
=== Backdoored ML KEM KeyGen (device) ===

=== Attacker observes public key and recovers secret ===
=== Attacker recovery from public key ===

=== Comparison ===
Ciphertext match: True
seed_B match:     True
Secret s match:   True

First 20 coefficients of s[0] (true):      [-1, 0, 0, 1, 1, 1, 0, -1, -1, 0, 1, 0, -1, 0, 2, -1, 0, 0, 0, 2]
First 20 coefficients of s[0] (recovered): [-1, 0, 0, 1, 1, 1, 0, -1, -1, 0, 1, 0, -1, 0, 2, -1, 0, 0, 0, 2]

MATCH: True
