<a href="https://colab.research.google.com/github/owlmt/PQC/blob/main/mlkem_backdoor1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
import hashlib
import struct
import secrets
from typing import List, Tuple

print("=== Backdooring Kyber768 (Ravi et al. GLSVLSI 2024) ===")

# ---------------------------------------------------------------------------
# 1. Kyber768 parameters (as in paper)
# ---------------------------------------------------------------------------

class KyberParams:
    n = 256          # polynomial degree
    k = 3            # Kyber-768 uses k = 3
    q = 3329         # modulus
    eta = 2          # CBD parameter for Kyber-768

    # Size of K-409 ECDH ciphertext used as backdoor (Section 3.5.1)
    ECDH_CIPHERTEXT_SIZE = 104  # bytes = 832 bits

    @staticmethod
    def compute_c_bits(total_bits: int) -> int:
        """
        c_bits = ceil(b / (k * n)) from Equation (1) of the paper,
        where b is the bit-length of the payload ciphertext.
        """
        return (total_bits + KyberParams.k * KyberParams.n - 1) // (KyberParams.k * KyberParams.n)


# ---------------------------------------------------------------------------
# 2. NTT "stub" (we stay in coefficient domain for this demo)
# ---------------------------------------------------------------------------

def ntt(poly: List[int]) -> List[int]:
    # In the reference C implementation the public key stores NTT(t').
    # For a minimal demo we treat NTT as the identity so that INTT(NTT(x)) = x.
    return list(poly)

def intt(poly_ntt: List[int]) -> List[int]:
    return list(poly_ntt)


# ---------------------------------------------------------------------------
# 3. CBD sampler (FIPS 203 Algorithm 8, eta = 2)
# ---------------------------------------------------------------------------

from hashlib import shake_256

def cbd_fips203(seed: bytes, eta: int = KyberParams.eta) -> List[int]:
    """
    Centered Binomial Distribution sampler as in FIPS 203 (Algorithm 8)
    for Kyber with eta = 2. Returns n = 256 coefficients in [-eta, eta].
    """
    outlen = 64 * eta  # 64*eta bytes for eta in {2,3}
    buf = shake_256(seed).digest(outlen)

    coeffs: List[int] = []
    bitpos = 0
    for _ in range(KyberParams.n):
        x = 0
        y = 0
        # eta bits for x
        for _ in range(eta):
            byte_idx = bitpos // 8
            bit_in_byte = bitpos % 8
            bit = (buf[byte_idx] >> bit_in_byte) & 1
            x += bit
            bitpos += 1
        # eta bits for y
        for _ in range(eta):
            byte_idx = bitpos // 8
            bit_in_byte = bitpos % 8
            bit = (buf[byte_idx] >> bit_in_byte) & 1
            y += bit
            bitpos += 1
        coeffs.append(x - y)

    return coeffs


# ---------------------------------------------------------------------------
# 4. Bit helpers (little-endian inside bytes, as in the paper / reference code)
# ---------------------------------------------------------------------------

def bits_from_bytes_le(data: bytes) -> List[int]:
    """
    Convert bytes to a list of bits in little-endian order per byte.
    bit_index = 8*i + j corresponds to (data[i] >> j) & 1.
    """
    out: List[int] = []
    for b in data:
        for j in range(8):
            out.append((b >> j) & 1)
    return out


def bytes_from_bits_le(bits: List[int], out_len: int) -> bytes:
    """
    Inverse of bits_from_bytes_le for the first out_len bytes.
    """
    res = bytearray(out_len)
    for i in range(out_len):
        val = 0
        for j in range(8):
            idx = 8 * i + j
            if idx < len(bits) and bits[idx]:
                val |= 1 << j
        res[i] = val
    return bytes(res)


# ---------------------------------------------------------------------------
# 5. Encoding payload ct_bd into polynomial p (Section 3.3, Equation (1))
# ---------------------------------------------------------------------------

