<a href="https://colab.research.google.com/github/owlmt/PQC/blob/main/mlkem_backdoor_expriment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://link.springer.com/chapter/10.1007/978-3-031-82852-2_11

In [2]:
!pip install numpy scipy



In [19]:
import numpy as np
import hashlib
import secrets
import time
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any


# ==============================================
# 1. Kyber parameters
# ==============================================

@dataclass
class KyberParams:
    n: int = 256        # polynomial degree
    q: int = 3329       # modulus
    k: int = 3          # Kyber768
    eta: int = 2

    @property
    def q_half(self) -> int:
        # centered range is (-q/2, q/2]
        return self.q // 2  # 1664

    @property
    def border_thresh(self) -> int:
        # (q - 3)/2 = 1663
        return (self.q - 3) // 2

    def total_coefficients(self) -> int:
        return self.k * self.n  # 768


# ==============================================
# 2. Central binomial and conditional distributions
# ==============================================

class CentralBinomialDistribution:
    @staticmethod
    def sample_B2_from_rng(rng: np.random.RandomState) -> int:
        """
        Sample from B_2 using a given RNG:
        B2 = (a1 + a2) - (b1 + b2), a_i, b_i in {0, 1}.
        """
        a1, a2, b1, b2 = rng.randint(0, 2, size=4)
        return int((a1 + a2) - (b1 + b2))

    @staticmethod
    def sample_conditional(target_lsb: int) -> int:
        """
        Conditional error sampling as in the paper:

        D0 (LSB(e) = 0): {-2: 1/8, 0: 3/4, 2: 1/8}
        D1 (LSB(e) = 1): {-1: 1/2, 1: 1/2}
        """
        if target_lsb == 0:
            r = secrets.randbits(3)
            if r == 0:
                return -2
            elif 1 <= r <= 6:
                return 0
            else:
                return 2
        else:
            return -1 if secrets.randbits(1) == 0 else 1

    @staticmethod
    def verify_distribution(num_samples: int = 100000) -> None:
        """
        Quick sanity check that unconditional B2 and conditional
        D0/D1 behave as expected.
        """
        from collections import Counter

        rng = np.random.RandomState(42)

        counts = Counter()
        for _ in range(num_samples):
            v = CentralBinomialDistribution.sample_B2_from_rng(rng)
            counts[v] += 1

        print("B2 distribution (empirical):")
        total = sum(counts.values())
        for v in [-2, -1, 0, 1, 2]:
            print(f"  {v:2d}: {counts[v] / total:.4f}")

        # check conditional D0
        counts_d0 = Counter()
        for _ in range(num_samples):
            v = CentralBinomialDistribution.sample_conditional(0)
            counts_d0[v] += 1
        print("\nD0 distribution (empirical):")
        total = sum(counts_d0.values())
        for v in [-2, 0, 2]:
            print(f"  {v:2d}: {counts_d0[v] / total:.4f}")

        # check conditional D1
        counts_d1 = Counter()
        for _ in range(num_samples):
            v = CentralBinomialDistribution.sample_conditional(1)
            counts_d1[v] += 1
        print("\nD1 distribution (empirical):")
        total = sum(counts_d1.values())
        for v in [-1, 1]:
            print(f"  {v:2d}: {counts_d1[v] / total:.4f}")


# ==============================================
# 3. Toy Classic McEliece with real decap
# ==============================================

@dataclass
class MockMcElieceSecret:
    table: Dict[tuple, bytes]


@dataclass
class MockMcEliecePublic:
    dummy: bytes = b"mc_pk"


class MockClassicMcEliece:
    """
    Toy Classic McEliece that is invertible:

    - keygen: returns pk, sk where sk holds a mapping C_bits -> K
    - encap:  generates a fresh K and C_bits, stores mapping in sk.table
    - decap:  looks up K from sk.table using C_bits as key
    """

    def __init__(self, ciphertext_bits: int = 768):
        self.ciphertext_bits = ciphertext_bits

    def keygen(self) -> Tuple[MockMcEliecePublic, MockMcElieceSecret]:
        pk = MockMcEliecePublic()
        sk = MockMcElieceSecret(table={})
        return pk, sk

    def encap(
        self,
        pk: MockMcEliecePublic,
        sk: MockMcElieceSecret,
    ) -> Tuple[bytes, List[int]]:
        # 256-bit session key
        K = secrets.token_bytes(32)

        # derive 768 pseudorandom bits from SHA3-512(K)
        h = hashlib.sha3_512(K).digest()
        bits: List[int] = []
        idx = 0
        while len(bits) < self.ciphertext_bits:
            b = h[idx % len(h)]
            for i in range(8):
                bits.append((b >> i) & 1)
            idx += 1
        bits = bits[: self.ciphertext_bits]

        # store mapping for decap
        sk.table[tuple(bits)] = K
        return K, bits

    def decap(
        self,
        sk: MockMcElieceSecret,
        C_bits: List[int],
    ) -> Optional[bytes]:
        return sk.table.get(tuple(C_bits))


