<a href="https://colab.research.google.com/github/owlmt/PQC/blob/main/mldsa_backdoor_experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://link.springer.com/chapter/10.1007/978-3-031-83885-9_30

In [14]:
import random
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import hashlib

# =============================================================================
# DILITHIUM PARAMETERS (minimal set needed for Algorithms 4–6 demonstration)
# =============================================================================

@dataclass
class DilithiumParams:
    q: int = 8380417
    n: int = 256
    k: int = 4
    gamma1: int = (8380417 - 1) // 16
    gamma2: int = (8380417 - 1) // 88
    beta: int = 39
    tau: int = 60


# =============================================================================
# PARITY-BASED BACKDOOR (Hybrid implementation)
# =============================================================================

class DilithiumParityBackdoorHybrid:
    """
    Hybrid implementation of the parity-based backdoor from Algorithms 4–6.
    """

    def __init__(self, params: Optional[DilithiumParams] = None):
        self.params = params or DilithiumParams()
        self.total_coeffs = self.params.n * self.params.k  # 256k coefficients

    # -------------------------------------------------------------------------
    # Helpers for even/odd sampling
    # -------------------------------------------------------------------------

    def _sample_even(self) -> int:
        q = self.params.q
        x = random.randrange(0, q // 2 + 1)
        return (2 * x) % q

    def _sample_odd(self) -> int:
        q = self.params.q
        x = random.randrange(0, q // 2)
        return (2 * x + 1) % q

    # -------------------------------------------------------------------------
    # Algorithm 6: ParityKernel
    # -------------------------------------------------------------------------

    def parity_kernel(self, M_rho: List[int], theta: int, x: int,
                      target_len: int) -> List[int]:
        q = self.params.q
        k = self.params.k

        z = []

        for i in range(theta):
            want_parity = M_rho[i] ^ x
            coeff = self._sample_odd() if want_parity == 1 else self._sample_even()
            z.append(coeff % q)

        msg_cap = 255 * k - 1
        if theta <= msg_cap:
            noise_len = msg_cap - theta + 1
            for _ in range(noise_len):
                z.append(random.randrange(0, q))

        return z

    # -------------------------------------------------------------------------
    # Algorithm 5: Parity
    # -------------------------------------------------------------------------

    def parity(self, M_bits: List[int]):
        theta = len(M_bits)
        k = self.params.k
        msg_cap = 255 * k - 1

        x = random.randint(0, 1)
        rho = [random.randint(0, 1) for _ in range(theta)]
        M_rho = [m ^ r for m, r in zip(M_bits, rho)]

        if theta <= msg_cap:
            z_partial = self.parity_kernel(M_rho, theta, x, target_len=msg_cap + 1)
            z = self._rebuild_full(z_partial)
            return z, x, theta, rho, None

        else:
            M1 = M_rho[:msg_cap + 1]
            M2 = M_rho[msg_cap + 1:]

            z1 = self.parity_kernel(M1, len(M1), x, target_len=msg_cap + 1)

            z2_indexes = [i for i, bit in enumerate(M2) if bit == 1]

            z = self._embed_z2_into_z1(z1, z2_indexes)

            return z, x, theta, rho, z2_indexes

    # -------------------------------------------------------------------------
    # Hybrid z2 embedding
    # -------------------------------------------------------------------------

    def _embed_z2_into_z1(self, z1: List[int], z2_indexes: List[int]):
        q = self.params.q
        full_len = self.total_coeffs

        if len(z1) >= full_len:
            return self._rebuild_full(z1)

        remaining = full_len - len(z1)
        if remaining < len(z2_indexes):
            return self._rebuild_full(z1)

        embedding = [(idx + 2) % q for idx in z2_indexes]

        noise_len = remaining - len(embedding)
        noise = [random.randrange(0, q) for _ in range(noise_len)]

        z = z1 + embedding + noise
        return self._rebuild_full(z)

    # -------------------------------------------------------------------------
    # Final rebuild to size 256k
    # -------------------------------------------------------------------------

    def _rebuild_full(self, z_partial: List[int]) -> List[int]:
        q = self.params.q
        out = z_partial[:]

        if len(out) < self.total_coeffs:
            pad_len = self.total_coeffs - len(out)
            out.extend(random.randrange(0, q) for _ in range(pad_len))
        else:
            out = out[:self.total_coeffs]

        return out

    # -------------------------------------------------------------------------
    # Algorithm 4 (simulated)
    # -------------------------------------------------------------------------

    def kleptographic_sign(self, M_bits: List[int], max_attempts: int = 50):
        attempt = 0
        z = None
        x = theta = rho = None
        z2_indexes = None

        while attempt < max_attempts:
            attempt += 1
            y = self._generate_y()
            w1 = self._simulate_highbits(y)
            chash = self._hash_message(M_bits, w1)
            c = self._hash_to_challenge(chash)
            z, x, theta, rho, z2_indexes = self.parity(M_bits)

            if self._accept_signature(z, y):
                break

        z_matrix = np.array(z, dtype=np.int64).reshape(self.params.k, self.params.n)

        info = {
            "x": x,
            "theta": theta,
            "rho": rho,
            "z2_indexes": z2_indexes,
            "attempts": attempt
        }
        return z_matrix, c, info

    # -------------------------------------------------------------------------
    # Decoder
    # -------------------------------------------------------------------------

    def decode_message(self, z_matrix: np.ndarray, info: Dict) -> List[int]:
        x = info["x"]
        theta = info["theta"]
        rho = info["rho"]
        z2_indexes = info["z2_indexes"]

        z_flat = z_matrix.flatten().tolist()
        k = self.params.k
        msg_cap = 255 * k - 1

        M1_rho = []
        for i in range(min(theta, msg_cap + 1)):
            parity = z_flat[i] & 1
            M1_rho.append(parity ^ x)

        if theta <= msg_cap:
            M_rho = M1_rho
        else:
            M2_len = theta - (msg_cap + 1)
            M2_rho = [0] * M2_len

            if z2_indexes is not None:
                for idx in z2_indexes:
                    if idx < M2_len:
                        M2_rho[idx] = 1
            else:
                tail = z_flat[msg_cap + 1:]
                for coeff in tail:
                    idx = (coeff - 2)
                    if 0 <= idx < M2_len:
                        M2_rho[idx] = 1

            M_rho = M1_rho + M2_rho

        M_bits = [mr ^ r for mr, r in zip(M_rho, rho)]
        return M_bits[:theta]

    # -------------------------------------------------------------------------
    # Simulated Dilithium internals
    # -------------------------------------------------------------------------

    def _generate_y(self):
        g1 = self.params.gamma1
        return [random.randint(-(g1 - 1), g1 - 1)
                for _ in range(self.params.k * self.params.n)]

    def _simulate_highbits(self, y):
        return hashlib.sha256(str(y).encode()).digest()[:16]

    def _hash_message(self, M_bits, w1):
        return hashlib.sha256(self._bits_to_bytes(M_bits) + w1).digest()

    def _hash_to_challenge(self, h):
        return [1 if (b & 1) else -1 for b in h[: self.params.tau]]

    def _accept_signature(self, z, y):
        g1 = self.params.gamma1
        beta = self.params.beta

        centered = [abs(v - self.params.q // 2) for v in z]
        cond1 = max(centered) < (g1 - beta)
        cond2 = random.random() < 0.9
        return cond1 and cond2

    @staticmethod
    def _bits_to_bytes(bits):
        out = bytearray()
        for i in range(0, len(bits), 8):
            v = 0
            for j in range(8):
                if i + j < len(bits):
                    v |= (bits[i + j] & 1) << j
            out.append(v)
        return bytes(out)


# =============================================================================
# TEST UTILITIES (needed by run_test)
# =============================================================================

def to_bits(msg: str):
    bits = []
    for ch in msg:
        v = ord(ch)
        for i in range(8):
            bits.append((v >> i) & 1)
    return bits


def from_bits(bits):
    chars = []
    for i in range(0, len(bits), 8):
        block = bits[i:i+8]
        if len(block) < 8:
            break
        v = sum((block[j] & 1) << j for j in range(8))
        chars.append(chr(v) if 32 <= v <= 126 else '?')
    return ''.join(chars)


def bits_to_hex_preview(bits, max_bytes=32):
    """Convert first few bytes of bits to hex string for debugging"""
    hex_chars = []
    byte_count = min(max_bytes, len(bits) // 8)

    for i in range(0, byte_count * 8, 8):
        byte_val = 0
        for j in range(8):
            if i + j < len(bits):
                byte_val |= (bits[i + j] & 1) << j
        hex_chars.append(f"{byte_val:02x}")

    preview = ' '.join(hex_chars)
    if len(bits) // 8 > max_bytes:
        preview += " ..."
    return preview


def run_test(name: str, M_bits, bd: DilithiumParityBackdoorHybrid):
    print(f"\n=== TEST: {name} ===")
    print(f"Message length: {len(M_bits)} bits")

    z_matrix, c, info = bd.kleptographic_sign(M_bits)
    recovered_bits = bd.decode_message(z_matrix, info)

    ok = recovered_bits == M_bits
    print(f"Recovered correct: {ok}")
    print(f"Attempts: {info['attempts']}")

    if info["z2_indexes"] is not None:
        print(f"Academic z2 metadata used (indexes={len(info['z2_indexes'])})")
    else:
        print("z2 fully embedded inside z")

    # HEX preview for debugging
    print(f"Original hex  : {bits_to_hex_preview(M_bits)}")
    print(f"Recovered hex : {bits_to_hex_preview(recovered_bits)}")

    return ok


# =============================================================================
# FULL TEST HARNESS (minimal, edge, split, maximal)
# =============================================================================

if __name__ == "__main__":

    params = DilithiumParams()
    bd = DilithiumParityBackdoorHybrid(params)

    k = params.k
    cap = 255 * k - 1
    total = params.k * params.n

    print("\n==============================================")
    print(" RUNNING HYBRID DILITHIUM PARITY BACKDOOR TESTS")
    print("==============================================")

    # 1. Minimal message
    run_test("Minimal message (1 bit)", [1], bd)

    # 2. Short-message edge
    M_bits_edge = [random.randint(0, 1) for _ in range(cap)]
    run_test(f"Short-message edge (θ = {cap})", M_bits_edge, bd)

    # 3. Split threshold exact
    M_bits_split_exact = [random.randint(0, 1) for _ in range(cap + 1)]
    run_test(f"Split threshold (θ = {cap + 1})", M_bits_split_exact, bd)

    # 4. Large message
    M_bits_large = [random.randint(0, 1) for _ in range(4096)]
    run_test("Moderately large message (4096 bits)", M_bits_large, bd)

    # 5. Maximal embeddable message
    embed_capacity = total - (cap + 1)
    max_embed_msg = (cap + 1) + embed_capacity
    M_bits_max_embed = [random.randint(0, 1) for _ in range(max_embed_msg)]
    run_test(
        f"Maximal embed message (θ ≈ {max_embed_msg})",
        M_bits_max_embed, bd
    )

    # 6. Forced academic mode
    M_bits_force_academic = [random.randint(0, 1)
                             for _ in range(max_embed_msg + 100)]
    run_test(
        "Message exceeding embedding-capacity (forces academic z2 metadata)",
        M_bits_force_academic, bd
    )


 RUNNING HYBRID DILITHIUM PARITY BACKDOOR TESTS

=== TEST: Minimal message (1 bit) ===
Message length: 1 bits
Recovered correct: True
Attempts: 50
z2 fully embedded inside z
Original hex  : 
Recovered hex : 

=== TEST: Short-message edge (θ = 1019) ===
Message length: 1019 bits
Recovered correct: True
Attempts: 50
z2 fully embedded inside z
Original hex  : f9 ae b4 ca 74 3a fa 61 f0 52 16 04 bb 96 c9 e1 6c 85 d6 e1 bf 96 fe 00 de 1d 63 60 b3 1b fb 64 ...
Recovered hex : f9 ae b4 ca 74 3a fa 61 f0 52 16 04 bb 96 c9 e1 6c 85 d6 e1 bf 96 fe 00 de 1d 63 60 b3 1b fb 64 ...

=== TEST: Split threshold (θ = 1020) ===
Message length: 1020 bits
Recovered correct: True
Attempts: 50
Academic z2 metadata used (indexes=0)
Original hex  : 4e ed 49 0d fa a4 8a db 4f 5a bb e0 59 f2 44 0b ca dd 5b 36 e8 4a 2c cd e6 12 91 33 2b d6 a8 a1 ...
Recovered hex : 4e ed 49 0d fa a4 8a db 4f 5a bb e0 59 f2 44 0b ca dd 5b 36 e8 4a 2c cd e6 12 91 33 2b d6 a8 a1 ...

=== TEST: Moderately large message (4096 bits) =