def encode_payload_to_poly(ct: bytes) -> Tuple[List[int], int, int, int, int]:
    """
    Encode the ciphertext ct_bd into a polynomial module p in R_q^k, as in
    Section 3.3. We follow the reference C code rather than the slightly
    ambiguous notation "mod c" in the paper.

    Let b  = |ct| in bits.
    Let c_bits = ceil(b / (k*n)).
    Let M = 2^c_bits.

    For i from 0 to num_used-1, coefficient p[i] stores c_bits bits of ct
    in little-endian order:
        p[i] = sum_{j=0}^{c_bits-1} bit[ i * c_bits + j ] * 2^j

    Remaining coefficients are zero.
    """
    k, n = KyberParams.k, KyberParams.n
    b_bits = len(ct) * 8
    c_bits = KyberParams.compute_c_bits(b_bits)
    M = 1 << c_bits

    bits = bits_from_bytes_le(ct)
    num_used = (b_bits + c_bits - 1) // c_bits  # ceil(b_bits / c_bits)

    if num_used > k * n:
        raise ValueError("Payload too large to embed in Kyber768 public key")

    p = [0] * (k * n)
    for i in range(num_used):
        val = 0
        for j in range(c_bits):
            bit_index = i * c_bits + j
            if bit_index < b_bits and bits[bit_index]:
                val |= 1 << j
        p[i] = val  # in [0, M-1]

    return p, c_bits, M, b_bits, num_used


# ---------------------------------------------------------------------------
# 6. Compensation vector h and modified t' (Section 3.3, "additively integrate")
# ---------------------------------------------------------------------------

def apply_compensation(t: List[int],
                       p: List[int],
                       c_bits: int,
                       M: int,
                       num_used: int) -> Tuple[List[int], List[int]]:
    """
    Given original LWE component t (flattened, length k*n) and encoded
    payload coefficients p[i] in [0, M-1], compute a small compensation
    vector h so that

        t'[i] = t[i] + h[i] (mod q)
        t'[i] ≡ p[i] (mod M)    for 0 <= i < num_used

    and h[i] is always in [-M/2, M/2].  For Kyber768 with K-409 we have
    c_bits = 2, M = 4, so h[i] ∈ [-2, 2], exactly matching the analysis in
    Section 3.4 where the error span becomes [-4, 4].
    """
    q = KyberParams.q
    t_prime = list(t)
    h = [0] * len(t)
    halfM = M // 2

    for i in range(num_used):
        cur_mod = t[i] % M
        desired = p[i]              # already in [0, M-1]
        diff = (desired - cur_mod) % M  # 0..M-1

        # Center in [-M/2, M/2]
        if diff > halfM:
            diff -= M

        h[i] = diff
        t_prime[i] = (t[i] + diff) % q

        # Sanity checks (can be removed for performance)
        if t_prime[i] % M != desired:
            raise AssertionError("Compensation failed at coefficient {}".format(i))
        if not -halfM <= diff <= halfM:
            raise AssertionError("Compensation outside allowed range at {}".format(i))

    return t_prime, h


# ---------------------------------------------------------------------------
# 7. Recover payload from modified t' (Algorithm 2, "Secret Key Recovery")
# ---------------------------------------------------------------------------

def recover_payload_from_tprime(t_prime: List[int],
                                c_bits: int,
                                M: int,
                                b_bits: int,
                                num_used: int,
                                out_ct_len: int) -> bytes:
    """
    In the recovery phase the attacker computes t' = INTT(t_hat') and then
    for each i takes

        p_rec[i] = t'[i] mod M

    Then the original bits are reconstructed from p_rec[i] using exactly
    the same mapping as in encode_payload_to_poly.
    """
    bits: List[int] = []

    for i in range(num_used):
        val = t_prime[i] % M
        for j in range(c_bits):
            bits.append((val >> j) & 1)

    bits = bits[:b_bits]
    return bytes_from_bits_le(bits, out_ct_len)


# ---------------------------------------------------------------------------
# 8. Simulated ECDH on curve K-409 (Section 3.5.1)
# ---------------------------------------------------------------------------