# ==============================================
# 4. Backdoored Kyber key generation (KeyGen*)
# ==============================================

class BackdooredKyber:
    """
    Python model of Algorithm 2 (KeyGen*).
    """

    def __init__(self, params: KyberParams = None):
        self.params = params or KyberParams()
        # For Kyber768 we have 768 coefficients and use 768-bit McEliece
        self.mce = MockClassicMcEliece(ciphertext_bits=self.params.total_coefficients())

    def sample_matrix_A(self, seed: bytes) -> np.ndarray:
        """
        Deterministic sampling of A from pk.seed.

        For the purposes of this lab we do not implement the exact
        Kyber XOF-based sampler but a seeded PRNG instead.
        """
        seed_int = int.from_bytes(hashlib.sha256(seed).digest()[:4], "big")
        rng = np.random.RandomState(seed_int)
        A = rng.randint(
            0,
            self.params.q,
            size=(self.params.k, self.params.k, self.params.n),
            dtype=np.int32,
        )
        return A

    def sample_secret_vector(self, sk_seed: bytes) -> np.ndarray:
        """
        Sample s deterministically from sk.seed in the B2 distribution.
        """
        seed_int = int.from_bytes(hashlib.sha256(sk_seed).digest()[:4], "big")
        rng = np.random.RandomState(seed_int)
        s = np.zeros((self.params.k, self.params.n), dtype=np.int32)
        for i in range(self.params.k):
            for j in range(self.params.n):
                s[i, j] = CentralBinomialDistribution.sample_B2_from_rng(rng)
        return s

    def compute_As_mod(self, A: np.ndarray, s: np.ndarray) -> np.ndarray:
        """
        Naive cyclic polynomial multiplication for As = A * s in R_q,
        with coefficients reduced mod q into [0, q-1].
        """
        k = self.params.k
        n = self.params.n
        q = self.params.q

        As = np.zeros((k, n), dtype=np.int32)

        for i in range(k):
            for j in range(k):
                for a_idx in range(n):
                    a_coeff = A[i, j, a_idx]
                    if a_coeff == 0:
                        continue
                    for s_idx in range(n):
                        val = s[j, s_idx]
                        if val == 0:
                            continue
                        pos = (a_idx + s_idx) % n
                        As[i, pos] = (As[i, pos] + a_coeff * val) % q

        return As

    @staticmethod
    def mod_center(x: int, q: int, q_half: int) -> int:
        """
        Centered representation mod ±q in the range [-q/2, q/2].
        """
        r = x % q
        if r > q_half:
            r -= q
        return r

    def keygen_backdoored(self) -> Dict[str, Any]:
        """
        Backdoored Kyber key generation (KeyGen*).

        Steps:
          1. Generate McEliece pk, sk and encapsulate to (K, C_bits)
          2. Derive (sk.seed, pk.seed) = G(K)
          3. Sample A from pk.seed, s from sk.seed
          4. Compute As, recenter As to mod± q
          5. For each coefficient, choose e from D0 or D1 so that LSB(t) = C_bit
             except when |t| is in the border region.
        """
        q = self.params.q
        q_half = self.params.q_half
        border_thresh = self.params.border_thresh

        # 1. McEliece KEM
        mc_pk, mc_sk = self.mce.keygen()
        K, C_bits = self.mce.encap(mc_pk, mc_sk)
        assert len(C_bits) == self.params.total_coefficients()

        # 2. Derive seeds from d = K
        h = hashlib.sha3_512(K).digest()
        sk_seed = h[:32]
        pk_seed = h[32:]

        # 3. Sample A and s
        A = self.sample_matrix_A(pk_seed)
        s = self.sample_secret_vector(sk_seed)

        # 4. Compute As mod q, then recenter to mod± q
        As_mod = self.compute_As_mod(A, s)
        As = np.zeros_like(As_mod, dtype=np.int32)
        for i in range(self.params.k):
            for j in range(self.params.n):
                As[i, j] = self.mod_center(int(As_mod[i, j]), q, q_half)

        e = np.zeros_like(As, dtype=np.int32)
        t = np.zeros_like(As, dtype=np.int32)

        border_cases = 0
        successful = 0

        bit_idx = 0
        for i in range(self.params.k):
            for j in range(self.params.n):
                As_ij = As[i, j]
                C_bit = C_bits[bit_idx]
                LSB_As = As_ij & 1

                # LSB(t) = LSB(As) XOR LSB(e) must equal C_bit
                required_e_lsb = LSB_As ^ C_bit
                e_ij = CentralBinomialDistribution.sample_conditional(required_e_lsb)
                e[i, j] = e_ij

                t_raw = As_ij + e_ij
                t_ij = self.mod_center(t_raw, q, q_half)
                t[i, j] = t_ij

                if abs(t_ij) >= border_thresh:
                    border_cases += 1
                else:
                    if (t_ij & 1) == C_bit:
                        successful += 1

                bit_idx += 1

        used_bits = bit_idx
        non_border = used_bits - border_cases
        accuracy = successful / non_border if non_border > 0 else 0.0

        return {
            "pk": (t, pk_seed),
            "sk": s,
            "A": A,
            "e": e,
            "d": K,
            "mc_sk": mc_sk,
            "C_bits": C_bits,
            "border_cases": border_cases,
            "accuracy": accuracy,
            "sk_seed": sk_seed,
        }


