# Key-Recovery Plaintext-Checking Attack on Kyber
$m[j] = \texttt{Compress}(1, v[j] - s_i[j] \cdot u_i[j])$

Depending on the value of $s_i[j]$ (the $j^\text{th}$ coefficient of the $i^\text{th}$ polynomial in the secret key $\mathbf{s}\in R^k$), the range of $v[j]$ for which $m[j]$ is $1$ will change, so depending on the response for some chosen values of $v[j]$, we can pin down the value of $s_i[j]$. We then only need to repeat for all $1 \leq i \leq k, 1 \leq j \leq n$.

## Compression and decompression
```rust
use std::collections::{HashMap, HashSet};
pub const KYBER_Q: u32 = 3329;
pub const HALF_Q: u32 = 1665;

pub fn to_positive_repr(mut val: i16) -> u32 {
    val += (val >> 15) & 3329;
    val as u32
}

pub fn compress(d: usize, mut val: u32) -> u32 {
    val = val << d;
    val += HALF_Q;
    val *= 80635;
    val >>= 28;
    return val & (u32::MAX >> (32 - d));
}

pub fn decompress(d: usize, mut val: u32) -> u32 {
    val = val * KYBER_Q;
    val += 1 << (d - 1);
    val >>= d;
    return val;
}

fn main() {
    // ML-KEM-512/768
    let (du, dv, eta1) = (10usize, 4usize, 3i16);
    // ML-KEM-1024
    // let (du, dv, eta1) = (11usize, 5usize, 2i16);
    let compressed_u = 1<<5;
    let u = decompress(du, compressed_u);
    
    for s in -eta1..=eta1 {
        println!("s: {s}");
        let s_unsigned = to_positive_repr(s);
        for compressed_v in 0..(1<<dv) {
            let v = decompress(dv, compressed_v);
            let decryption = compress(
                1, 
                (v + KYBER_Q - (s_unsigned * u % KYBER_Q)) % KYBER_Q
            );
            println!("compressed_v: {compressed_v}, m: {decryption}");
        }
    }
}
```

In [1]:
import numpy as np

KYBER_Q = 3329
KYBER_CEIL_HALF_Q = 1665

def compress(d: int, num: int) -> int:
    """Algorithm 4.7 from FIPS 203"""
    num <<= d
    num += KYBER_CEIL_HALF_Q
    num //= KYBER_Q

    return num % (1 << d)

def decompress(d: int, num: int) -> int:
    """Algorithm 4.8 of FIPS 203"""
    num *= KYBER_Q
    num += (1 << (d - 1))
    num >>= d
    return num

du, dv, eta1 = 10, 4, 3  # ML-KEM-512
# du, dv, eta1 = 10, 4, 2  # ML-KEM-768
# du, dv, eta1 = 11, 5, 2  # ML-KEM-1024

For a given choice of $u_i[j]$ and $v_j$, the value of $\hat{m}[j] \leftarrow \texttt{compress}_{d=1}(v[j] - s_i[j]\cdot u_i[j])$ depend on the value of $s_i[j]$, which is how we can pinpoint the value of a specific coefficient in the secret key. The strategy would be something like this:

- Given some $u_i[j]$, the range of $v[j]$ such that $\hat{m}[j]$ is $1$ is "this to that" if $s_i[j] = 0$, is "this to that" if $s_i[j] = 1$, etc.

In [11]:
# fix some value for u
compressed_u = 1 << 6
u = decompress(du, compressed_u)

for compressed_v in range(1 << dv):
    v = decompress(dv, compressed_v)
    pattern = []
    for s_ij in range(-eta1, eta1 + 1):
        m = compress(1, (v + KYBER_Q - s_ij * u) % KYBER_Q)
        pattern.append(m)
    print(f"compressed_v: {compressed_v:05b} pattern: {pattern}")

compressed_v: 00000 pattern: [0, 0, 0, 0, 0, 0, 0]
compressed_v: 00001 pattern: [1, 0, 0, 0, 0, 0, 0]
compressed_v: 00010 pattern: [1, 1, 0, 0, 0, 0, 0]
compressed_v: 00011 pattern: [1, 1, 1, 0, 0, 0, 0]
compressed_v: 00100 pattern: [1, 1, 1, 1, 0, 0, 0]
compressed_v: 00101 pattern: [1, 1, 1, 1, 1, 0, 0]
compressed_v: 00110 pattern: [1, 1, 1, 1, 1, 1, 0]
compressed_v: 00111 pattern: [1, 1, 1, 1, 1, 1, 1]
compressed_v: 01000 pattern: [1, 1, 1, 1, 1, 1, 1]
compressed_v: 01001 pattern: [0, 1, 1, 1, 1, 1, 1]
compressed_v: 01010 pattern: [0, 0, 1, 1, 1, 1, 1]
compressed_v: 01011 pattern: [0, 0, 0, 1, 1, 1, 1]
compressed_v: 01100 pattern: [0, 0, 0, 0, 1, 1, 1]
compressed_v: 01101 pattern: [0, 0, 0, 0, 0, 1, 1]
compressed_v: 01110 pattern: [0, 0, 0, 0, 0, 0, 1]
compressed_v: 01111 pattern: [0, 0, 0, 0, 0, 0, 0]


Using the snippet above, it is easy to develop the concrete strategy for finding $s_i[j]$ (for ML-KEM-512) using up to three queries:

![ML-KEM-512 KR-PCA queries](./kyber512-ky-pca-queries.png)

There are $k \times n$ secret coefficients, so it takes at most $(\log_2{2 \cdot \eta_1 + 1}\cdot k \cdot n)$ decryption queries to completely recover the secret key:

|security level|$k$|$n$|$eta_1$|decryption queries needed|
|:--|:--|:--|:--|:--|
|ML-KEM-512|2|256|3|1536|
|ML-KEM-768|3|256|2|2304|
|ML-KEM-512|4|256|2|3072|