class SimulatedK409:
    """
    We do not implement full ECC here.  We only need a deterministic
    Encbd/Decbd pair with the right ciphertext size and a 32-byte seed_B.
    For the demo, we set

        ct_bd = random 104 bytes
        seed_B = first 32 bytes of SHA-256(ct_bd)

    so that the attacker can later recompute seed_B from ct_bd.
    """

    @staticmethod
    def enc() -> Tuple[bytes, bytes]:
        ct = secrets.token_bytes(KyberParams.ECDH_CIPHERTEXT_SIZE)
        seed_B = hashlib.sha256(ct).digest()[:32]
        return ct, seed_B

    @staticmethod
    def dec(ct: bytes) -> bytes:
        return hashlib.sha256(ct).digest()[:32]


# ---------------------------------------------------------------------------
# 9. Backdoored Kyber key generation (Algorithm 2, "Backdoor-ed KeyGen C'")
# ---------------------------------------------------------------------------

def backdoored_keygen() -> Tuple[dict, dict, dict]:
    """
    This implements the online part of Algorithm 2 specialized to the
    K-409 instantiation.  We simplify the actual Kyber operations as
    follows:

      - We generate t as a random vector in [0, q)^{k*n} instead of
        computing t = A * s + e.  For the kleptographic channel this
        makes no difference since only the marginal distribution of t
        matters.

      - NTT and INTT are treated as the identity.

    The key steps that match the paper + reference C code are:
      * Encbd to obtain (ct_bd, m_bd), where seed_B is derived from m_bd.
      * Sampling s and e from seed_B using CBD.
      * Encoding ct_bd into polynomial p with c_bits bits per coefficient.
      * Computing compensation h so that t' embeds p modulo M = 2^{c_bits}.
      * Publishing pk = (seed_A, NTT(t')).
    """
    # 1. Adversary's public encryption Encbd (on device)
    ct_bd, seed_B = SimulatedK409.enc()

    # 2. Sample Kyber secret s and error e from seed_B (Algorithm 1 / 2)
    s_vec: List[List[int]] = []
    e_vec: List[List[int]] = []

    for i in range(KyberParams.k):
        s_seed = seed_B + struct.pack("B", i)
        e_seed = seed_B + struct.pack("B", i + KyberParams.k)
        s_vec.append(cbd_fips203(s_seed))
        e_vec.append(cbd_fips203(e_seed))

    # 3. Sample A and compute t = A*s + e.
    #    For the kleptographic channel we only need t to "look like"
    #    a legitimate Kyber LWE vector, so we draw it uniformly at random.
    total_coeffs = KyberParams.k * KyberParams.n
    t = [secrets.randbelow(KyberParams.q) for _ in range(total_coeffs)]

    # 4. Encode ct_bd into p
    p, c_bits, M, b_bits, num_used = encode_payload_to_poly(ct_bd)

    # 5. Apply compensation
    t_prime, h = apply_compensation(t, p, c_bits, M, num_used)

    # 6. Pack public and secret key objects (NTT is identity here)
    seed_A = secrets.token_bytes(32)

    pk = {
        "seed_A": seed_A,
        "t_prime_ntt": ntt(t_prime),
        "c_bits": c_bits,
        "M": M,
        "b_bits": b_bits,
        "num_used": num_used,
    }

    sk_internal = {
        "s_vec": s_vec,
        "seed_B": seed_B,
        "ct_bd": ct_bd,
    }

    debug = {
        "t": t,
        "t_prime": t_prime,
        "h": h,
    }

    return pk, sk_internal, debug


# ---------------------------------------------------------------------------
# 10. Attacker recovery given only pk and sk_bd (Algorithm 2, "Secret recovery")
# ---------------------------------------------------------------------------