# ==============================================
# 5. Backdoor key recovery (KeyRec*)
# ==============================================

class BackdoorRecovery:
    """
    Python model of Algorithm 3 (KeyRec*).
    """

    def __init__(self, params: KyberParams = None):
        self.params = params or KyberParams()
        self.mce = MockClassicMcEliece(ciphertext_bits=self.params.total_coefficients())

    def extract_ciphertext_bits(
        self, t: np.ndarray
    ) -> Tuple[List[Optional[int]], List[int]]:
        """
        Extract C bits and mark border indices:

        - If |t_i| >= (q - 3)/2, we record None and index i
        - Otherwise we record LSB(t_i)
        """
        border_thresh = self.params.border_thresh
        flat = t.flatten()
        bits: List[Optional[int]] = []
        borders: List[int] = []

        for idx, val in enumerate(flat):
            v = int(val)
            if abs(v) >= border_thresh:
                bits.append(None)
                borders.append(idx)
            else:
                bits.append(v & 1)

        return bits, borders

    def key_recovery(
        self,
        pk: Tuple[np.ndarray, bytes],
        mc_sk: MockMcElieceSecret,
        max_border_attempts: int = 8,
    ) -> Optional[np.ndarray]:
        """
        Attempt to recover the Kyber secret s from pk and mc_sk.

        We try all 2^i assignments to the i border bits, with
        i capped at max_border_attempts.
        """
        t, pk_seed = pk
        extracted_bits, border_indices = self.extract_ciphertext_bits(t)

        if len(border_indices) > max_border_attempts:
            border_indices = border_indices[:max_border_attempts]

        num_borders = len(border_indices)
        num_candidates = 1 << num_borders

        for mask in range(num_candidates):
            candidate = extracted_bits.copy()

            for bit_pos, idx in enumerate(border_indices):
                bit_val = (mask >> bit_pos) & 1
                candidate[idx] = bit_val

            # Replace any remaining None by 0 (they should not appear
            # after we capped border_indices)
            C_candidate = [0 if b is None else int(b) for b in candidate]

            K_candidate = self.mce.decap(mc_sk, C_candidate)
            if K_candidate is None:
                continue

            h = hashlib.sha3_512(K_candidate).digest()
            sk_seed_cand = h[:32]
            pk_seed_cand = h[32:]

            if pk_seed_cand != pk_seed:
                continue

            # Reconstruct s deterministically from sk_seed_cand
            seed_int = int.from_bytes(
                hashlib.sha256(sk_seed_cand).digest()[:4], "big"
            )
            rng = np.random.RandomState(seed_int)
            s_rec = np.zeros((self.params.k, self.params.n), dtype=np.int32)
            for i in range(self.params.k):
                for j in range(self.params.n):
                    s_rec[i, j] = CentralBinomialDistribution.sample_B2_from_rng(rng)
            return s_rec

        return None


# ==============================================
# 6. Undetectability analysis helper (optional)
# ==============================================

