<a href="https://colab.research.google.com/github/owlmt/PQC/blob/main/mlkem_backdoor_Ravi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import os
import hashlib
import struct
import secrets
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
import numpy as np

# ============================================================
# FIXED encode_payload_to_poly and recovery
# ============================================================
def encode_payload_to_poly(ct, k, n):
    """Encode ciphertext bits into polynomial coefficients - PROPERLY FIXED."""
    b = len(ct) * 8  # Total bits in ciphertext

    # Calculate c: bits per coefficient
    c = (b + k * n - 1) // (k * n)  # ceil(b/(k*n))
    if c == 0:
        c = 1

    print(f"  Encoding: {len(ct)} bytes = {b} bits")
    print(f"  Coefficients available: k*n = {k}*{n} = {k*n}")
    print(f"  Bits per coefficient (c): {c}")

    # Convert ciphertext to bit string (LITTLE-ENDIAN per Kyber spec)
    bits = ''
    for byte in ct:
        # Kyber uses little-endian bit order
        bits += ''.join(f'{(byte >> j) & 1}' for j in range(8))

    # Pad to multiple of c
    padded_len = ((len(bits) + c - 1) // c) * c
    bits = bits.ljust(padded_len, '0')

    # Split into chunks of c bits
    num_chunks = len(bits) // c
    p = [0] * (k * n)

    for i in range(min(num_chunks, k*n)):  # Don't exceed available coefficients
        chunk = bits[i*c:(i+1)*c]
        p[i] = int(chunk, 2)

    print(f"  Actual chunks used: {min(num_chunks, k*n)}/{k*n}")
    return p, c, b, min(num_chunks, k*n)

def apply_compensation(t_coeffs, p, c, num_chunks):
    """Apply compensation to ensure t'[i] mod 2^c == p[i]."""
    t_prime = t_coeffs.copy()
    modulus = 1 << c  # 2^c

    for i in range(num_chunks):
        # Current value modulo 2^c
        current_mod = t_coeffs[i] % modulus

        # Desired value (p[i] is already in [0, 2^c-1])
        desired_mod = p[i] % modulus

        # Compute adjustment (in range [-modulus/2, modulus/2])
        adjustment = (desired_mod - current_mod) % modulus
        if adjustment > modulus // 2:
            adjustment -= modulus

        # Apply adjustment
        t_prime[i] = (t_coeffs[i] + adjustment) % 3329

    return t_prime

def recover_from_t_prime(t_prime, c, b, num_chunks):
    """Recover ciphertext from t' coefficients - FIXED."""
    modulus = 1 << c  # 2^c

    # Recover p[i] values
    p_rec = [0] * num_chunks
    for i in range(num_chunks):
        p_rec[i] = t_prime[i] % modulus

    # Convert p[i] values back to bit string
    bit_chunks = []
    for val in p_rec:
        # Format as c-bit binary string, padding with leading zeros
        chunk = f'{val:0{c}b}'
        if len(chunk) > c:
            chunk = chunk[-c:]  # Take only c bits if val too large
        bit_chunks.append(chunk)

    bitstring = ''.join(bit_chunks)

    # Trim to original bit length
    bitstring = bitstring[:b]

    # Pad to multiple of 8 if needed
    if len(bitstring) % 8 != 0:
        bitstring = bitstring.ljust(len(bitstring) + (8 - len(bitstring) % 8), '0')

    # Convert to bytes (little-endian per Kyber)
    bytes_list = []
    for i in range(0, len(bitstring), 8):
        chunk = bitstring[i:i+8]
        if len(chunk) == 8:
            byte_val = 0
            for j, bit in enumerate(chunk):
                if bit == '1':
                    byte_val |= (1 << j)  # Little-endian
            bytes_list.append(byte_val)

    return bytes(bytes_list)

# ============================================================
# Simplified Test with Direct Verification
# ============================================================
def test_simple_backdoor():
    """Test the core encoding/recovery without ECDH complexities."""
    print("=== Simplified Backdoor Test ===")
    print("Testing encoding and recovery of payload in polynomial coefficients")
    print("=" * 60)

    # Use Kyber768 parameters
    k = 3
    n = 256
    q = 3329

    # Create a test payload (like a 128-byte ciphertext)
    payload_size = 128  # bytes
    payload = secrets.token_bytes(payload_size)

    print(f"\n1. Original payload: {len(payload)} bytes")
    print(f"   First 16 bytes: {payload[:16].hex()}")

    # Encode payload into polynomial p
    p, c, b, num_chunks = encode_payload_to_poly(payload, k, n)

    print(f"\n2. Encoding parameters:")
    print(f"   Bits per coefficient (c): {c}")
    print(f"   Number of coefficients used: {num_chunks}")
    print(f"   Max adjustment range: ±{1 << (c-1)}")

    # Generate random t coefficients (simulating MLWE output)
    t = [secrets.randbelow(q) for _ in range(k * n)]

    # Apply compensation
    t_prime = apply_compensation(t, p, c, num_chunks)

    print(f"\n3. Applied compensation:")
    print(f"   Modified {num_chunks} coefficients")

    # Recover payload from t'
    recovered = recover_from_t_prime(t_prime, c, b, num_chunks)

    print(f"\n4. Recovery results:")
    print(f"   Recovered size: {len(recovered)} bytes")
    print(f"   Original size: {len(payload)} bytes")

    # Trim recovered to original length
    recovered = recovered[:len(payload)]

    if payload == recovered:
        print("   ✓ Perfect recovery!")

        # Also verify each coefficient stores correct bits
        print("\n5. Verifying coefficient-level recovery:")
        modulus = 1 << c
        mismatches = 0
        for i in range(num_chunks):
            original_val = p[i] % modulus
            recovered_val = t_prime[i] % modulus
            if original_val != recovered_val:
                mismatches += 1
                if mismatches <= 3:  # Show first 3 mismatches
                    print(f"   Mismatch at coefficient {i}: {original_val} != {recovered_val}")

        if mismatches == 0:
            print("   ✓ All coefficients match!")
        else:
            print(f"   ✗ {mismatches} coefficient mismatches")

        return True
    else:
        print("   ✗ Recovery failed")

        # Find first mismatch
        for i, (orig, rec) in enumerate(zip(payload, recovered)):
            if orig != rec:
                print(f"   First mismatch at byte {i}: 0x{orig:02x} != 0x{rec:02x}")
                break

        return False

# ============================================================
# Test with ECDH (Fixed)
# ============================================================
def test_ecdh_backdoor():
    """Test with ECDH but simplified recovery verification."""
    print("\n=== ECDH Backdoor Test ===")

    # Use smaller parameters for testing
    k = 3
    n = 256
    q = 3329

    # Create ECDH backdoor
    curve = ec.SECP256R1()
    private_key = ec.generate_private_key(curve)
    public_key = private_key.public_key()

    # Generate ciphertext (ephemeral public key)
    ephemeral_private = ec.generate_private_key(curve)
    ephemeral_public = ephemeral_private.public_key()

    # Get uncompressed point
    ct = ephemeral_public.public_bytes(
        encoding=serialization.Encoding.X962,
        format=serialization.PublicFormat.UncompressedPoint
    )

    print(f"ECDH ciphertext size: {len(ct)} bytes")

    # Test encoding and recovery
    p, c, b, num_chunks = encode_payload_to_poly(ct, k, n)

    # Generate random t
    t = [secrets.randbelow(q) for _ in range(k * n)]

    # Apply compensation
    t_prime = apply_compensation(t, p, c, num_chunks)

    # Recover
    recovered_ct = recover_from_t_prime(t_prime, c, b, num_chunks)

    # Trim to original length
    recovered_ct = recovered_ct[:len(ct)]

    if ct == recovered_ct:
        print("✓ ECDH ciphertext recovered perfectly!")

        # Try to reconstruct the public key from recovered bytes
        try:
            recovered_public = ec.EllipticCurvePublicKey.from_encoded_point(
                curve,
                recovered_ct
            )
            print("✓ Recovered point is a valid EC public key")
            return True
        except ValueError as e:
            print(f"✗ Recovered bytes not a valid EC point: {e}")
            return False
    else:
        print("✗ ECDH ciphertext recovery failed")

        # Show first few bytes comparison
        print(f"Original: {ct[:20].hex()}")
        print(f"Recovered: {recovered_ct[:20].hex()}")
        return False

# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
    print("Running corrected backdoor implementation tests")
    print("=" * 60)

    # Test 1: Simple payload encoding/recovery
    print("\nTest 1: Simple payload (128 bytes)")
    success1 = test_simple_backdoor()

    print("\n" + "=" * 60)

    # Test 2: ECDH payload
    print("\nTest 2: ECDH ciphertext")
    success2 = test_ecdh_backdoor()

    print("\n" + "=" * 60)

    # Summary
    print("\nSUMMARY:")
    if success1 and success2:
        print("✓ Both tests PASSED")
        print("\nKey fixes applied:")
        print("1. Proper bit ordering (little-endian)")
        print("2. Correct modulo operation: t'[i] mod 2^c (not mod c)")
        print("3. Proper padding/truncation of bit strings")
        print("4. Consistent byte reconstruction")
    else:
        print("✗ Tests FAILED")

    print("\nThe paper's method works by:")
    print("- Storing c bits per coefficient in p[i] (0 to 2^c-1)")
    print("- Adjusting t[i] so t'[i] ≡ p[i] (mod 2^c)")
    print("- Recovery: p[i] = t'[i] mod 2^c")
    print(f"- For 128-byte payload: c = ceil({128*8}/(3*256)) = 2")

Running corrected backdoor implementation tests

Test 1: Simple payload (128 bytes)
=== Simplified Backdoor Test ===
Testing encoding and recovery of payload in polynomial coefficients

1. Original payload: 128 bytes
   First 16 bytes: 7a88a629ac0fd77ed7830379df9bc241
  Encoding: 128 bytes = 1024 bits
  Coefficients available: k*n = 3*256 = 768
  Bits per coefficient (c): 2
  Actual chunks used: 512/768

2. Encoding parameters:
   Bits per coefficient (c): 2
   Number of coefficients used: 512
   Max adjustment range: ±2

3. Applied compensation:
   Modified 512 coefficients

4. Recovery results:
   Recovered size: 128 bytes
   Original size: 128 bytes
   ✓ Perfect recovery!

5. Verifying coefficient-level recovery:
   ✓ All coefficients match!


Test 2: ECDH ciphertext

=== ECDH Backdoor Test ===
ECDH ciphertext size: 65 bytes
  Encoding: 65 bytes = 520 bits
  Coefficients available: k*n = 3*256 = 768
  Bits per coefficient (c): 1
  Actual chunks used: 520/768
✓ ECDH ciphertext recove

In [8]:
import os
import hashlib
import struct
import secrets
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization

# ============================================================
# Kleptographic Backdoor Demonstration for ML-KEM (Kyber)
# Based on paper: "Backdooring Post-Quantum Cryptography"
# ============================================================

print("""
================================================================================
            KLEPTOGRAPHIC BACKDOOR DEMONSTRATION FOR ML-KEM (KYBER)
================================================================================

This demonstration shows how the paper "Backdooring Post-Quantum Cryptography"
implements a kleptographic backdoor in the key generation procedure of ML-KEM.

Key Concept: The backdoor embeds an encrypted version of the secret key seed
             into the public key, which only the attacker can decrypt.

Paper's Methodology:
1. Attacker creates a backdoor key pair (pk_bd, sk_bd) using ECDH or Classic McEliece
2. Malicious implementation embeds attacker's pk_bd in the device
3. During Kyber key generation:
   - Generate ECDH ciphertext ct_bd and shared secret seed_B
   - Use seed_B to generate Kyber's secret key s
   - Encode ct_bd into polynomial coefficients p[i]
   - Adjust public key coefficients t'[i] so that: t'[i] mod 2^c = p[i]
4. Attacker extracts ct_bd from public key, decrypts it with sk_bd to get seed_B
5. Attacker reconstructs secret key s from seed_B

This backdoor is:
- Undetectable to users (public keys appear normal)
- Exclusive to attacker (only they can decrypt ct_bd)
- Cryptographically robust (uses proper public-key crypto)
================================================================================
""")

# ============================================================
# 1. SETUP: Attacker Creates Backdoor Key Pair
# ============================================================
print("\n" + "="*80)
print("STEP 1: ATTACKER SETS UP THE BACKDOOR")
print("="*80)

class ECDHBackdoor:
    """Backdoor using ECDH (as described in paper's pre-quantum backdoor)"""
    def __init__(self):
        self.curve = ec.SECP256R1()
        self.private_key = ec.generate_private_key(self.curve)
        self.public_key = self.private_key.public_key()

    def get_public_bytes(self):
        return self.public_key.public_bytes(
            encoding=serialization.Encoding.X962,
            format=serialization.PublicFormat.UncompressedPoint
        )

    def encrypt(self):
        """Generate ECDH ciphertext (ephemeral public key) and shared secret"""
        ephemeral_private = ec.generate_private_key(self.curve)
        ephemeral_public = ephemeral_private.public_key()

        # Compute shared secret (seed_B)
        shared_secret = ephemeral_private.exchange(ec.ECDH(), self.public_key)

        # Ciphertext is the ephemeral public key
        ct = ephemeral_public.public_bytes(
            encoding=serialization.Encoding.X962,
            format=serialization.PublicFormat.UncompressedPoint
        )

        # First 32 bytes of shared secret become seed_B for Kyber
        seed_b = shared_secret[:32]

        return ct, seed_b

    def decrypt(self, ct):
        """Decrypt ciphertext to recover seed_B"""
        ephemeral_public = ec.EllipticCurvePublicKey.from_encoded_point(
            self.curve,
            ct
        )
        shared_secret = self.private_key.exchange(ec.ECDH(), ephemeral_public)
        return shared_secret[:32]

# Attacker creates backdoor
print("\nAttacker generates ECDH backdoor key pair...")
backdoor = ECDHBackdoor()
pk_bd = backdoor.get_public_bytes()
sk_bd = backdoor.private_key

print(f"✓ Backdoor public key (embedded in device): {len(pk_bd)} bytes")
print(f"✓ Backdoor secret key (kept by attacker): Private")
print(f"  Curve: secp256r1 (uncompressed points = 65 bytes)")

# ============================================================
# 2. VICTIM: Backdoored Key Generation (Malicious Implementation)
# ============================================================
print("\n" + "="*80)
print("STEP 2: VICTIM GENERATES BACKDOORED KYBER KEY PAIR")
print("="*80)
print("\nUser runs what they think is standard Kyber-768 key generation...")
print("But the implementation has been modified to include the backdoor.")

def encode_payload_to_poly(ct, k, n):
    """Encode ciphertext bits into polynomial coefficients (Paper's Eq. 1)"""
    b = len(ct) * 8
    c = (b + k * n - 1) // (k * n)  # ceil(b/(k*n))

    # Convert to bits (little-endian as per Kyber)
    bits = ''
    for byte in ct:
        bits += ''.join(f'{(byte >> j) & 1}' for j in range(8))

    # Pad to multiple of c
    padded_len = ((len(bits) + c - 1) // c) * c
    bits = bits.ljust(padded_len, '0')

    # Split into chunks of c bits to form p[i] coefficients
    num_chunks = len(bits) // c
    p = [0] * (k * n)

    for i in range(min(num_chunks, k*n)):
        chunk = bits[i*c:(i+1)*c]
        p[i] = int(chunk, 2)

    return p, c, b, min(num_chunks, k*n)

def cbd(seed, eta=2):
    """Centered Binomial Distribution for Kyber secret generation"""
    from hashlib import shake_256
    required_bytes = 64 * eta
    expanded = shake_256(seed).digest(required_bytes)

    coeffs = []
    bit_pos = 0

    for i in range(256):
        x = 0
        y = 0
        for j in range(eta):
            byte_idx = bit_pos // 8
            bit_in_byte = bit_pos % 8
            bit_val = (expanded[byte_idx] >> bit_in_byte) & 1
            x += bit_val
            bit_pos += 1
        for j in range(eta):
            byte_idx = bit_pos // 8
            bit_in_byte = bit_pos % 8
            bit_val = (expanded[byte_idx] >> bit_in_byte) & 1
            y += bit_val
            bit_pos += 1
        coeffs.append(x - y)

    return coeffs

print("\nInside the malicious key generation:")
print("1. Generate ECDH ciphertext and shared secret...")
ct_bd, seed_b = backdoor.encrypt()
print(f"   → ECDH ciphertext (ct_bd): {len(ct_bd)} bytes")
print(f"   → Shared secret (seed_B): {seed_b[:8].hex()}... (used for Kyber secret)")

print("\n2. Use seed_B to generate Kyber secret key s...")
# Generate Kyber-768 secret (k=3 vectors of 256 coefficients)
k = 3
s_coeffs = []
for i in range(k):
    s_seed = seed_b + struct.pack('B', i)
    s_i = cbd(s_seed, eta=2)
    s_coeffs.append(s_i)

print(f"   → Generated secret key s with {k} polynomials")
print(f"   → First coefficient: s[0][0] = {s_coeffs[0][0]}")

print("\n3. Encode ciphertext ct_bd into polynomial p...")
p, c, b, num_chunks = encode_payload_to_poly(ct_bd, k=k, n=256)
print(f"   → Ciphertext bits: {b} bits")
print(f"   → Stored as: {num_chunks} coefficients with {c} bits each")
print(f"   → Each p[i] ∈ [0, {2**c - 1}] stores {c} bits of ciphertext")

print("\n4. Generate public key t = A·s + e (MLWE)...")
print("   (In real Kyber: t would be computed from matrix A, secret s, and error e)")
print("   For demonstration, we create a random t and then adjust it...")

# Simulate random t coefficients (would be A·s + e in real Kyber)
q = 3329
t = [secrets.randbelow(q) for _ in range(k * 256)]

print("\n5. Apply compensation to embed p[i] into t[i] (Paper's Eq. 2)...")
print("   Adjust t[i] to t'[i] such that: t'[i] mod 2^c = p[i]")
print(f"   Compensation range: ±{2**(c-1)}")

modulus = 1 << c
t_prime = t.copy()
for i in range(num_chunks):
    current_mod = t[i] % modulus
    desired_mod = p[i] % modulus

    # Compute smallest adjustment in [-modulus/2, modulus/2]
    adjustment = (desired_mod - current_mod) % modulus
    if adjustment > modulus // 2:
        adjustment -= modulus

    t_prime[i] = (t[i] + adjustment) % q

print(f"   → Modified {num_chunks} coefficients of t")
print(f"   → Max adjustment applied: {max(abs((t_prime[i] - t[i]) % q) for i in range(num_chunks))}")

print("\n6. Output public key (contains adjusted t') and secret key s")
print("   → Public key looks normal (t' appears random)")
print("   → Secret key s is stored normally")
print("   ✓ Key generation complete - user gets valid Kyber key pair")

# ============================================================
# 3. ATTACKER: Extracting Secret Key from Public Key
# ============================================================
print("\n" + "="*80)
print("STEP 3: ATTACKER EXTRACTS SECRET KEY FROM PUBLIC KEY")
print("="*80)
print("\nAttacker obtains the victim's public key (e.g., from TLS handshake)...")

def recover_from_t_prime(t_prime, c, b, num_chunks):
    """Recover ciphertext from t' coefficients (reverse of encoding)"""
    modulus = 1 << c

    # Extract p[i] = t'[i] mod 2^c
    p_rec = [t_prime[i] % modulus for i in range(num_chunks)]

    # Convert p[i] back to bit string
    bit_chunks = [f'{val:0{c}b}'[-c:] for val in p_rec]
    bitstring = ''.join(bit_chunks)[:b]

    # Pad to multiple of 8
    if len(bitstring) % 8 != 0:
        bitstring = bitstring.ljust(len(bitstring) + (8 - len(bitstring) % 8), '0')

    # Convert to bytes (little-endian)
    bytes_list = []
    for i in range(0, len(bitstring), 8):
        chunk = bitstring[i:i+8]
        if len(chunk) == 8:
            byte_val = 0
            for j, bit in enumerate(chunk):
                if bit == '1':
                    byte_val |= (1 << j)
            bytes_list.append(byte_val)

    return bytes(bytes_list)

print("\n1. Attacker extracts ct_bd from public key coefficients t'...")
print(f"   Looking at first {num_chunks} coefficients of t'...")
recovered_ct = recover_from_t_prime(t_prime, c, b, num_chunks)
recovered_ct = recovered_ct[:len(ct_bd)]

print(f"   → Recovered ciphertext: {len(recovered_ct)} bytes")
print(f"   → Ciphertexts match: {ct_bd == recovered_ct}")

print("\n2. Attacker decrypts ct_bd using their backdoor secret key sk_bd...")
recovered_seed_b = backdoor.decrypt(recovered_ct)
print(f"   → Recovered seed_B: {recovered_seed_b[:8].hex()}...")
print(f"   → Seeds match: {seed_b == recovered_seed_b}")

print("\n3. Attacker regenerates Kyber secret key s from seed_B...")
s_recovered = []
for i in range(k):
    s_seed = recovered_seed_b + struct.pack('B', i)
    s_i = cbd(s_seed, eta=2)
    s_recovered.append(s_i)

# Verify recovery
match = all(
    s_coeffs[i][j] == s_recovered[i][j]
    for i in range(k)
    for j in range(min(10, 256))  # Check first 10 coefficients
)

print(f"   → Secret key recovery successful: {match}")
print(f"   → First coefficients match: {s_coeffs[0][0]} == {s_recovered[0][0]}")

# ============================================================
# 4. IMPACT: Decrypting TLS Traffic
# ============================================================
print("\n" + "="*80)
print("STEP 4: PRACTICAL IMPACT - DECRYPTING TLS 1.3 TRAFFIC")
print("="*80)

print("""
WITH THE SECRET KEY, THE ATTACKER CAN NOW:

1. Passive Eavesdropping:
   - Attacker records TLS 1.3 handshake containing victim's public key
   - Extracts and decrypts secret key as demonstrated above
   - Computes shared secret for the session
   - Decrypts all encrypted application data

2. Real-World Scenario:
   ▸ Victim: "Alice" connects to "Bob's" server using TLS 1.3 with ML-KEM
   ▸ Alice's device uses backdoored ML-KEM implementation
   ▸ Alice sends her public key to Bob during key exchange
   ▸ Attacker: Sees Alice's public key on network
   ▸ Attacker: Extracts Alice's secret key from public key
   ▸ Attacker: Computes same shared secret as Alice and Bob
   ▸ Attacker: Decrypts all Alice↔Bob communications

3. Why This is Dangerous:
   • Undetectable: Public keys appear normal, pass all validity checks
   • Exclusive: Only attacker with sk_bd can exploit the backdoor
   • Retroactive: Old captured traffic can be decrypted if attacker has sk_bd
   • Protocol-level: Works within standard TLS 1.3, not just primitive
""")

# ============================================================
# 5. DEFENSE: Detection and Prevention
# ============================================================
print("\n" + "="*80)
print("STEP 5: DEFENSE - HOW TO DETECT AND PREVENT")
print("="*80)

print("""
DEFENSE MECHANISMS (from paper discussion):

1. Statistical Detection (User with secret key can detect):
   • User can recompute error: e = t' - A·s
   • Backdoored keys have error coefficients in range [-η-c/2, η+c/2]
   • Normal keys have error in [-η, η] (η=2 for Kyber-768)
   • Detection: Check if error coefficients exceed normal bounds

2. Implementation Countermeasures:
   • Use verified implementations from trusted sources
   • Audit open-source code for backdoors
   • Implement runtime checks on key generation output
   • Use hardware with verified secure elements

3. For Kyber-768 with ECDH backdoor:
   • Normal error range: [-2, 2]
   • Backdoored error range: [-3, 3] (with c=2, compensation ±1)
   • User can check: max(|e[i]|) > 2 indicates potential backdoor

DEMONSTRATING DETECTION:
""")

# Simulate error computation for detection
print("To detect backdoor, user with secret key s can:")
print("1. Recompute what error should be: e = t' - A·s")
print("2. Check if any |e[i]| > η (where η=2 for Kyber-768)")
print("3. If yes, key may be backdoored")

# Simulated error (in reality would compute A·s)
print(f"\nFor our demonstration:")
print(f"  • We added compensation of ±{2**(c-1)} to {num_chunks} coefficients")
print(f"  • So some coefficients of e would be in range [{-2-2**(c-1)}, {2+2**(c-1)}]")
print(f"  • Normal range is only [{-2}, {2}]")
print(f"  • Detection: User finds coefficients outside [-2, 2] → SUSPICIOUS!")

print("""
However, in practice:
• Users rarely perform such checks
• Backdoor can be tuned to keep error within acceptable bounds
• Paper shows decryption failure rate remains negligible (~2^-124.5)
""")

# ============================================================
# SUMMARY
# ============================================================
print("\n" + "="*80)
print("SUMMARY: KLEPTOGRAPHIC BACKDOOR IN ML-KEM")
print("="*80)

print("""
KEY INSIGHTS:

1. The Attack:
   • Embeds encrypted secret key seed in public key coefficients
   • Uses modular arithmetic: t'[i] mod 2^c = ciphertext bits
   • Only attacker with backdoor secret key can decrypt

2. The Stealth:
   • Public keys remain valid for encryption
   • Error increase is small (decryption failure ~2^-124.5)
   • No modification to random number generation
   • Works within standard protocols (TLS 1.3)

3. The Impact:
   • Complete compromise of ML-KEM security
   • Passive eavesdropping on encrypted communications
   • Undetectable without the secret key

4. The Defense:
   • Users with secret key can detect by checking error bounds
   • Use verified implementations and hardware
   • Runtime validation of cryptographic operations

CONCLUSION:
This demonstration shows how sophisticated cryptographic backdoors can be
implemented in post-quantum cryptography. The paper's attack highlights
the importance of implementation trust in cryptographic systems, especially
as we transition to new quantum-resistant algorithms.

The backdoor is mathematically elegant but practically dangerous - a stark
reminder that "black box" cryptographic implementations must be treated
with extreme caution.
""")

print("\n" + "="*80)
print("END OF DEMONSTRATION")
print("="*80)


            KLEPTOGRAPHIC BACKDOOR DEMONSTRATION FOR ML-KEM (KYBER)
          
This demonstration shows how the paper "Backdooring Post-Quantum Cryptography"
implements a kleptographic backdoor in the key generation procedure of ML-KEM.

Key Concept: The backdoor embeds an encrypted version of the secret key seed
             into the public key, which only the attacker can decrypt.

Paper's Methodology:
1. Attacker creates a backdoor key pair (pk_bd, sk_bd) using ECDH or Classic McEliece
2. Malicious implementation embeds attacker's pk_bd in the device
3. During Kyber key generation:
   - Generate ECDH ciphertext ct_bd and shared secret seed_B
   - Use seed_B to generate Kyber's secret key s
   - Encode ct_bd into polynomial coefficients p[i]
   - Adjust public key coefficients t'[i] so that: t'[i] mod 2^c = p[i]
4. Attacker extracts ct_bd from public key, decrypts it with sk_bd to get seed_B
5. Attacker reconstructs secret key s from seed_B

This backdoor is:
- Undetectable to users

In [9]:
# COMBINED COMPLETE IMPLEMENTATION:

import os
import hashlib
import struct
import secrets
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization

# ============================================================
# PART 1: ECDH BACKDOOR (from demonstration)
# ============================================================
class ECDHBackdoor:
    def __init__(self):
        self.curve = ec.SECP256R1()
        self.private_key = ec.generate_private_key(self.curve)
        self.public_key = self.private_key.public_key()

    def encrypt(self):
        ephemeral_private = ec.generate_private_key(self.curve)
        ephemeral_public = ephemeral_private.public_key()

        shared_secret = ephemeral_private.exchange(ec.ECDH(), self.public_key)

        ct = ephemeral_public.public_bytes(
            encoding=serialization.Encoding.X962,
            format=serialization.PublicFormat.UncompressedPoint
        )

        seed_b = shared_secret[:32]
        return ct, seed_b

    def decrypt(self, ct):
        ephemeral_public = ec.EllipticCurvePublicKey.from_encoded_point(
            self.curve,
            ct
        )
        shared_secret = self.private_key.exchange(ec.ECDH(), ephemeral_public)
        return shared_secret[:32]

# ============================================================
# PART 2: KYBER CORE FUNCTIONS (from technical implementation)
# ============================================================
def cbd(seed, eta=2):
    """Centered Binomial Distribution - working version."""
    from hashlib import shake_256
    required_bytes = 64 * eta
    expanded = shake_256(seed).digest(required_bytes)

    coeffs = []
    bit_pos = 0

    for i in range(256):
        x = 0
        y = 0
        for j in range(eta):
            byte_idx = bit_pos // 8
            bit_in_byte = bit_pos % 8
            bit_val = (expanded[byte_idx] >> bit_in_byte) & 1
            x += bit_val
            bit_pos += 1
        for j in range(eta):
            byte_idx = bit_pos // 8
            bit_in_byte = bit_pos % 8
            bit_val = (expanded[byte_idx] >> bit_in_byte) & 1
            y += bit_val
            bit_pos += 1
        coeffs.append(x - y)

    return coeffs

def ntt(poly):
    """Simplified NTT (for demonstration)."""
    # Real implementation would have proper NTT with zeta values
    return poly  # Placeholder

# ============================================================
# PART 3: BACKDOOR ENCODING/RECOVERY (from demonstration)
# ============================================================
def encode_payload_to_poly(ct, k, n):
    """Encode ciphertext bits into polynomial coefficients."""
    b = len(ct) * 8
    c = (b + k * n - 1) // (k * n)

    bits = ''
    for byte in ct:
        bits += ''.join(f'{(byte >> j) & 1}' for j in range(8))

    padded_len = ((len(bits) + c - 1) // c) * c
    bits = bits.ljust(padded_len, '0')

    num_chunks = len(bits) // c
    p = [0] * (k * n)

    for i in range(min(num_chunks, k*n)):
        chunk = bits[i*c:(i+1)*c]
        p[i] = int(chunk, 2)

    return p, c, b, min(num_chunks, k*n)

def apply_compensation(t_coeffs, p, c, num_chunks, q=3329):
    """Apply compensation to embed p[i] into t[i]."""
    modulus = 1 << c
    t_prime = t_coeffs.copy()

    for i in range(num_chunks):
        current_mod = t_coeffs[i] % modulus
        desired_mod = p[i] % modulus

        adjustment = (desired_mod - current_mod) % modulus
        if adjustment > modulus // 2:
            adjustment -= modulus

        t_prime[i] = (t_coeffs[i] + adjustment) % q

    return t_prime

def recover_from_t_prime(t_prime, c, b, num_chunks):
    """Recover ciphertext from t' coefficients."""
    modulus = 1 << c
    p_rec = [t_prime[i] % modulus for i in range(num_chunks)]

    bit_chunks = [f'{val:0{c}b}'[-c:] for val in p_rec]
    bitstring = ''.join(bit_chunks)[:b]

    if len(bitstring) % 8 != 0:
        bitstring = bitstring.ljust(len(bitstring) + (8 - len(bitstring) % 8), '0')

    bytes_list = []
    for i in range(0, len(bitstring), 8):
        chunk = bitstring[i:i+8]
        if len(chunk) == 8:
            byte_val = 0
            for j, bit in enumerate(chunk):
                if bit == '1':
                    byte_val |= (1 << j)
            bytes_list.append(byte_val)

    return bytes(bytes_list)

# ============================================================
# PART 4: MAIN DEMONSTRATION WITH REAL FUNCTIONALITY
# ============================================================
def complete_demonstration():
    print("COMPLETE KLEPTOGRAPHIC BACKDOOR DEMONSTRATION")
    print("=" * 60)

    # 1. Setup
    backdoor = ECDHBackdoor()
    print("1. Backdoor created ✓")

    # 2. Generate backdoored key
    ct_bd, seed_b = backdoor.encrypt()
    print(f"2. ECDH: ct={len(ct_bd)}B, seed={seed_b[:8].hex()}... ✓")

    # Generate Kyber secret from seed_b
    k, n, q = 3, 256, 3329
    s_coeffs = []
    for i in range(k):
        s_seed = seed_b + struct.pack('B', i)
        s_i = cbd(s_seed, eta=2)
        s_coeffs.append(s_i)
    print(f"3. Generated Kyber secret s ✓")

    # Encode ciphertext
    p, c, b, num_chunks = encode_payload_to_poly(ct_bd, k, n)
    print(f"4. Encoded ct into {num_chunks} coefficients (c={c}) ✓")

    # Create and adjust t
    t = [secrets.randbelow(q) for _ in range(k * n)]
    t_prime = apply_compensation(t, p, c, num_chunks, q)
    print(f"5. Adjusted {num_chunks} coefficients of t ✓")

    # 3. Recovery
    recovered_ct = recover_from_t_prime(t_prime, c, b, num_chunks)
    recovered_ct = recovered_ct[:len(ct_bd)]
    print(f"6. Recovered ct from t' ✓")

    recovered_seed = backdoor.decrypt(recovered_ct)
    print(f"7. Decrypted seed: {recovered_seed[:8].hex()}... ✓")

    # Regenerate secret
    s_recovered = []
    for i in range(k):
        s_seed = recovered_seed + struct.pack('B', i)
        s_i = cbd(s_seed, eta=2)
        s_recovered.append(s_i)

    # Verify
    match = s_coeffs[0][:10] == s_recovered[0][:10]
    print(f"8. Secret recovery: {'✓ SUCCESS' if match else '✗ FAILED'}")

    return match

if __name__ == "__main__":
    if complete_demonstration():
        print("\n" + "="*60)
        print("COMPLETE DEMONSTRATION SUCCESSFUL")
        print("Shows both: Concept explanation + Working implementation")
    else:
        print("\nDemonstration failed - check implementation")

COMPLETE KLEPTOGRAPHIC BACKDOOR DEMONSTRATION
1. Backdoor created ✓
2. ECDH: ct=65B, seed=29ee3e3da0e2a7de... ✓
3. Generated Kyber secret s ✓
4. Encoded ct into 520 coefficients (c=1) ✓
5. Adjusted 520 coefficients of t ✓
6. Recovered ct from t' ✓
7. Decrypted seed: 29ee3e3da0e2a7de... ✓
8. Secret recovery: ✓ SUCCESS

COMPLETE DEMONSTRATION SUCCESSFUL
Shows both: Concept explanation + Working implementation


In [13]:
import os
import hashlib
import struct
import secrets
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization

# ============================================================
# KLEPTOGRAPHIC BACKDOOR IMPLEMENTATION FOR ML-KEM (KYBER)
# ============================================================

print("""
===============================================================================
            KLEPTOGRAPHIC BACKDOOR FOR ML-KEM - DEBUGGING VERSION
===============================================================================
""")

# ============================================================
# 1. ECDH BACKDOOR WITH DEBUGGING
# ============================================================
class ECDHBackdoor:
    def __init__(self):
        self.curve = ec.SECP256R1()
        self.private_key = ec.generate_private_key(self.curve)
        self.public_key = self.private_key.public_key()

    def get_public_bytes(self):
        return self.public_key.public_bytes(
            encoding=serialization.Encoding.X962,
            format=serialization.PublicFormat.UncompressedPoint
        )

    def encrypt(self):
        ephemeral_private = ec.generate_private_key(self.curve)
        ephemeral_public = ephemeral_private.public_key()

        shared_secret = ephemeral_private.exchange(ec.ECDH(), self.public_key)

        ct = ephemeral_public.public_bytes(
            encoding=serialization.Encoding.X962,
            format=serialization.PublicFormat.UncompressedPoint
        )

        seed_b = shared_secret[:32]
        return ct, seed_b

    def decrypt(self, ct):
        try:
            ephemeral_public = ec.EllipticCurvePublicKey.from_encoded_point(
                self.curve,
                ct
            )
            shared_secret = self.private_key.exchange(ec.ECDH(), ephemeral_public)
            return shared_secret[:32]
        except Exception as e:
            print(f"  Decryption failed: {e}")
            if len(ct) >= 2:
                print(f"  First 2 bytes of recovered ct: 0x{ct[0]:02x}{ct[1]:02x}")
                print(f"  Expected: 0x04 (uncompressed point marker)")
            return None

# ============================================================
# 2. IMPROVED ENCODING/RECOVERY FUNCTIONS
# ============================================================
def encode_payload_to_poly_debug(ct, k, n):
    """Enhanced version with debugging."""
    b = len(ct) * 8
    c = (b + k * n - 1) // (k * n)
    if c == 0:
        c = 1

    print(f"    Original ciphertext: {len(ct)} bytes = {b} bits")
    print(f"    Coefficients available: {k}*{n} = {k*n}")
    print(f"    Bits per coefficient (c): {c}")
    print(f"    Max value per coefficient: 0-{2**c - 1}")

    # Convert to bits (little-endian)
    bits = ''
    for byte in ct:
        bits += ''.join(f'{(byte >> j) & 1}' for j in range(8))

    # Pad to multiple of c
    original_bit_len = len(bits)
    padded_len = ((original_bit_len + c - 1) // c) * c
    bits = bits.ljust(padded_len, '0')

    num_chunks = len(bits) // c
    p = [0] * (k * n)

    # Store only in available coefficients
    chunks_to_use = min(num_chunks, k*n)
    for i in range(chunks_to_use):
        chunk = bits[i*c:(i+1)*c]
        p[i] = int(chunk, 2)

    print(f"    Bits after padding: {len(bits)}")
    print(f"    Chunks to encode: {chunks_to_use}")
    print(f"    First few p values: {p[:5]}")

    return p, c, b, chunks_to_use

def apply_compensation_debug(t_coeffs, p, c, num_chunks, q=3329):
    """Enhanced compensation with verification."""
    modulus = 1 << c
    t_prime = t_coeffs.copy()

    adjustments = []

    for i in range(num_chunks):
        current_mod = t_coeffs[i] % modulus
        desired_mod = p[i] % modulus

        # Compute smallest adjustment
        adjustment = (desired_mod - current_mod) % modulus
        if adjustment > modulus // 2:
            adjustment -= modulus

        t_prime[i] = (t_coeffs[i] + adjustment) % q
        adjustments.append(adjustment)

        # Verify
        result_mod = t_prime[i] % modulus
        if result_mod != desired_mod:
            print(f"    ERROR at coefficient {i}: wanted {desired_mod}, got {result_mod}")

    print(f"    Applied {len(adjustments)} adjustments")
    print(f"    Adjustment range: [{min(adjustments)}, {max(adjustments)}]")
    print(f"    Success rate: {sum(1 for i in range(num_chunks) if (t_prime[i] % modulus) == (p[i] % modulus))}/{num_chunks}")

    return t_prime

def recover_from_t_prime_debug(t_prime, c, b, num_chunks):
    """Enhanced recovery with validation."""
    modulus = 1 << c

    # Extract p[i] values
    p_rec = [t_prime[i] % modulus for i in range(num_chunks)]

    print(f"    Recovered {len(p_rec)} p values")
    print(f"    p value range: [{min(p_rec)}, {max(p_rec)}]")

    # Convert to bit string
    bit_chunks = []
    for val in p_rec:
        chunk = f'{val:0{c}b}'
        if len(chunk) > c:
            chunk = chunk[-c:]  # Take only c bits
        bit_chunks.append(chunk)

    bitstring = ''.join(bit_chunks)
    print(f"    Bitstring length before trimming: {len(bitstring)}")

    # Trim to original length
    bitstring = bitstring[:b]
    print(f"    Bitstring length after trimming: {len(bitstring)}")

    # Pad to multiple of 8
    if len(bitstring) % 8 != 0:
        padding_needed = 8 - (len(bitstring) % 8)
        bitstring = bitstring.ljust(len(bitstring) + padding_needed, '0')
        print(f"    Added {padding_needed} padding bits")

    print(f"    Final bitstring length: {len(bitstring)}")

    # Convert to bytes
    bytes_list = []
    for i in range(0, len(bitstring), 8):
        chunk = bitstring[i:i+8]
        if len(chunk) == 8:
            byte_val = 0
            for j, bit in enumerate(chunk):
                if bit == '1':
                    byte_val |= (1 << j)
            bytes_list.append(byte_val)

    result = bytes(bytes_list)
    print(f"    Recovered {len(result)} bytes")

    return result

# ============================================================
# 3. DEBUGGING DEMONSTRATION
# ============================================================
def debug_backdoor():
    """Debug the backdoor implementation step by step."""
    print("\n" + "="*80)
    print("DEBUGGING BACKDOOR IMPLEMENTATION")
    print("="*80)

    # 1. Create backdoor
    print("\n1. Creating ECDH backdoor...")
    backdoor = ECDHBackdoor()
    print("   ✓ Backdoor created")

    # 2. Generate ciphertext
    print("\n2. Generating ECDH ciphertext...")
    ct_bd, seed_b = backdoor.encrypt()
    print(f"   Original ciphertext: {len(ct_bd)} bytes")
    print(f"   First few bytes: {ct_bd[:4].hex()}...")
    print(f"   Byte 0 (should be 0x04): 0x{ct_bd[0]:02x}")

    # Quick test: can we decrypt our own ciphertext?
    print("\n3. Testing direct ECDH decryption...")
    test_seed = backdoor.decrypt(ct_bd)
    if test_seed is not None and test_seed == seed_b:
        print("   ✓ Direct ECDH works correctly")
    else:
        print("   ✗ Direct ECDH failed!")
        return False

    # 3. Test encoding/decoding without polynomial adjustments
    print("\n4. Testing bit encoding/decoding (no adjustments)...")

    # Simple test: encode and decode without polynomial modifications
    test_bytes = b"Test" * 16  # 64 bytes
    print(f"   Test data: {len(test_bytes)} bytes")

    k, n = 3, 256
    p, c, b, num_chunks = encode_payload_to_poly_debug(test_bytes, k, n)

    # Simulate t coefficients (all zeros for test)
    t_test = [0] * (k * n)

    # Apply compensation
    t_prime_test = apply_compensation_debug(t_test, p, c, num_chunks)

    # Recover
    recovered_test = recover_from_t_prime_debug(t_prime_test, c, b, num_chunks)
    recovered_test = recovered_test[:len(test_bytes)]

    if test_bytes == recovered_test:
        print("   ✓ Simple encoding/decoding test PASSED")
    else:
        print("   ✗ Simple test FAILED")
        print(f"   Original: {test_bytes[:16].hex()}")
        print(f"   Recovered: {recovered_test[:16].hex()}")
        # Find first mismatch
        for i in range(min(len(test_bytes), len(recovered_test))):
            if test_bytes[i] != recovered_test[i]:
                print(f"   First mismatch at byte {i}: 0x{test_bytes[i]:02x} != 0x{recovered_test[i]:02x}")
                break

    # 4. Now test with actual ECDH ciphertext
    print("\n5. Testing with actual ECDH ciphertext...")

    p_ct, c_ct, b_ct, num_chunks_ct = encode_payload_to_poly_debug(ct_bd, k, n)

    # Generate random t coefficients
    q = 3329
    t_ct = [secrets.randbelow(q) for _ in range(k * n)]

    # Apply compensation
    t_prime_ct = apply_compensation_debug(t_ct, p_ct, c_ct, num_chunks_ct)

    # Recover ciphertext
    recovered_ct = recover_from_t_prime_debug(t_prime_ct, c_ct, b_ct, num_chunks_ct)
    recovered_ct = recovered_ct[:len(ct_bd)]

    print(f"\n6. Comparing original and recovered ciphertexts:")
    print(f"   Original length: {len(ct_bd)} bytes")
    print(f"   Recovered length: {len(recovered_ct)} bytes")

    if len(ct_bd) == len(recovered_ct):
        # Check byte by byte
        mismatches = []
        for i in range(len(ct_bd)):
            if ct_bd[i] != recovered_ct[i]:
                mismatches.append((i, ct_bd[i], recovered_ct[i]))
                if len(mismatches) > 5:
                    break

        if len(mismatches) == 0:
            print("   ✓ PERFECT RECOVERY!")
        else:
            print(f"   ✗ {len(mismatches)} mismatches found")
            for i, orig, rec in mismatches[:5]:
                print(f"      Byte {i}: 0x{orig:02x} != 0x{rec:02x}")
    else:
        print(f"   ✗ Length mismatch!")

    # 7. Try to decrypt recovered ciphertext
    print("\n7. Attempting to decrypt recovered ciphertext...")
    if len(recovered_ct) == len(ct_bd):
        recovered_seed = backdoor.decrypt(recovered_ct)
        if recovered_seed is not None:
            print(f"   Recovered seed: {recovered_seed[:8].hex()}...")
            print(f"   Original seed: {seed_b[:8].hex()}...")
            if recovered_seed == seed_b:
                print("   ✓ SEED RECOVERY SUCCESSFUL!")
                return True
            else:
                print("   ✗ Seed recovery failed")
        else:
            print("   ✗ Could not decrypt recovered ciphertext")
    else:
        print("   ✗ Cannot decrypt: ciphertext length mismatch")

    return False

# ============================================================
# 4. FIXED VERSION WITH SMALLER PAYLOAD
# ============================================================
def demonstrate_fixed_backdoor():
    """Use a smaller, more manageable payload."""
    print("\n" + "="*80)
    print("FIXED VERSION WITH SMALLER PAYLOAD")
    print("="*80)

    print("\nPROBLEM: ECDH uncompressed points are 65 bytes")
    print("This requires c=2 bits per coefficient")
    print("But 65 bytes = 520 bits, and we have 768 coefficients")
    print("So c = ceil(520/768) = 1 (not enough for proper encoding)")
    print("\nSOLUTION: Use smaller payload or larger ciphertext")

    # Create a synthetic 32-byte payload (simulating compressed ECDH point)
    print("\nUsing synthetic 32-byte payload (simulating ECDH shared secret)...")

    # Simulate backdoor
    class SimpleBackdoor:
        def __init__(self):
            self.key = secrets.token_bytes(32)

        def encrypt(self):
            # Generate random ciphertext
            ct = secrets.token_bytes(32)
            # Derive seed from ciphertext
            seed = hashlib.sha256(ct + self.key).digest()[:32]
            return ct, seed

        def decrypt(self, ct):
            return hashlib.sha256(ct + self.key).digest()[:32]

    backdoor = SimpleBackdoor()

    # Generate payload
    ct_bd, seed_b = backdoor.encrypt()
    print(f"\n1. Generated 32-byte payload")
    print(f"   Ciphertext: {ct_bd[:8].hex()}...")
    print(f"   Seed: {seed_b[:8].hex()}...")

    # Kyber parameters
    k, n, q = 3, 256, 3329

    # Encode
    p, c, b, num_chunks = encode_payload_to_poly_debug(ct_bd, k, n)
    print(f"\n2. Encoding parameters:")
    print(f"   c = {c} bits per coefficient")
    print(f"   Using {num_chunks} coefficients")

    # Generate random t and apply compensation
    t = [secrets.randbelow(q) for _ in range(k * n)]
    t_prime = apply_compensation_debug(t, p, c, num_chunks)

    # Recover
    recovered_ct = recover_from_t_prime_debug(t_prime, c, b, num_chunks)
    recovered_ct = recovered_ct[:len(ct_bd)]

    print(f"\n3. Recovery results:")
    print(f"   Original: {len(ct_bd)} bytes")
    print(f"   Recovered: {len(recovered_ct)} bytes")

    if ct_bd == recovered_ct:
        print("   ✓ PERFECT RECOVERY!")

        # Test decryption
        recovered_seed = backdoor.decrypt(recovered_ct)
        if recovered_seed == seed_b:
            print("   ✓ SEED RECOVERY SUCCESSFUL!")
            return True
        else:
            print("   ✗ Seed mismatch")
    else:
        # Find mismatches
        for i in range(min(len(ct_bd), len(recovered_ct))):
            if ct_bd[i] != recovered_ct[i]:
                print(f"   First mismatch at byte {i}: 0x{ct_bd[i]:02x} != 0x{recovered_ct[i]:02x}")
                break

    return False

# ============================================================
# 5. MAIN EXECUTION
# ============================================================
if __name__ == "__main__":
    print("\n" + "="*80)
    print("KLEPTOGRAPHIC BACKDOOR DEBUGGING")
    print("="*80)

    print("\nNote: The paper's method requires careful bit-level encoding.")
    print("Even one bit error in the ciphertext makes ECDH decryption fail.")

    # Run debugging
    print("\n" + "="*80)
    print("DEBUGGING SESSION")
    print("="*80)

    debug_success = debug_backdoor()

    print("\n" + "="*80)
    print("FIXED VERSION TEST")
    print("="*80)

    fixed_success = demonstrate_fixed_backdoor()

    print("\n" + "="*80)
    print("CONCLUSION")
    print("="*80)

    if fixed_success:
        print("""
✓ BACKDOOR TECHNIQUE VALIDATED

The kleptographic backdoor technique works when implemented correctly:

1. Encode ciphertext bits into polynomial coefficients p[i] (c bits each)
2. Adjust public key coefficients t[i] to t'[i] such that:
   t'[i] mod 2^c = p[i]
3. Attacker recovers p[i] = t'[i] mod 2^c
4. Reconstruct ciphertext from p[i] values
5. Decrypt ciphertext to get seed_B
6. Regenerate secret key from seed_B

CHALLENGES ENCOUNTERED:
1. ECDH ciphertexts (65-byte uncompressed points) are large
2. Need c ≥ 2 for meaningful encoding, but 65 bytes gives c=1
3. Bit-level errors break ECDH decryption

PAPER'S SOLUTION:
- Use K-409 curve (104-byte ciphertexts) → c=2
- Or Classic McEliece (156-byte ciphertexts) → c=2
- With c=2, each coefficient stores 2 bits (0-3)
- Compensation is ±1, keeping error manageable
        """)
    else:
        print("""
✗ IMPLEMENTATION CHALLENGES

The backdoor concept is valid but requires precise implementation:

Key issues:
1. Bit-level encoding must be perfect (no errors)
2. ECDH is sensitive to single bit errors
3. Real Kyber uses NTT domain, making adjustment trickier

For a working implementation:
1. Use the paper's exact parameters (K-409 curve)
2. Implement proper NTT domain operations
3. Handle byte-to-bit conversion carefully
        """)

    print("\n" + "="*80)
    print("END OF Demonstration")
    print("="*80)


            KLEPTOGRAPHIC BACKDOOR FOR ML-KEM - DEBUGGING VERSION


KLEPTOGRAPHIC BACKDOOR DEBUGGING

Note: The paper's method requires careful bit-level encoding.
Even one bit error in the ciphertext makes ECDH decryption fail.

DEBUGGING SESSION

DEBUGGING BACKDOOR IMPLEMENTATION

1. Creating ECDH backdoor...
   ✓ Backdoor created

2. Generating ECDH ciphertext...
   Original ciphertext: 65 bytes
   First few bytes: 041216ee...
   Byte 0 (should be 0x04): 0x04

3. Testing direct ECDH decryption...
   ✓ Direct ECDH works correctly

4. Testing bit encoding/decoding (no adjustments)...
   Test data: 64 bytes
    Original ciphertext: 64 bytes = 512 bits
    Coefficients available: 3*256 = 768
    Bits per coefficient (c): 1
    Max value per coefficient: 0-1
    Bits after padding: 512
    Chunks to encode: 512
    First few p values: [0, 0, 1, 0, 1]
    Applied 512 adjustments
    Adjustment range: [0, 1]
    Success rate: 512/512
    Recovered 512 p values
    p value range: [0, 1]
  

In [15]:
import os
import hashlib
import struct
import secrets
import numpy as np
from typing import List, Tuple

# ============================================================
# CORRECTED KLEPTOGRAPHIC BACKDOOR IMPLEMENTATION
# Following EXACTLY the paper: "Backdooring Post-Quantum Cryptography"
# ============================================================

print("""
===============================================================================
        CORRECTED KLEPTOGRAPHIC BACKDOOR - FOLLOWING PAPER SPECIFICATIONS
===============================================================================
""")

# ============================================================
# 1. PAPER'S PARAMETERS (Kyber-768)
# ============================================================
class KyberParams:
    """Kyber-768 parameters as specified in paper."""
    n = 256          # polynomial degree
    k = 3            # module dimension for Kyber-768
    q = 3329         # modulus
    eta = 2          # CBD parameter for Kyber-768

    # For ECDH K-409 curve (paper §3.5.1)
    ECDH_CIPHERTEXT_SIZE = 104  # bytes (K-409 curve)

    # For Classic McEliece (paper §3.5.2)
    CM_CIPHERTEXT_SIZE = 156    # bytes (mceliece460896)

    @staticmethod
    def compute_c(bits: int) -> int:
        """Compute c = ceil(b/(k*n)) per paper Eq. (1)."""
        return (bits + KyberParams.k * KyberParams.n - 1) // (KyberParams.k * KyberParams.n)

# ============================================================
# 2. NTT IMPLEMENTATION (simplified but correct structure)
# ============================================================
def ntt(poly: List[int]) -> List[int]:
    """Number Theoretic Transform (simplified for demo)."""
    # Real implementation would use proper NTT with zeta values
    # For demo, we just return the polynomial as-is (pretending it's in NTT domain)
    return poly.copy()

def inv_ntt(poly_ntt: List[int]) -> List[int]:
    """Inverse NTT (simplified for demo)."""
    return poly_ntt.copy()

# ============================================================
# 3. CORRECTED CBD SAMPLER (Algorithm 8 from FIPS 203)
# ============================================================
def cbd_corrected(seed: bytes, eta: int = 2) -> List[int]:
    """
    Centered Binomial Distribution following FIPS 203 Algorithm 8.
    Returns 256 coefficients in [-eta, eta].
    """
    # For eta=2, we need 64*eta = 128 bytes
    from hashlib import shake_256
    expanded = shake_256(seed).digest(64 * eta)

    coeffs = []
    bit_pos = 0

    for i in range(256):
        x = 0
        y = 0

        # Read eta bits for x
        for j in range(eta):
            byte_idx = bit_pos // 8
            bit_in_byte = bit_pos % 8
            bit_val = (expanded[byte_idx] >> bit_in_byte) & 1
            x += bit_val
            bit_pos += 1

        # Read eta bits for y
        for j in range(eta):
            byte_idx = bit_pos // 8
            bit_in_byte = bit_pos % 8
            bit_val = (expanded[byte_idx] >> bit_in_byte) & 1
            y += bit_val
            bit_pos += 1

        coeffs.append(x - y)

    return coeffs

# ============================================================
# 4. CORRECTED ENCODING (Paper Eq. 1)
# ============================================================
def encode_payload_to_poly_corrected(ct: bytes) -> Tuple[List[int], int, int, int]:
    """
    Paper's Equation (1): Encode ciphertext into polynomial p.

    Returns:
        p: polynomial coefficients (k*n elements)
        c: number of bits per coefficient = ceil(b/(k*n))
        b: total bits in ciphertext
        num_chunks: ceil(b/c) = number of non-zero coefficients
    """
    k = KyberParams.k
    n = KyberParams.n

    b = len(ct) * 8  # total bits

    # c = ceil(b/(k*n)) - BITS PER COEFFICIENT
    c = KyberParams.compute_c(b)

    # Convert to bit string (little-endian)
    bits = ''
    for byte in ct:
        bits += ''.join(f'{(byte >> j) & 1}' for j in range(8))

    # Pad to multiple of c
    padded_len = ((len(bits) + c - 1) // c) * c
    bits = bits.ljust(padded_len, '0')

    # Split into chunks of c bits to form p[i]
    num_chunks = len(bits) // c  # = ceil(b/c)
    p = [0] * (k * n)

    for i in range(num_chunks):
        chunk = bits[i*c:(i+1)*c]
        p[i] = int(chunk, 2)  # p[i] ∈ [0, 2^c)

    return p, c, b, num_chunks

# ============================================================
# 5. CORRECTED COMPENSATION (Paper Eq. 2 & constraints)
# ============================================================
def apply_compensation_corrected(t_coeffs: List[int], p: List[int],
                                 c: int, num_chunks: int) -> Tuple[List[int], List[int]]:
    """
    Paper's Equation (2): t' = t + h where h[i] ∈ [-⌊c/2⌋, ⌊c/2⌋]
    such that t'[i] mod c = p[i] mod c

    Returns:
        t_prime: adjusted coefficients
        h: compensation vector
    """
    q = KyberParams.q
    t_prime = t_coeffs.copy()
    h = [0] * len(t_coeffs)

    for i in range(num_chunks):
        # Current value modulo c (NOT 2^c!)
        current_mod = t_coeffs[i] % c

        # Desired value: p[i] mod c (p[i] ∈ [0, 2^c), but we only need mod c)
        desired_mod = p[i] % c

        # Compute adjustment to make current_mod ≡ desired_mod (mod c)
        # Adjustment = (desired_mod - current_mod) mod c, but centered
        adjustment = (desired_mod - current_mod) % c

        # Center adjustment in [-⌊c/2⌋, ⌊c/2⌋]
        if adjustment > c // 2:
            adjustment -= c

        # Apply adjustment (Paper: h[i] ∈ [-⌊c/2⌋, ⌊c/2⌋])
        h[i] = adjustment
        t_prime[i] = (t_coeffs[i] + adjustment) % q

        # Verify: t'[i] mod c should equal p[i] mod c
        if (t_prime[i] % c) != (p[i] % c):
            print(f"  ERROR: Coefficient {i} failed: {t_prime[i] % c} != {p[i] % c}")

    return t_prime, h

# ============================================================
# 6. CORRECTED RECOVERY (Paper Algorithm 2)
# ============================================================
def recover_payload_corrected(t_prime: List[int], c: int, b: int,
                              num_chunks: int) -> bytes:
    """
    Recover ciphertext from t' using paper's method:
    p[i] = t'[i] mod c
    """
    # Recover p[i] values (mod c)
    p_rec = [t_prime[i] % c for i in range(num_chunks)]

    # Convert p[i] (which are in [0, c-1]) back to bits
    # Each p[i] represents c bits in the original encoding
    # But we only recovered p[i] mod c, not the full p[i] ∈ [0, 2^c)!

    print(f"  WARNING: Recovered p[i] mod c (values 0-{c-1}), not full {c}-bit values")
    print(f"  Paper's recovery in Algorithm 2 uses: p[i] = t'[i] mod c")
    print(f"  But original encoding stores c bits per coefficient")
    print(f"  This appears to be a PAPER ERROR or our misunderstanding")

    # Let's proceed with what the paper says
    bit_chunks = []
    for val in p_rec:
        # val is in [0, c-1], convert to log2(c) bits?
        # Actually, if c is not a power of 2, this doesn't make sense
        bits_needed = (c - 1).bit_length()  # ceil(log2(c))
        chunk = f'{val:0{bits_needed}b}'
        bit_chunks.append(chunk)

    bitstring = ''.join(bit_chunks)

    # Trim to original length
    bitstring = bitstring[:b]

    # Pad to multiple of 8
    if len(bitstring) % 8 != 0:
        bitstring = bitstring.ljust(len(bitstring) + (8 - len(bitstring) % 8), '0')

    # Convert to bytes
    bytes_list = []
    for i in range(0, len(bitstring), 8):
        chunk = bitstring[i:i+8]
        if len(chunk) == 8:
            byte_val = 0
            for j, bit in enumerate(chunk):
                if bit == '1':
                    byte_val |= (1 << j)
            bytes_list.append(byte_val)

    return bytes(bytes_list)

# ============================================================
# 7. SIMULATED ECDH K-409 (104-byte ciphertext)
# ============================================================
class SimulatedK409:
    """Simulate K-409 ECDH as described in paper §3.5.1."""

    @staticmethod
    def encrypt() -> Tuple[bytes, bytes]:
        """Generate 104-byte ciphertext and 32-byte seed_B."""
        # Paper: K-409 ciphertext = 104 bytes
        ct = secrets.token_bytes(104)

        # Paper: first 32 bytes of shared secret = seed_B
        seed_b = hashlib.sha256(ct).digest()[:32]

        return ct, seed_b

    @staticmethod
    def decrypt(ct: bytes) -> bytes:
        """Recover seed_B from ciphertext."""
        return hashlib.sha256(ct).digest()[:32]

# ============================================================
# 8. CORRECTED KEY GENERATION (Paper Algorithm 2)
# ============================================================
def backdoored_keygen_corrected() -> dict:
    """
    Implement Paper's Algorithm 2 for backdoored key generation.

    Steps:
    1. Generate ECDH ciphertext ct_bd and message m_bd
    2. seed_B = first 32 bytes of m_bd
    3. Generate Kyber secret s and error e from seed_B
    4. Generate random matrix A and compute t = A·s + e
    5. Encode ct_bd into polynomial p
    6. Compute compensation h and t' = t + h
    7. Output: pk = (seed_A, NTT(t')), sk = NTT(s)
    """
    params = KyberParams

    print("\n" + "="*80)
    print("CORRECTED BACKDOORED KEY GENERATION (Paper Algorithm 2)")
    print("="*80)

    # 1. Generate ECDH ciphertext and seed_B
    print("\n1. Generate ECDH ciphertext (K-409, 104 bytes)...")
    ct_bd, seed_b = SimulatedK409.encrypt()
    print(f"   ct_bd: {len(ct_bd)} bytes")
    print(f"   seed_B: {seed_b[:8].hex()}...")

    # 2. Generate Kyber secret s from seed_B
    print("\n2. Generate Kyber secret s from seed_B...")
    s_coeffs = []
    for i in range(params.k):
        s_seed = seed_b + struct.pack('B', i)
        s_i = cbd_corrected(s_seed, params.eta)
        s_coeffs.append(s_i)

    # 3. Generate error e from seed_B
    print("\n3. Generate error e from seed_B...")
    e_coeffs = []
    for i in range(params.k):
        e_seed = seed_b + struct.pack('B', i + params.k)
        e_i = cbd_corrected(e_seed, params.eta)
        e_coeffs.append(e_i)

    # 4. Generate random matrix A (simplified)
    print("\n4. Generate random matrix A and compute t = A·s + e...")
    # In real Kyber: A = Expand(seed_A), t = A·s + e
    # For demo, we create random t
    t_flat = []
    for i in range(params.k * params.n):
        # Simulate A·s + e with random values in [0, q)
        t_flat.append(secrets.randbelow(params.q))

    # 5. Encode ciphertext into polynomial p
    print("\n5. Encode ciphertext into polynomial p (Paper Eq. 1)...")
    p, c, b, num_chunks = encode_payload_to_poly_corrected(ct_bd)
    print(f"   c = ceil({b}/({params.k}*{params.n})) = {c} bits per coefficient")
    print(f"   Non-zero coefficients: {num_chunks}")

    # 6. Apply compensation
    print("\n6. Apply compensation h (Paper Eq. 2)...")
    print(f"   h[i] must be in [-{c//2}, {c//2}]")
    t_prime_flat, h = apply_compensation_corrected(t_flat, p, c, num_chunks)

    # Check h bounds
    h_values = [abs(h[i]) for i in range(num_chunks) if h[i] != 0]
    if h_values:
        max_h = max(h_values)
        print(f"   Max |h[i]| = {max_h} (must be ≤ {c//2})")
        if max_h > c // 2:
            print(f"   ERROR: Compensation exceeds paper's bounds!")

    # 7. Convert to NTT domain (Paper: t_hat' = NTT(t'))
    print("\n7. Convert to NTT domain...")
    t_prime_ntt = ntt(t_prime_flat)
    s_ntt = [ntt(s_i) for s_i in s_coeffs]

    print("\n8. Output key pair:")
    print("   pk = (seed_A, NTT(t'))")
    print("   sk = NTT(s)")

    return {
        'ct_bd': ct_bd,
        'seed_b': seed_b,
        's_coeffs': s_coeffs,
        't_flat': t_flat,
        't_prime_flat': t_prime_flat,
        't_prime_ntt': t_prime_ntt,
        's_ntt': s_ntt,
        'p': p,
        'c': c,
        'b': b,
        'num_chunks': num_chunks,
        'h': h
    }

# ============================================================
# 9. CORRECTED ATTACKER RECOVERY
# ============================================================
def attacker_recovery_corrected(public_key: dict) -> Tuple[bytes, bytes, List[List[int]]]:
    """
    Attacker recovers secret key from public key (Paper's recovery).

    Steps:
    1. From pk, extract NTT(t') and seed_A
    2. Compute t' = INTT(NTT(t'))
    3. Recover p[i] = t'[i] mod c
    4. Reconstruct ct_bd from p[i]
    5. Decrypt ct_bd to get seed_B
    6. Regenerate s from seed_B
    """
    print("\n" + "="*80)
    print("CORRECTED ATTACKER RECOVERY")
    print("="*80)

    # Extract values
    t_prime_ntt = public_key['t_prime_ntt']
    c = public_key['c']
    b = public_key['b']
    num_chunks = public_key['num_chunks']

    print("\n1. Compute t' = INTT(NTT(t'))...")
    t_prime_flat = inv_ntt(t_prime_ntt)

    print("\n2. Recover p[i] = t'[i] mod c...")
    p_rec = [t_prime_flat[i] % c for i in range(num_chunks)]

    print("\n3. Reconstruct ciphertext from p[i] values...")
    # Convert p[i] (which are mod c) back to bits
    # This is where the paper's method seems problematic
    bit_chunks = []
    for val in p_rec:
        # val is in [0, c-1]
        bits_needed = (c - 1).bit_length()
        chunk = f'{val:0{bits_needed}b}'
        bit_chunks.append(chunk)

    bitstring = ''.join(bit_chunks)[:b]

    # Convert to bytes
    bytes_needed = (b + 7) // 8
    int_val = int(bitstring, 2) if bitstring else 0
    recovered_ct = int_val.to_bytes(bytes_needed, 'little')

    # Pad/truncate to original size (104 bytes for K-409)
    target_size = KyberParams.ECDH_CIPHERTEXT_SIZE
    if len(recovered_ct) < target_size:
        recovered_ct = recovered_ct.ljust(target_size, b'\x00')
    else:
        recovered_ct = recovered_ct[:target_size]

    print(f"   Recovered ciphertext: {len(recovered_ct)} bytes")

    print("\n4. Decrypt ciphertext to get seed_B...")
    recovered_seed = SimulatedK409.decrypt(recovered_ct)

    print("\n5. Regenerate secret s from seed_B...")
    params = KyberParams
    s_recovered = []
    for i in range(params.k):
        s_seed = recovered_seed + struct.pack('B', i)
        s_i = cbd_corrected(s_seed, params.eta)
        s_recovered.append(s_i)

    return recovered_ct, recovered_seed, s_recovered

# ============================================================
# 10. VALIDATION AND TEST
# ============================================================
def validate_corrected_implementation():
    """Validate the corrected implementation against paper specifications."""

    print("\n" + "="*80)
    print("VALIDATING CORRECTED IMPLEMENTATION")
    print("="*80)

    # Generate backdoored key
    result = backdoored_keygen_corrected()

    # Attacker recovery
    recovered_ct, recovered_seed, s_recovered = attacker_recovery_corrected(result)

    print("\n" + "="*80)
    print("VALIDATION RESULTS")
    print("="*80)

    # Check ciphertext recovery
    ct_match = result['ct_bd'] == recovered_ct
    print(f"\n1. Ciphertext recovery: {'✓ SUCCESS' if ct_match else '✗ FAILED'}")
    if not ct_match:
        print(f"   Original: {result['ct_bd'][:8].hex()}...")
        print(f"   Recovered: {recovered_ct[:8].hex()}...")

    # Check seed recovery
    seed_match = result['seed_b'] == recovered_seed
    print(f"2. Seed recovery: {'✓ SUCCESS' if seed_match else '✗ FAILED'}")

    # Check secret key recovery
    secret_match = True
    for i in range(KyberParams.k):
        for j in range(min(10, KyberParams.n)):
            if result['s_coeffs'][i][j] != s_recovered[i][j]:
                secret_match = False
                break

    print(f"3. Secret key recovery: {'✓ SUCCESS' if secret_match else '✗ FAILED'}")

    # Check paper's constraints
    print(f"\n4. Paper's constraints verification:")

    # Constraint 1: h[i] ∈ [-⌊c/2⌋, ⌊c/2⌋]
    c = result['c']
    h = result['h']
    h_valid = all(abs(h[i]) <= c // 2 for i in range(len(h)) if h[i] != 0)
    print(f"   h[i] bounds: {'✓ VALID' if h_valid else '✗ INVALID'}")

    # Constraint 2: t'[i] mod c = p[i] mod c
    t_prime = result['t_prime_flat']
    p = result['p']
    mod_valid = all((t_prime[i] % c) == (p[i] % c) for i in range(result['num_chunks']))
    print(f"   t'[i] mod c = p[i] mod c: {'✓ VALID' if mod_valid else '✗ INVALID'}")

    # Constraint 3: Using correct ciphertext size
    ct_size = len(result['ct_bd'])
    expected_size = KyberParams.ECDH_CIPHERTEXT_SIZE
    size_valid = ct_size == expected_size
    print(f"   Ciphertext size (K-409): {ct_size} bytes {'✓' if size_valid else '✗ (expected ' + str(expected_size) + ')'}")

    # Constraint 4: Correct c calculation
    b = len(result['ct_bd']) * 8
    computed_c = KyberParams.compute_c(b)
    c_valid = computed_c == c
    print(f"   c = ceil({b}/({KyberParams.k}*{KyberParams.n})) = {c} {'✓' if c_valid else '✗'}")

    return ct_match and seed_match and secret_match and h_valid and mod_valid and size_valid and c_valid

# ============================================================
# 11. DEMONSTRATE THE PAPER'S ERROR/ISSUE
# ============================================================
def demonstrate_paper_issue():
    """
    Demonstrate the apparent error in the paper:
    - Paper says: p[i] stores c bits (values 0 to 2^c-1)
    - Paper says: t'[i] mod c = p[i] mod c
    - But p[i] mod c only gives log2(c) bits, not c bits!
    """

    print("\n" + "="*80)
    print("DEMONSTRATING THE PAPER'S APPARENT ERROR")
    print("="*80)

    # Example with K-409: c = 2
    ct_size = KyberParams.ECDH_CIPHERTEXT_SIZE
    b = ct_size * 8
    c = KyberParams.compute_c(b)

    print(f"\nK-409 example:")
    print(f"  Ciphertext: {ct_size} bytes = {b} bits")
    print(f"  k*n = {KyberParams.k}*{KyberParams.n} = {KyberParams.k * KyberParams.n} coefficients")
    print(f"  c = ceil({b}/({KyberParams.k}*{KyberParams.n})) = {c}")

    print(f"\nPaper's encoding (Eq. 1):")
    print(f"  Each p[i] stores c = {c} bits")
    print(f"  So p[i] ∈ [0, 2^{c}-1] = [0, {2**c - 1}]")

    print(f"\nPaper's compensation (Eq. 2):")
    print(f"  t'[i] mod c = p[i] mod c")
    print(f"  But p[i] mod c ∈ [0, {c-1}] (only {c.bit_length()} bits, not {c} bits!)")

    print(f"\nPaper's recovery (Algorithm 2):")
    print(f"  p[i] = t'[i] mod c")
    print(f"  This recovers values in [0, {c-1}], losing {c - c.bit_length()} bits per coefficient")

    print(f"\nPossible interpretations:")
    print("  1. The paper has a typo: should be 't'[i] mod 2^c = p[i]'")
    print("  2. Or: 'c' in Eq. 2 means 2^c, not the number of bits")
    print("  3. Or: There's additional encoding we're missing")

    return c

# ============================================================
# 12. MAIN EXECUTION
# ============================================================
if __name__ == "__main__":
    print("\n" + "="*80)
    print("CORRECTED KLEPTOGRAPHIC BACKDOOR IMPLEMENTATION")
    print("="*80)

    # First demonstrate the paper's apparent issue
    c_value = demonstrate_paper_issue()

    print("\n" + "="*80)
    print("RUNNING CORRECTED IMPLEMENTATION")
    print("="*80)

    # Try with the corrected understanding
    success = validate_corrected_implementation()

    print("\n" + "="*80)
    print("CONCLUSION")
    print("="*80)

    if success:
        print("""
✓ IMPLEMENTATION FOLLOWS PAPER SPECIFICATIONS

Key corrections made:
1. ✅ Working in NTT domain: t_hat' = NTT(t + h)
2. ✅ Using modulo c (not 2^c): t'[i] mod c = p[i] mod c
3. ✅ Respecting error bounds: h[i] ∈ [-⌊c/2⌋, ⌊c/2⌋]
4. ✅ Using correct ciphertext size: 104 bytes (K-409)
5. ✅ Implementing proper Kyber key generation with CBD
6. ✅ Following Algorithm 2 for recovery

HOWEVER: The paper appears to have an error/ambiguity:
- p[i] stores c bits (0 to 2^c-1)
- But recovery only gets p[i] mod c (0 to c-1)
- This loses information unless c=2 (as in K-409 example)

For K-409 with c=2:
  p[i] ∈ {0,1,2,3} (2 bits)
  p[i] mod 2 ∈ {0,1} (1 bit)
  So we lose 1 bit per coefficient!

This suggests either:
  a) The paper's examples work because c=2 and they accept 50% data loss
  b) There's a misinterpretation of the notation
  c) The paper contains an error
        """)
    else:
        print("""
✗ IMPLEMENTATION ISSUES DETECTED

The implementation tries to follow the paper exactly, but
encounters the mathematical inconsistency described above.

For a working backdoor, one would need to either:
1. Interpret 'mod c' as 'mod 2^c' (making mathematical sense)
2. Accept data loss and use error correction
3. Use a different encoding scheme

The paper's attack might still work in practice if:
- They use error correction on the recovered bits
- Or they have additional encoding not described
- Or 'c' in Eq. 2 means something else
        """)

    print("\n" + "="*80)
    print("RECOMMENDATION FOR PRACTICAL IMPLEMENTATION")
    print("="*80)
    print("""
For a practical implementation that works:

1. Use: t'[i] mod 2^c = p[i]  (instead of mod c)
2. This stores c bits per coefficient correctly
3. Adjustments h[i] become larger but still manageable
4. All other paper methodology remains the same

This corrected approach would:
- Store all c bits of p[i] in each coefficient
- Require adjustments in range [-2^(c-1), 2^(c-1)]
- Still keep error increase manageable for c=2 (adjustments ±1)
- Make mathematical sense and actually work
    """)

    print("\n" + "="*80)
    print("END OF CORRECTED IMPLEMENTATION")
    print("="*80)


        CORRECTED KLEPTOGRAPHIC BACKDOOR - FOLLOWING PAPER SPECIFICATIONS


CORRECTED KLEPTOGRAPHIC BACKDOOR IMPLEMENTATION

DEMONSTRATING THE PAPER'S APPARENT ERROR

K-409 example:
  Ciphertext: 104 bytes = 832 bits
  k*n = 3*256 = 768 coefficients
  c = ceil(832/(3*256)) = 2

Paper's encoding (Eq. 1):
  Each p[i] stores c = 2 bits
  So p[i] ∈ [0, 2^2-1] = [0, 3]

Paper's compensation (Eq. 2):
  t'[i] mod c = p[i] mod c
  But p[i] mod c ∈ [0, 1] (only 2 bits, not 2 bits!)

Paper's recovery (Algorithm 2):
  p[i] = t'[i] mod c
  This recovers values in [0, 1], losing 0 bits per coefficient

Possible interpretations:
  1. The paper has a typo: should be 't'[i] mod 2^c = p[i]'
  2. Or: 'c' in Eq. 2 means 2^c, not the number of bits
  3. Or: There's additional encoding we're missing

RUNNING CORRECTED IMPLEMENTATION

VALIDATING CORRECTED IMPLEMENTATION

CORRECTED BACKDOORED KEY GENERATION (Paper Algorithm 2)

1. Generate ECDH ciphertext (K-409, 104 bytes)...
   ct_bd: 104 bytes
   seed_B:

In [16]:
import hashlib
import struct
import secrets
from typing import List, Tuple

print("=== Backdooring Kyber768 (Ravi et al. GLSVLSI 2024) ===")

# ---------------------------------------------------------------------------
# 1. Kyber768 parameters (as in paper)
# ---------------------------------------------------------------------------

class KyberParams:
    n = 256          # polynomial degree
    k = 3            # Kyber-768 uses k = 3
    q = 3329         # modulus
    eta = 2          # CBD parameter for Kyber-768

    # Size of K-409 ECDH ciphertext used as backdoor (Section 3.5.1)
    ECDH_CIPHERTEXT_SIZE = 104  # bytes = 832 bits

    @staticmethod
    def compute_c_bits(total_bits: int) -> int:
        """
        c_bits = ceil(b / (k * n)) from Equation (1) of the paper,
        where b is the bit-length of the payload ciphertext.
        """
        return (total_bits + KyberParams.k * KyberParams.n - 1) // (KyberParams.k * KyberParams.n)


# ---------------------------------------------------------------------------
# 2. NTT "stub" (we stay in coefficient domain for this demo)
# ---------------------------------------------------------------------------

def ntt(poly: List[int]) -> List[int]:
    # In the reference C implementation the public key stores NTT(t').
    # For a minimal demo we treat NTT as the identity so that INTT(NTT(x)) = x.
    return list(poly)

def intt(poly_ntt: List[int]) -> List[int]:
    return list(poly_ntt)


# ---------------------------------------------------------------------------
# 3. CBD sampler (FIPS 203 Algorithm 8, eta = 2)
# ---------------------------------------------------------------------------

from hashlib import shake_256

def cbd_fips203(seed: bytes, eta: int = KyberParams.eta) -> List[int]:
    """
    Centered Binomial Distribution sampler as in FIPS 203 (Algorithm 8)
    for Kyber with eta = 2. Returns n = 256 coefficients in [-eta, eta].
    """
    outlen = 64 * eta  # 64*eta bytes for eta in {2,3}
    buf = shake_256(seed).digest(outlen)

    coeffs: List[int] = []
    bitpos = 0
    for _ in range(KyberParams.n):
        x = 0
        y = 0
        # eta bits for x
        for _ in range(eta):
            byte_idx = bitpos // 8
            bit_in_byte = bitpos % 8
            bit = (buf[byte_idx] >> bit_in_byte) & 1
            x += bit
            bitpos += 1
        # eta bits for y
        for _ in range(eta):
            byte_idx = bitpos // 8
            bit_in_byte = bitpos % 8
            bit = (buf[byte_idx] >> bit_in_byte) & 1
            y += bit
            bitpos += 1
        coeffs.append(x - y)

    return coeffs


# ---------------------------------------------------------------------------
# 4. Bit helpers (little-endian inside bytes, as in the paper / reference code)
# ---------------------------------------------------------------------------

def bits_from_bytes_le(data: bytes) -> List[int]:
    """
    Convert bytes to a list of bits in little-endian order per byte.
    bit_index = 8*i + j corresponds to (data[i] >> j) & 1.
    """
    out: List[int] = []
    for b in data:
        for j in range(8):
            out.append((b >> j) & 1)
    return out


def bytes_from_bits_le(bits: List[int], out_len: int) -> bytes:
    """
    Inverse of bits_from_bytes_le for the first out_len bytes.
    """
    res = bytearray(out_len)
    for i in range(out_len):
        val = 0
        for j in range(8):
            idx = 8 * i + j
            if idx < len(bits) and bits[idx]:
                val |= 1 << j
        res[i] = val
    return bytes(res)


# ---------------------------------------------------------------------------
# 5. Encoding payload ct_bd into polynomial p (Section 3.3, Equation (1))
# ---------------------------------------------------------------------------

def encode_payload_to_poly(ct: bytes) -> Tuple[List[int], int, int, int, int]:
    """
    Encode the ciphertext ct_bd into a polynomial module p in R_q^k, as in
    Section 3.3. We follow the reference C code rather than the slightly
    ambiguous notation "mod c" in the paper.

    Let b  = |ct| in bits.
    Let c_bits = ceil(b / (k*n)).
    Let M = 2^c_bits.

    For i from 0 to num_used-1, coefficient p[i] stores c_bits bits of ct
    in little-endian order:
        p[i] = sum_{j=0}^{c_bits-1} bit[ i * c_bits + j ] * 2^j

    Remaining coefficients are zero.
    """
    k, n = KyberParams.k, KyberParams.n
    b_bits = len(ct) * 8
    c_bits = KyberParams.compute_c_bits(b_bits)
    M = 1 << c_bits

    bits = bits_from_bytes_le(ct)
    num_used = (b_bits + c_bits - 1) // c_bits  # ceil(b_bits / c_bits)

    if num_used > k * n:
        raise ValueError("Payload too large to embed in Kyber768 public key")

    p = [0] * (k * n)
    for i in range(num_used):
        val = 0
        for j in range(c_bits):
            bit_index = i * c_bits + j
            if bit_index < b_bits and bits[bit_index]:
                val |= 1 << j
        p[i] = val  # in [0, M-1]

    return p, c_bits, M, b_bits, num_used


# ---------------------------------------------------------------------------
# 6. Compensation vector h and modified t' (Section 3.3, "additively integrate")
# ---------------------------------------------------------------------------

def apply_compensation(t: List[int],
                       p: List[int],
                       c_bits: int,
                       M: int,
                       num_used: int) -> Tuple[List[int], List[int]]:
    """
    Given original LWE component t (flattened, length k*n) and encoded
    payload coefficients p[i] in [0, M-1], compute a small compensation
    vector h so that

        t'[i] = t[i] + h[i] (mod q)
        t'[i] ≡ p[i] (mod M)    for 0 <= i < num_used

    and h[i] is always in [-M/2, M/2].  For Kyber768 with K-409 we have
    c_bits = 2, M = 4, so h[i] ∈ [-2, 2], exactly matching the analysis in
    Section 3.4 where the error span becomes [-4, 4].
    """
    q = KyberParams.q
    t_prime = list(t)
    h = [0] * len(t)
    halfM = M // 2

    for i in range(num_used):
        cur_mod = t[i] % M
        desired = p[i]              # already in [0, M-1]
        diff = (desired - cur_mod) % M  # 0..M-1

        # Center in [-M/2, M/2]
        if diff > halfM:
            diff -= M

        h[i] = diff
        t_prime[i] = (t[i] + diff) % q

        # Sanity checks (can be removed for performance)
        if t_prime[i] % M != desired:
            raise AssertionError("Compensation failed at coefficient {}".format(i))
        if not -halfM <= diff <= halfM:
            raise AssertionError("Compensation outside allowed range at {}".format(i))

    return t_prime, h


# ---------------------------------------------------------------------------
# 7. Recover payload from modified t' (Algorithm 2, "Secret Key Recovery")
# ---------------------------------------------------------------------------

def recover_payload_from_tprime(t_prime: List[int],
                                c_bits: int,
                                M: int,
                                b_bits: int,
                                num_used: int,
                                out_ct_len: int) -> bytes:
    """
    In the recovery phase the attacker computes t' = INTT(t_hat') and then
    for each i takes

        p_rec[i] = t'[i] mod M

    Then the original bits are reconstructed from p_rec[i] using exactly
    the same mapping as in encode_payload_to_poly.
    """
    bits: List[int] = []

    for i in range(num_used):
        val = t_prime[i] % M
        for j in range(c_bits):
            bits.append((val >> j) & 1)

    bits = bits[:b_bits]
    return bytes_from_bits_le(bits, out_ct_len)


# ---------------------------------------------------------------------------
# 8. Simulated ECDH on curve K-409 (Section 3.5.1)
# ---------------------------------------------------------------------------

class SimulatedK409:
    """
    We do not implement full ECC here.  We only need a deterministic
    Encbd/Decbd pair with the right ciphertext size and a 32-byte seed_B.
    For the demo, we set

        ct_bd = random 104 bytes
        seed_B = first 32 bytes of SHA-256(ct_bd)

    so that the attacker can later recompute seed_B from ct_bd.
    """

    @staticmethod
    def enc() -> Tuple[bytes, bytes]:
        ct = secrets.token_bytes(KyberParams.ECDH_CIPHERTEXT_SIZE)
        seed_B = hashlib.sha256(ct).digest()[:32]
        return ct, seed_B

    @staticmethod
    def dec(ct: bytes) -> bytes:
        return hashlib.sha256(ct).digest()[:32]


# ---------------------------------------------------------------------------
# 9. Backdoored Kyber key generation (Algorithm 2, "Backdoor-ed KeyGen C'")
# ---------------------------------------------------------------------------

def backdoored_keygen() -> Tuple[dict, dict, dict]:
    """
    This implements the online part of Algorithm 2 specialized to the
    K-409 instantiation.  We simplify the actual Kyber operations as
    follows:

      - We generate t as a random vector in [0, q)^{k*n} instead of
        computing t = A * s + e.  For the kleptographic channel this
        makes no difference since only the marginal distribution of t
        matters.

      - NTT and INTT are treated as the identity.

    The key steps that match the paper + reference C code are:
      * Encbd to obtain (ct_bd, m_bd), where seed_B is derived from m_bd.
      * Sampling s and e from seed_B using CBD.
      * Encoding ct_bd into polynomial p with c_bits bits per coefficient.
      * Computing compensation h so that t' embeds p modulo M = 2^{c_bits}.
      * Publishing pk = (seed_A, NTT(t')).
    """
    # 1. Adversary's public encryption Encbd (on device)
    ct_bd, seed_B = SimulatedK409.enc()

    # 2. Sample Kyber secret s and error e from seed_B (Algorithm 1 / 2)
    s_vec: List[List[int]] = []
    e_vec: List[List[int]] = []

    for i in range(KyberParams.k):
        s_seed = seed_B + struct.pack("B", i)
        e_seed = seed_B + struct.pack("B", i + KyberParams.k)
        s_vec.append(cbd_fips203(s_seed))
        e_vec.append(cbd_fips203(e_seed))

    # 3. Sample A and compute t = A*s + e.
    #    For the kleptographic channel we only need t to "look like"
    #    a legitimate Kyber LWE vector, so we draw it uniformly at random.
    total_coeffs = KyberParams.k * KyberParams.n
    t = [secrets.randbelow(KyberParams.q) for _ in range(total_coeffs)]

    # 4. Encode ct_bd into p
    p, c_bits, M, b_bits, num_used = encode_payload_to_poly(ct_bd)

    # 5. Apply compensation
    t_prime, h = apply_compensation(t, p, c_bits, M, num_used)

    # 6. Pack public and secret key objects (NTT is identity here)
    seed_A = secrets.token_bytes(32)

    pk = {
        "seed_A": seed_A,
        "t_prime_ntt": ntt(t_prime),
        "c_bits": c_bits,
        "M": M,
        "b_bits": b_bits,
        "num_used": num_used,
    }

    sk_internal = {
        "s_vec": s_vec,
        "seed_B": seed_B,
        "ct_bd": ct_bd,
    }

    debug = {
        "t": t,
        "t_prime": t_prime,
        "h": h,
    }

    return pk, sk_internal, debug


# ---------------------------------------------------------------------------
# 10. Attacker recovery given only pk and sk_bd (Algorithm 2, "Secret recovery")
# ---------------------------------------------------------------------------

def attacker_recovery(pk: dict) -> Tuple[bytes, bytes, List[List[int]]]:
    """
    The attacker knows the ECDH (or McEliece) secret key sk_bd, so it can
    run Decbd on any recovered payload ciphertext ct_bd.

    Given the backdoor-ed public key pk, the attacker:

      1. Computes t' = INTT(t_hat').
      2. Recovers p[i] = t'[i] mod M.
      3. Reconstructs ct_bd from p.
      4. Runs Decbd(ct_bd) to obtain seed_B.
      5. Regenerates s from seed_B using CBD.
    """
    t_prime = intt(pk["t_prime_ntt"])
    c_bits = pk["c_bits"]
    M = pk["M"]
    b_bits = pk["b_bits"]
    num_used = pk["num_used"]

    # Step 1–3: recover ct_bd
    recovered_ct = recover_payload_from_tprime(
        t_prime,
        c_bits,
        M,
        b_bits,
        num_used,
        KyberParams.ECDH_CIPHERTEXT_SIZE,
    )

    # Step 4: Decbd to get seed_B
    recovered_seed_B = SimulatedK409.dec(recovered_ct)

    # Step 5: regenerate s_vec from seed_B
    s_vec_rec: List[List[int]] = []
    for i in range(KyberParams.k):
        s_seed = recovered_seed_B + struct.pack("B", i)
        s_vec_rec.append(cbd_fips203(s_seed))

    return recovered_ct, recovered_seed_B, s_vec_rec


# ---------------------------------------------------------------------------
# 11. End-to-end demo
# ---------------------------------------------------------------------------

def main() -> None:
    print("=== Attacker installs backdoor (offline) ===")
    print("    [Modeled implicitly via SimulatedK409 with fixed Decbd]")

    print("\n=== Victim runs backdoored Kyber KeyGen (C') ===")
    pk, sk_internal, debug = backdoored_keygen()

    print("    Public key contains seed_A and modified t' with embedded payload.")
    print("    c_bits       =", pk["c_bits"])
    print("    num_used     =", pk["num_used"], "coefficients carry payload")
    print("    max |h[i]|   =", max(abs(x) for x in debug["h"][:pk['num_used']]))
    print("    allowed |h|  =", (pk["M"] // 2))

    print("\n=== Attacker recovers secret from public key ===")
    rec_ct, rec_seed_B, s_vec_rec = attacker_recovery(pk)

    ct_match = sk_internal["ct_bd"] == rec_ct
    seed_match = sk_internal["seed_B"] == rec_seed_B

    print("Ciphertext match: ", ct_match)
    print("seed_B match:     ", seed_match)

    # Compare secret polynomials s
    s_true = sk_internal["s_vec"]
    all_match = True
    for i in range(KyberParams.k):
        if s_true[i] != s_vec_rec[i]:
            all_match = False
            break

    print("Secret s match:   ", all_match)

    print("\nFirst 20 coeffs of s[0] (true):     ", s_true[0][:20])
    print("First 20 coeffs of s[0] (recovered):", s_vec_rec[0][:20])

    print("\nMATCH:", all_match and ct_match and seed_match)


if __name__ == "__main__":
    main()


=== Backdooring Kyber768 (Ravi et al. GLSVLSI 2024) ===
=== Attacker installs backdoor (offline) ===
    [Modeled implicitly via SimulatedK409 with fixed Decbd]

=== Victim runs backdoored Kyber KeyGen (C') ===
    Public key contains seed_A and modified t' with embedded payload.
    c_bits       = 2
    num_used     = 416 coefficients carry payload
    max |h[i]|   = 2
    allowed |h|  = 2

=== Attacker recovers secret from public key ===
Ciphertext match:  True
seed_B match:      True
Secret s match:    True

First 20 coeffs of s[0] (true):      [-2, 2, 0, -1, 2, 0, 0, 0, 2, 0, 1, 0, 1, 1, -1, 0, -1, 0, -1, 1]
First 20 coeffs of s[0] (recovered): [-2, 2, 0, -1, 2, 0, 0, 0, 2, 0, 1, 0, 1, 1, -1, 0, -1, 0, -1, 1]

MATCH: True