def attacker_recovery(pk: dict) -> Tuple[bytes, bytes, List[List[int]]]:
    """
    The attacker knows the ECDH (or McEliece) secret key sk_bd, so it can
    run Decbd on any recovered payload ciphertext ct_bd.

    Given the backdoor-ed public key pk, the attacker:

      1. Computes t' = INTT(t_hat').
      2. Recovers p[i] = t'[i] mod M.
      3. Reconstructs ct_bd from p.
      4. Runs Decbd(ct_bd) to obtain seed_B.
      5. Regenerates s from seed_B using CBD.
    """
    t_prime = intt(pk["t_prime_ntt"])
    c_bits = pk["c_bits"]
    M = pk["M"]
    b_bits = pk["b_bits"]
    num_used = pk["num_used"]

    # Step 1–3: recover ct_bd
    recovered_ct = recover_payload_from_tprime(
        t_prime,
        c_bits,
        M,
        b_bits,
        num_used,
        KyberParams.ECDH_CIPHERTEXT_SIZE,
    )

    # Step 4: Decbd to get seed_B
    recovered_seed_B = SimulatedK409.dec(recovered_ct)

    # Step 5: regenerate s_vec from seed_B
    s_vec_rec: List[List[int]] = []
    for i in range(KyberParams.k):
        s_seed = recovered_seed_B + struct.pack("B", i)
        s_vec_rec.append(cbd_fips203(s_seed))

    return recovered_ct, recovered_seed_B, s_vec_rec


# ---------------------------------------------------------------------------
# 11. End-to-end demo
# ---------------------------------------------------------------------------

def main() -> None:
    print("=== Attacker installs backdoor (offline) ===")
    print("    [Modeled implicitly via SimulatedK409 with fixed Decbd]")

    print("\n=== Victim runs backdoored Kyber KeyGen (C') ===")
    pk, sk_internal, debug = backdoored_keygen()

    print("    Public key contains seed_A and modified t' with embedded payload.")
    print("    c_bits       =", pk["c_bits"])
    print("    num_used     =", pk["num_used"], "coefficients carry payload")
    print("    max |h[i]|   =", max(abs(x) for x in debug["h"][:pk['num_used']]))
    print("    allowed |h|  =", (pk["M"] // 2))

    print("\n=== Attacker recovers secret from public key ===")
    rec_ct, rec_seed_B, s_vec_rec = attacker_recovery(pk)

    ct_match = sk_internal["ct_bd"] == rec_ct
    seed_match = sk_internal["seed_B"] == rec_seed_B

    print("Ciphertext match: ", ct_match)
    print("seed_B match:     ", seed_match)

    # Compare secret polynomials s
    s_true = sk_internal["s_vec"]
    all_match = True
    for i in range(KyberParams.k):
        if s_true[i] != s_vec_rec[i]:
            all_match = False
            break

    print("Secret s match:   ", all_match)

    print("\nFirst 20 coeffs of s[0] (true):     ", s_true[0][:20])
    print("First 20 coeffs of s[0] (recovered):", s_vec_rec[0][:20])

    print("\nMATCH:", all_match and ct_match and seed_match)


if __name__ == "__main__":
    main()


=== Backdooring Kyber768 (Ravi et al. GLSVLSI 2024) ===
=== Attacker installs backdoor (offline) ===
    [Modeled implicitly via SimulatedK409 with fixed Decbd]

=== Victim runs backdoored Kyber KeyGen (C') ===
    Public key contains seed_A and modified t' with embedded payload.
    c_bits       = 2
    num_used     = 416 coefficients carry payload
    max |h[i]|   = 2
    allowed |h|  = 2

=== Attacker recovers secret from public key ===
Ciphertext match:  True
seed_B match:      True
Secret s match:    True

First 20 coeffs of s[0] (true):      [-2, 2, 0, -1, 2, 0, 0, 0, 2, 0, 1, 0, 1, 1, -1, 0, -1, 0, -1, 1]
First 20 coeffs of s[0] (recovered): [-2, 2, 0, -1, 2, 0, 0, 0, 2, 0, 1, 0, 1, 1, -1, 0, -1, 0, -1, 1]

MATCH: True