class UndetectabilityAnalyzer:
    def __init__(self, params: KyberParams = None):
        self.params = params or KyberParams()
        self.backdoor = BackdooredKyber(self.params)

    def collect_statistics(self, num_keys: int = 50) -> Dict[str, np.ndarray]:
        accuracies: List[float] = []
        border_counts: List[int] = []
        errors: List[int] = []
        lsbs: List[int] = []

        for _ in range(num_keys):
            k = self.backdoor.keygen_backdoored()
            accuracies.append(k["accuracy"])
            border_counts.append(k["border_cases"])

            e_flat = k["e"].flatten()
            errors.extend(int(x) for x in e_flat)

            t_flat = k["pk"][0].flatten()
            lsbs.extend(int(v) & 1 for v in t_flat)

        return {
            "accuracies": np.array(accuracies),
            "border_counts": np.array(border_counts),
            "errors": np.array(errors),
            "lsbs": np.array(lsbs),
        }

    def analyze(self, stats: Dict[str, np.ndarray]) -> None:
        errors = stats["errors"]
        border_counts = stats["border_counts"]
        lsbs = stats["lsbs"]
        accuracies = stats["accuracies"]

        print("\n=== Error distribution (should look like B2) ===")
        for v in [-2, -1, 0, 1, 2]:
            p = float(np.mean(errors == v))
            print(f"  P(e = {v:2d}) = {p:.4f}")

        print("\n=== Border cases ===")
        avg_borders = float(np.mean(border_counts))
        expected_borders = self.params.total_coefficients() * (4 / self.params.q)
        print(f"  average border count per key = {avg_borders:.3f}")
        print(f"  rough expected count          = {expected_borders:.3f}")

        print("\n=== LSB distribution of t ===")
        p0 = float(np.mean(lsbs == 0))
        p1 = 1.0 - p0
        print(f"  P(LSB(t) = 0) = {p0:.4f}")
        print(f"  P(LSB(t) = 1) = {p1:.4f}")

        print("\n=== Embedding accuracy (non-border positions) ===")
        print(f"  mean accuracy = {float(np.mean(accuracies)):.4f}")
        print(f"  min accuracy  = {float(np.min(accuracies)):.4f}")
        print(f"  max accuracy  = {float(np.max(accuracies)):.4f}")


# ==============================================
# 7. End-to-end demonstration
# ==============================================

def demonstrate_paper_results() -> None:
    """
    Simple demonstration:

      1. Check B2 and conditional distributions
      2. Run one backdoored KeyGen*
      3. Recover the key with KeyRec*
      4. Optionally sample many keys for undetectability stats
    """
    params = KyberParams()
    backdoor = BackdooredKyber(params)
    recovery = BackdoorRecovery(params)

    print("=" * 70)
    print("Kyber-768 Backdoor (Xia-Wang-Gu style) - Python model")
    print("=" * 70)

    print("\n[1] Verifying distributions")
    CentralBinomialDistribution.verify_distribution(20000)

    print("\n[2] Generating one backdoored key")
    start = time.time()
    key = backdoor.keygen_backdoored()
    elapsed = (time.time() - start) * 1000.0

    t, pk_seed = key["pk"]
    print(f"  keygen_backdoored time : {elapsed:.2f} ms")
    print(f"  t shape                : {t.shape}")
    print(f"  s shape                : {key['sk'].shape}")
    print(f"  border cases           : {key['border_cases']}")
    print(f"  embedding accuracy     : {key['accuracy']:.4f}")

    print("\n[3] Attempting key recovery")
    s_rec = recovery.key_recovery(key["pk"], key["mc_sk"], max_border_attempts=8)
    if s_rec is None:
        print("  recovery failed")
    else:
        ok = np.array_equal(s_rec, key["sk"])
        print(f"  recovery succeeded     : {ok}")

    print("\n[4] Undetectability sampling (small sample)")
    analyzer = UndetectabilityAnalyzer(params)
    stats = analyzer.collect_statistics(num_keys=10)
    analyzer.analyze(stats)


if __name__ == "__main__":
    demonstrate_paper_results()


Kyber-768 Backdoor (Xia-Wang-Gu style) - Python model

[1] Verifying distributions
B2 distribution (empirical):
  -2: 0.0604
  -1: 0.2492
   0: 0.3797
   1: 0.2482
   2: 0.0624

D0 distribution (empirical):
  -2: 0.1278
   0: 0.7462
   2: 0.1259

D1 distribution (empirical):
  -1: 0.5020
   1: 0.4980

[2] Generating one backdoored key
  keygen_backdoored time : 216.23 ms
  t shape                : (3, 256)
  s shape                : (3, 256)
  border cases           : 1
  embedding accuracy     : 1.0000

[3] Attempting key recovery
  recovery succeeded     : True

[4] Undetectability sampling (small sample)

=== Error distribution (should look like B2) ===
  P(e = -2) = 0.0632
  P(e = -1) = 0.2483
  P(e =  0) = 0.3749
  P(e =  1) = 0.2499
  P(e =  2) = 0.0638

=== Border cases ===
  average border count per key = 1.300
  rough expected count          = 0.923

=== LSB distribution of t ===
  P(LSB(t) = 0) = 0.5056
  P(LSB(t) = 1) = 0.4944

=== Embedding accuracy (non-border positions) =