<a href="https://colab.research.google.com/github/owlmt/PQC/blob/main/mldsa_backdoor_experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://link.springer.com/chapter/10.1007/978-3-031-83885-9_30

In [9]:
import random
import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
import hashlib

# ============================================================================
# PAPER-ALIGNED IMPLEMENTATION WITH ALL FIXES
# ============================================================================

@dataclass
class DilithiumParams:
    """CRYSTALS-Dilithium parameters as in the paper"""
    q: int = 8380417  # Modulus (2^23 - 2^13 + 1)
    n: int = 256      # Polynomial degree
    k: int = 4        # Module dimension
    gamma1: int = (8380417 - 1) // 16
    gamma2: int = (8380417 - 1) // 88
    beta: int = 39    # β parameter for rejection sampling
    tau: int = 60     # Number of ±1 in challenge

class DilithiumBackdoorPaperExact:
    """
    Faithful implementation of Algorithms 4-6 from the paper with all fixes.

    Corrections applied:
    1. Added rejection loop and checks in Algorithm 4
    2. Fixed θ splitting at 255·k instead of 256·k
    3. Implemented Rebuild(z) function as in paper
    4. Proper deterministic x,ρ derivation from secret key
    5. Fixed parity mapping to match paper exactly
    """

    def __init__(self, params: DilithiumParams = None, secret_key: bytes = None):
        self.params = params or DilithiumParams()
        self.secret_key = secret_key or b"default_sk_for_demo"
        self.total_coeffs = self.params.n * self.params.k  # 256*k

    # -----------------------------------------------------------------
    # Algorithm 6: ParityKernel with Rebuild (page 7 of paper)
    # -----------------------------------------------------------------
    def parity_kernel(self, M_rho: List[int], theta: int, x: int) -> List[int]:
        """
        Implements ParityKernel(M_rho, θ, x) from Algorithm 6

        Args:
            M_rho: XOR of message bits with seed ρ (M ⊕ ρ)
            theta: Length of M_rho (equal to len(M_bits))
            x: Parity mapping bit (0: 0→even,1→odd; 1: opposite)

        Returns:
            List of coefficients for z
        """
        # Initialize z with zeros (will be filled)
        z = []

        # Process each bit in M_rho for the first θ positions
        for i in range(theta):
            if (M_rho[i] == 1 and x == 0) or (M_rho[i] == 0 and x == 1):
                # Need odd coefficient
                # For q=8380417 (odd), 2*Z_q = even numbers, 2*Z_q+1 = odd numbers
                max_even = (self.params.q - 1) // 2
                coeff = 2 * random.randint(0, max_even - 1) + 1
            else:
                # Need even coefficient
                max_even = (self.params.q - 1) // 2
                coeff = 2 * random.randint(0, max_even)

            z.append(coeff % self.params.q)

        # Add noise if needed (θ < 255·k - 1)
        max_msg_coeffs = 255 * self.params.k - 1
        if theta < max_msg_coeffs:
            noise_len = max_msg_coeffs - theta
            z.extend([random.randint(0, self.params.q - 1) for _ in range(noise_len)])

        # Rebuild to k×n matrix as in paper (Algorithm 5 calls Rebuild)
        return self.rebuild(z)

    def rebuild(self, z_coeffs: List[int]) -> List[int]:
        """
        Rebuild(z) function mentioned in Algorithm 5 and 6
        Transforms list to appropriate format for z vector
        """
        # Ensure we have exactly total_coeffs elements
        if len(z_coeffs) < self.total_coeffs:
            # Pad with random coefficients
            pad_len = self.total_coeffs - len(z_coeffs)
            z_coeffs.extend([random.randint(0, self.params.q - 1) for _ in range(pad_len)])
        elif len(z_coeffs) > self.total_coeffs:
            # Truncate
            z_coeffs = z_coeffs[:self.total_coeffs]

        return z_coeffs

    # -----------------------------------------------------------------
    # Algorithm 5: Parity (page 7 of paper) with all fixes
    # -----------------------------------------------------------------
    def parity(self, M_bits: List[int]) -> Tuple[List[int], int, int, List[int]]:
        """
        Implements Parity(M_bits) from Algorithm 5 with paper corrections

        Args:
            M_bits: Original message bits to encode

        Returns:
            (z, x, θ, ρ) where:
            - z: Encoded coefficients (rebuild to matrix)
            - x: Parity mapping bit (deterministic from sk)
            - θ: len(M_bits)
            - ρ: Random seed for XOR (deterministic from sk)
        """
        θ = len(M_bits)

        # 1. Deterministically derive x from secret key (not random)
        x = self._derive_x_from_sk()

        # 2. Deterministically generate ρ ∈ {0, 1}^θ from secret key
        ρ = self._derive_rho_from_sk(θ)

        # 3. Compute M_ρ = M_bits ⊕ ρ
        M_ρ = [m ^ r for m, r in zip(M_bits, ρ)]

        # 4. Encode based on θ (using 255·k as threshold per paper)
        max_msg_coeffs = 255 * self.params.k - 1

        if θ <= max_msg_coeffs:
            # Case 1: Use ParityKernel directly
            z = self.parity_kernel(M_ρ, θ, x)
        else:
            # Case 2: Split M_ρ into M1 and M2 (θ ≥ 255·k)
            M1 = M_ρ[:max_msg_coeffs]
            M2 = M_ρ[max_msg_coeffs:]

            # Encode M1 with ParityKernel
            z1 = self.parity_kernel(M1, len(M1), x)

            # Store indices of 1s in M2 (StoreIndexes function)
            z2 = self.store_indexes(M2)

            # Paper: z = z1 (z2 stored separately)
            # For implementation, we'll combine with marker
            marker = self.params.q - 1  # Special marker
            z = z1 + [marker] + z2

        # Apply Rebuild as in paper
        z = self.rebuild(z)

        return z, x, θ, ρ

    def store_indexes(self, M2: List[int]) -> List[int]:
        """
        StoreIndexes(M2) as mentioned in paper
        Store positions where M2 has 1 bits as coefficients
        """
        indexes = []
        for i, bit in enumerate(M2):
            if bit == 1:
                # Map index to a coefficient in Z_q
                # Use simple linear mapping for demonstration
                mapped_idx = (i * (self.params.q // (2 * len(M2)))) % self.params.q
                indexes.append(mapped_idx)
        return indexes

    def _derive_x_from_sk(self) -> int:
        """Deterministically derive x from secret key"""
        # Use first byte of hash mod 2
        hash_val = hashlib.sha256(self.secret_key).digest()[0]
        return hash_val % 2

    def _derive_rho_from_sk(self, theta: int) -> List[int]:
        """Deterministically derive ρ from secret key"""
        # Generate deterministic but unpredictable ρ
        rng_seed = hashlib.sha256(self.secret_key + b"_rho").digest()
        rng = random.Random(int.from_bytes(rng_seed, 'big'))

        # Generate ρ bits
        ρ = []
        for _ in range(theta):
            ρ.append(rng.randint(0, 1))

        return ρ

    # -----------------------------------------------------------------
    # Algorithm 4: Kleptographic Attack on Dilithium with rejection loop
    # -----------------------------------------------------------------
    def kleptographic_attack(self, M_bits: List[int], max_attempts: int = 100) -> Tuple[np.ndarray, bytes, dict]:
        """
        Implements Algorithm 4: Kleptographic Attack on Dilithium
        WITH rejection loop and proper checks

        Args:
            M_bits: Message bits to sign and hide
            max_attempts: Maximum number of attempts before giving up

        Returns:
            (z_matrix, c, encoding_info) where:
            - z_matrix: k×n matrix with embedded message
            - c: Challenge (simulated hash)
            - encoding_info: Dictionary with x, θ, ρ for decoding
        """
        attempt = 0

        while attempt < max_attempts:
            attempt += 1

            # Step 1: Generate y (from S_{γ1-1}^l)
            y = self._generate_y()

            # Step 2: Compute w1 = HighBits(Ay, 2γ2)
            w1 = self._simulate_highbits(y)

            # Step 3: Generate challenge c = H(M, w1)
            c_hash = self._hash_message(M_bits, w1)
            c = self._hash_to_challenge(c_hash)

            # Step 4: Compute z = Parity(M_bits) - THE BACKDOOR
            z_coeffs, x, θ, ρ = self.parity(M_bits)

            # Step 5: Check rejection conditions (as in paper)
            # Condition 1: ||z||_∞ < γ1 - β
            z_norm = max(abs(coeff - (self.params.q // 2)) for coeff in z_coeffs)
            cond1 = z_norm < (self.params.gamma1 - self.params.beta)

            # Condition 2: ||LowBits(Ay - c·s2, 2γ2)||_∞ < γ2 - β
            # Simplified simulation (in real attack, would use actual A, s2)
            cond2 = self._simulate_lowbits_condition(y, c)

            if cond1 and cond2:
                # Acceptance - break the loop
                break
            else:
                # Rejection - continue loop
                if attempt == max_attempts:
                    print(f"Warning: Max attempts ({max_attempts}) reached")

        # Reshape z to k×n matrix
        z_matrix = self._reshape_to_matrix(z_coeffs)

        # Store encoding info for attacker
        encoding_info = {
            'x': x,
            'theta': θ,
            'rho': ρ,
            'c': c,
            'attempts': attempt
        }

        return z_matrix, c, encoding_info

    def decode_message(self, z: np.ndarray, encoding_info: dict) -> List[int]:
        """
        Decode message from backdoored signature

        Args:
            z: Signature matrix (k×n)
            encoding_info: Dictionary with x, θ, ρ

        Returns:
            Decoded message bits
        """
        x = encoding_info['x']
        θ = encoding_info['theta']
        ρ = encoding_info['rho']

        # Flatten matrix
        z_flat = z.flatten().tolist()

        # Check if there's a marker for split encoding
        if self.params.q - 1 in z_flat:
            marker_pos = z_flat.index(self.params.q - 1)
            # Split encoding case (θ ≥ 255·k)
            z1 = z_flat[:marker_pos]
            z2 = z_flat[marker_pos + 1:]

            # Decode first part (M1)
            M1_rho = []
            for coeff in z1[:255 * self.params.k - 1]:
                parity = 1 if coeff % 2 == 1 else 0
                M1_rho.append(parity ^ x)

            # Decode second part (M2 from indices)
            M2_len = θ - len(M1_rho)
            M2_rho = [0] * M2_len
            for coeff in z2:
                # Reverse mapping from coefficient to index
                idx = (coeff * (2 * M2_len)) // self.params.q
                if idx < M2_len:
                    M2_rho[idx] = 1

            M_rho = M1_rho + M2_rho
        else:
            # Normal case (θ ≤ 255·k - 1)
            M_rho = []
            for i in range(min(θ, len(z_flat))):
                coeff = z_flat[i]
                parity = 1 if coeff % 2 == 1 else 0
                M_rho.append(parity ^ x)

        # XOR with ρ to get original message
        M_bits = [mr ^ r for mr, r in zip(M_rho, ρ)]

        return M_bits[:θ]

    # -----------------------------------------------------------------
    # Dilithium simulation methods
    # -----------------------------------------------------------------
    def _generate_y(self) -> List[int]:
        """Generate y vector from S_{γ1-1}^l"""
        return [random.randint(-(self.params.gamma1 - 1), self.params.gamma1 - 1)
                for _ in range(self.params.k * self.params.n)]

    def _simulate_highbits(self, y: List[int]) -> bytes:
        """Simulate HighBits(Ay, 2γ2)"""
        # In real implementation, would compute Ay and take high bits
        # For simulation, return hash of y
        return hashlib.sha256(str(y).encode()).digest()[:16]

    def _hash_message(self, M_bits: List[int], w1: bytes) -> bytes:
        """Compute H(M, w1)"""
        M_bytes = self._bits_to_bytes(M_bits)
        return hashlib.sha256(M_bytes + w1).digest()

    def _hash_to_challenge(self, hash_bytes: bytes) -> List[int]:
        """Convert hash to challenge polynomial c ∈ B_τ"""
        # Simplified: return list of ±1 values
        c = []
        for byte in hash_bytes[:self.params.tau]:
            bit = byte % 2
            c.append(1 if bit == 1 else -1)
        return c

    def _simulate_lowbits_condition(self, y: List[int], c: List[int]) -> bool:
        """
        Simulate condition: ||LowBits(Ay - c·s2, 2γ2)||_∞ < γ2 - β

        In real attack, would compute Ay - c·s2 and check low bits
        For simulation, return True with high probability
        """
        # Simulate probability of acceptance (about 0.9 for Dilithium)
        return random.random() < 0.9

    def _reshape_to_matrix(self, coeffs: List[int]) -> np.ndarray:
        """Reshape coefficients to k×n matrix"""
        target_len = self.params.k * self.params.n
        if len(coeffs) < target_len:
            coeffs.extend([0] * (target_len - len(coeffs)))
        elif len(coeffs) > target_len:
            coeffs = coeffs[:target_len]

        return np.array(coeffs, dtype=np.int64).reshape(self.params.k, self.params.n)

    def _bits_to_bytes(self, bits: List[int]) -> bytes:
        """Convert bits to bytes"""
        byte_array = bytearray()
        for i in range(0, len(bits), 8):
            byte_val = 0
            for j in range(8):
                if i + j < len(bits):
                    byte_val |= bits[i + j] << j
            byte_array.append(byte_val)
        return bytes(byte_array)

    def _bytes_to_bits(self, data: bytes) -> List[int]:
        """Convert bytes to bits"""
        bits = []
        for byte in data:
            bits.extend([(byte >> i) & 1 for i in range(8)])
        return bits

# ============================================================================
# EXACT PAPER EXAMPLE REPRODUCTION WITH ALL CASES
# ============================================================================

def run_paper_example_with_fixes():
    """
    Exact reproduction of Example 1 from the paper with all fixes applied
    """
    print("=" * 80)
    print("EXACT PAPER EXAMPLE")
    print("=" * 80)

    # Small parameters matching paper example
    params = DilithiumParams(q=10, n=8, k=1, gamma1=9, gamma2=9, beta=2, tau=3)
    sk = b"example_secret_key"
    backdoor = DilithiumBackdoorPaperExact(params, sk)

    # Force deterministic parameters to match paper example
    # Paper: M_bits = [0, 1, 0, 1, 1, 0], x=0, ρ=[1,1,1,0,0,0]
    M_bits = [0, 1, 0, 1, 1, 0]
    θ = len(M_bits)

    # Override derivation to match paper exactly
    backdoor.secret_key = b"force_x0"  # Hash will give x=0
    x = backdoor._derive_x_from_sk()

    # Manually set ρ to match paper
    ρ = [1, 1, 1, 0, 0, 0]

    print(f"Paper Example 1 Parameters:")
    print(f"M_bits: {M_bits}")
    print(f"x: {x} (0 → even, 1 → odd)")
    print(f"ρ: {ρ}")
    print(f"θ: {θ}")

    # Compute M_ρ
    M_ρ = [m ^ r for m, r in zip(M_bits, ρ)]
    print(f"M_ρ = M_bits ⊕ ρ: {M_ρ}")

    # Encode with parity_kernel
    print("\nEncoding with ParityKernel:")
    z_coeffs = backdoor.parity_kernel(M_ρ, θ, x)

    # Take only first θ coefficients (message part)
    z_msg = z_coeffs[:θ]
    print(f"z coefficients (first {θ}): {z_msg}")

    # Check parity matches expected
    print("\nVerification of parity encoding:")
    for i in range(θ):
        coeff = z_msg[i]
        parity = "odd" if coeff % 2 == 1 else "even"
        expected = "odd" if M_ρ[i] == 1 else "even"
        match = "✓" if ((coeff % 2 == 1) == (M_ρ[i] == 1)) else "✗"
        print(f"  z[{i}] = {coeff} ({parity}) | Expected: {expected} {match}")

    # Decode
    print("\nDecoding verification:")
    decoded_M_ρ = []
    for coeff in z_msg:
        parity = 1 if coeff % 2 == 1 else 0
        decoded_bit = parity ^ x
        decoded_M_ρ.append(decoded_bit)

    print(f"Decoded M_ρ: {decoded_M_ρ}")

    # Recover original
    decoded_M_bits = [mr ^ r for mr, r in zip(decoded_M_ρ, ρ)]
    print(f"Decoded M_bits: {decoded_M_bits}")

    if decoded_M_bits == M_bits:
        print(" SUCCESS: Perfect match with paper example!")
    else:
        print(" FAILED: Decoding error")

    # Show all three cases
    print("\n" + "-" * 80)
    print("All Three Cases from Paper:")

    # Case 1: γ = θ = 6 (exact fit)
    print("\nCase 1: γ = 6 (message fits exactly)")
    # Reset for case 1
    z1 = backdoor.parity_kernel(M_ρ, θ, x)[:θ]
    print(f"z = {z1}")

    # Case 2: γ = 8 (add noise)
    print("\nCase 2: γ = 8 (add 2 noise positions)")
    # We need to encode into 8 coefficients total
    # First 6 are message, last 2 are noise
    z2_msg = backdoor.parity_kernel(M_ρ, θ, x)[:8]
    print(f"z = {z2_msg} (last 2 are noise)")

    # Case 3: γ = 4 (split message)
    print("\nCase 3: γ = 4 (split message)")
    M1 = M_bits[:4]
    M2 = M_bits[4:]
    ρ1 = ρ[:4]
    ρ2 = ρ[4:]

    print(f"Split: M1 = {M1}, M2 = {M2}")
    print(f"       ρ1 = {ρ1}, ρ2 = {ρ2}")

    # Encode M1
    M1_ρ = [m ^ r for m, r in zip(M1, ρ1)]
    z3_msg = backdoor.parity_kernel(M1_ρ, 4, x)[:4]
    print(f"z1 (M1 encoded): {z3_msg}")

    # Store indexes of 1s in M2
    indexes = [i for i, bit in enumerate(M2) if bit == 1]
    print(f"Indexes of 1s in M2: {indexes}")

# ============================================================================
# COMPLETE ATTACK DEMONSTRATION WITH REJECTION LOOP
# ============================================================================

def demonstrate_complete_attack_with_rejection():
    """
    Demonstrate complete attack with rejection sampling loop
    """
    print("\n" + "=" * 80)
    print("COMPLETE ATTACK WITH REJECTION SAMPLING")
    print("=" * 80)

    # Use realistic parameters
    params = DilithiumParams()
    sk = b"real_secret_key_123"
    backdoor = DilithiumBackdoorPaperExact(params, sk)

    # Test message
    test_message = "Confidential: Launch code Alpha-7"
    print(f"Original message: '{test_message}'")

    # Convert to bits
    M_bits = []
    for char in test_message:
        M_bits.extend([(ord(char) >> i) & 1 for i in range(8)])

    print(f"Message length: {len(M_bits)} bits")
    print(f"Max encodable (255·k-1): {255 * params.k - 1} bits")

    # Run attack with rejection loop
    print("\n[Phase 1: Running kleptographic attack with rejection sampling]")
    print("This may require multiple attempts due to rejection conditions...")

    z_matrix, c, encoding_info = backdoor.kleptographic_attack(
        M_bits, max_attempts=50
    )

    print(f"\nAttack completed in {encoding_info['attempts']} attempt(s)")
    print(f"Generated signature:")
    print(f"  z shape: {z_matrix.shape}")
    print(f"  x: {encoding_info['x']}")
    print(f"  θ: {encoding_info['theta']}")

    # Attacker extracts message
    print("\n[Phase 2: Extracting hidden message]")
    extracted_bits = backdoor.decode_message(z_matrix, encoding_info)

    # Convert back to string
    extracted_chars = []
    for i in range(0, len(extracted_bits), 8):
        char_bits = extracted_bits[i:i+8]
        if len(char_bits) == 8:
            char_val = sum(bit << j for j, bit in enumerate(char_bits))
            if 32 <= char_val <= 126:  # Printable ASCII
                extracted_chars.append(chr(char_val))
            else:
                extracted_chars.append('?')

    extracted_message = ''.join(extracted_chars)

    # Compare
    original_str = test_message
    extracted_str = extracted_message[:len(original_str)]

    print(f"Extracted message: '{extracted_str}'")

    if extracted_str == original_str:
        print(" ATTACK SUCCESSFUL: Message perfectly recovered!")
    else:
        # Calculate accuracy
        matching = sum(1 for a, b in zip(original_str, extracted_str) if a == b)
        accuracy = (matching / len(original_str)) * 100
        print(f"  Partial success: {accuracy:.1f}% accuracy")

        if accuracy < 100:
            print("\nCharacter differences:")
            for i, (orig, extr) in enumerate(zip(original_str, extracted_str)):
                if orig != extr:
                    print(f"  Pos {i}: '{orig}' → '{extr}'")

# ============================================================================
# TEST DIFFERENT MESSAGE LENGTHS
# ============================================================================

def test_different_lengths():
    """
    Test the attack with messages of different lengths
    """
    print("\n" + "=" * 80)
    print("TESTING DIFFERENT MESSAGE LENGTHS")
    print("=" * 80)

    params = DilithiumParams(k=4)
    sk = b"test_key"
    backdoor = DilithiumBackdoorPaperExact(params, sk)

    test_cases = [
        ("Short", "Hi", 16),  # 16 bits
        ("Medium", "Test message", 96),  # 96 bits
        ("Long", "This is a longer message to test the encoding", 352),  # 352 bits
    ]

    for name, message, expected_bits in test_cases:
        print(f"\n{name} message: '{message}'")

        # Convert to bits
        M_bits = []
        for char in message:
            M_bits.extend([(ord(char) >> i) & 1 for i in range(8)])

        print(f"  Length: {len(M_bits)} bits")

        # Check if fits in 255·k-1
        max_capacity = 255 * params.k - 1
        if len(M_bits) <= max_capacity:
            print(f"  Fits in capacity: ✓ (capacity: {max_capacity})")
        else:
            print(f"  Exceeds capacity: ✗ (will use split encoding)")

        # Encode
        z_coeffs, x, θ, ρ = backdoor.parity(M_bits)

        # Decode
        z_matrix = backdoor._reshape_to_matrix(z_coeffs)
        encoding_info = {'x': x, 'theta': θ, 'rho': ρ}
        extracted_bits = backdoor.decode_message(z_matrix, encoding_info)

        # Check
        if extracted_bits[:len(M_bits)] == M_bits:
            print(f"  Decoding: ✓ Perfect")
        else:
            errors = sum(1 for a, b in zip(extracted_bits[:len(M_bits)], M_bits) if a != b)
            print(f"  Decoding: ✗ {errors} errors")

# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print("=" * 80)
    print("KLEPTOGRAPHIC BACKDOOR ATTACK ON CRYSTALS-DILITHIUM")
    print("Complete Implementation with All Paper Corrections")
    print("=" * 80)

    # 1. Run exact paper example with all fixes
    run_paper_example_with_fixes()

    # 2. Demonstrate complete attack with rejection sampling
    demonstrate_complete_attack_with_rejection()

    # 3. Test different message lengths
    test_different_lengths()



KLEPTOGRAPHIC BACKDOOR ATTACK ON CRYSTALS-DILITHIUM
Complete Implementation with All Paper Corrections
EXACT PAPER EXAMPLE
Paper Example 1 Parameters:
M_bits: [0, 1, 0, 1, 1, 0]
x: 1 (0 → even, 1 → odd)
ρ: [1, 1, 1, 0, 0, 0]
θ: 6
M_ρ = M_bits ⊕ ρ: [1, 0, 1, 1, 1, 0]

Encoding with ParityKernel:
z coefficients (first 6): [4, 1, 2, 8, 8, 5]

Verification of parity encoding:
  z[0] = 4 (even) | Expected: odd ✗
  z[1] = 1 (odd) | Expected: even ✗
  z[2] = 2 (even) | Expected: odd ✗
  z[3] = 8 (even) | Expected: odd ✗
  z[4] = 8 (even) | Expected: odd ✗
  z[5] = 5 (odd) | Expected: even ✗

Decoding verification:
Decoded M_ρ: [1, 0, 1, 1, 1, 0]
Decoded M_bits: [0, 1, 0, 1, 1, 0]
 SUCCESS: Perfect match with paper example!

--------------------------------------------------------------------------------
All Three Cases from Paper:

Case 1: γ = 6 (message fits exactly)
z = [6, 7, 8, 2, 6, 3]

Case 2: γ = 8 (add 2 noise positions)
z = [2, 7, 8, 2, 2, 3, 0, 2] (last 2 are noise)

Case 3: γ = 4 (