# Post-Quantum Key Establishment with ML-KEM (Kyber)

This notebook provides a practical, hands-on introduction to the **ML-KEM** post-quantum key
encapsulation mechanism (formerly known as Kyber), including:

- Why post-quantum cryptography (PQC) matters
- A high-level view of the ML-KEM algorithm
- A working Python example using the `pyoqs` library
- **Step-by-step sender/receiver key sharing** with detailed explanations
- Simple performance measurements
- A light mathematical appendix (Module-LWE, keygen, encaps/decaps)

> **Note:** This notebook uses the `pyoqs` library, which provides Python bindings for the
> Open Quantum Safe (OQS) library. Make sure you have `pyoqs` installed:
> `conda install pyoqs` in the activated env.

## 1. Environment Check

In this section we check the Python version and verify that `pyoqs` is installed and that
the ML-KEM KEM algorithms are available.

In [None]:
import sys
print("Python version:", sys.version)

try:
    import oqs
    print("pyoqs version:", oqs.oqs_version())
    
    # Check what methods are available in the oqs module
    print("Available oqs methods:")
    available_methods = [attr for attr in dir(oqs) if not attr.startswith('_') and callable(getattr(oqs, attr))]
    for method in available_methods:
        print(f"  - {method}")
    
    # Try different possible method names for getting KEM mechanisms
    kems = None
    if hasattr(oqs, 'get_enabled_KEM_mechanisms'):
        kems = oqs.get_enabled_KEM_mechanisms()
    elif hasattr(oqs, 'get_supported_kem_mechanisms'):
        kems = oqs.get_supported_kem_mechanisms()
    elif hasattr(oqs, 'supported_kems'):
        kems = oqs.supported_kems()
    else:
        print("Could not find method to list KEM mechanisms. Let's try to create Kyber768 directly.")
    
    if kems:
        mlkem_algorithms = [kem for kem in kems if 'Kyber' in kem or 'ML-KEM' in kem]
        
        print(f"\nFound {len(kems)} total KEM algorithms")
        print("ML-KEM/Kyber algorithms:")
        for alg in mlkem_algorithms:
            print(f"  - {alg}")
        
        if not mlkem_algorithms:
            print("Warning: No ML-KEM/Kyber algorithms found!")
            print("First few available KEM algorithms:")
            for kem in kems[:5]:
                print(f"  - {kem}")
    
    # Test creating a Kyber768 KEM object
    try:
        kem_test = oqs.KeyEncapsulation("Kyber768")
        print("✅ Kyber768 (ML-KEM-768) is available and ready to use.")
    except Exception as e:
        print(f"❌ Error creating Kyber768 KEM: {e}")
        # Try other possible Kyber variants
        for variant in ["Kyber512", "Kyber1024", "ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"]:
            try:
                oqs.KeyEncapsulation(variant)
                print(f"✅ {variant} is available")
                break
            except:
                continue
        else:
            print("❌ No Kyber/ML-KEM variants found")
        
except ImportError as e:
    print("[ERROR] pyoqs library is not installed.")
    print("        Please install it with: pip install pyoqs")
    raise
except Exception as e:
    print("[ERROR] An unexpected error occurred while importing pyoqs:", e)
    raise

## 2. Why Post-Quantum Cryptography?

Classical public-key cryptography is built on problems like:

- Integer factorization (RSA)
- Discrete logarithms over finite fields or elliptic curves (DH, ECDH)

With large-scale quantum computers and algorithms like **Shor's algorithm**, these problems
become solvable in polynomial time. That means a sufficiently powerful quantum computer could
break RSA and ECC.

**Post-Quantum Cryptography (PQC)** aims to design schemes that remain secure even against
quantum adversaries. ML-KEM is one of the main building blocks for post-quantum key
establishment standardized by NIST.


## 3. High-Level ML-KEM Overview

ML-KEM is a **Key Encapsulation Mechanism (KEM)**. Its API is conceptually:

1. `KeyGen() → (pk, sk)`  
2. `Encaps(pk) → (ct, ss)`  
3. `Decaps(ct, sk) → ss`  

Where:

- `pk`: public key
- `sk`: secret key
- `ct`: ciphertext (encapsulation)
- `ss`: shared secret

The shared secret `ss` is then used as input to a **Key Derivation Function (KDF)** to produce
symmetric keys (e.g., for AES-GCM or ChaCha20-Poly1305).

At an abstract level, ML-KEM is based on the hardness of a **Module-LWE** problem over
polynomial rings. The details are in the math appendix at the end of this notebook.


## 4. ML-KEM in Python with `pyoqs`

In this section we use the `pyoqs` library to:

1. Generate an ML-KEM-768 key pair (using Kyber768 from pyoqs)
2. Encapsulate a shared secret using the public key
3. Decapsulate the ciphertext using the secret key
4. Verify that both sides derive the same shared secret

In [None]:
import oqs

# 1. Key Generation
# We'll use Kyber768 which corresponds to ML-KEM-768
kem = oqs.KeyEncapsulation("Kyber768")
public_key_bytes = kem.generate_keypair()
print("Generated ML-KEM-768 (Kyber768) key pair.")

# Show available methods on the KEM object to debug
print("Available KEM methods:", [method for method in dir(kem) if not method.startswith('_')])

# 2. Encapsulation - use the correct pyoqs API
try:
    # Create a new KEM instance for the sender (encapsulation)
    sender_kem = oqs.KeyEncapsulation("Kyber768")
    
    # The correct method in this pyoqs version is encap_secret with the public key
    ciphertext, shared_secret_enc = sender_kem.encap_secret(public_key_bytes)
    
    print("Encapsulation complete.")
    print("Ciphertext length:", len(ciphertext))
    print("Shared secret (encaps side) length:", len(shared_secret_enc))

    # 3. Decapsulation - use the correct pyoqs API
    # The receiver uses its secret key to decapsulate
    shared_secret_dec = kem.decap_secret(ciphertext)
        
    print("Decapsulation complete.")
    print("Shared secret (decaps side) length:", len(shared_secret_dec))

    # 4. Verify
    print("Shared secrets match:", shared_secret_enc == shared_secret_dec)

    # Additional information about the algorithm
    print(f"\nAlgorithm details:")
    print(f"  Public key length: {len(public_key_bytes)} bytes")
    print(f"  Secret key length: {kem.length_secret_key} bytes")
    print(f"  Ciphertext length: {kem.length_ciphertext} bytes")
    print(f"  Shared secret length: {kem.length_shared_secret} bytes")

except Exception as e:
    print(f"Error: {e}")
    print("Available methods on KEM object:")
    for method in dir(kem):
        if 'encap' in method.lower() or 'decap' in method.lower():
            print(f"  - {method}")

### 4.1 Working with Keys

In `pyoqs`, keys are handled as raw bytes. The public key is returned directly from 
`generate_keypair()`, and the secret key is stored internally in the KEM object.
Let's explore how to work with these keys.

In [None]:
# Public key is already in bytes format
pk_bytes = public_key_bytes
print("Public key length:", len(pk_bytes))
print("Public key (first 32 bytes):", pk_bytes[:32].hex())

# In pyoqs, the secret key is stored internally in the KEM object
# We can get its expected length using the length_secret_key property
print("Secret key length:", kem.length_secret_key, "bytes")

# We can also export the secret key if needed
try:
    secret_key_bytes = kem.export_secret_key()
    print("Secret key exported successfully, length:", len(secret_key_bytes), "bytes")
    print("Secret key (first 16 bytes):", secret_key_bytes[:16].hex())
except Exception as e:
    print("Secret key export not available or failed:", e)

# Create a new KEM instance for demonstration
kem_receiver = oqs.KeyEncapsulation("Kyber768")
print("Created new KEM instance for the same algorithm")

## 5. Step-by-Step Sender/Receiver Key Sharing

This is the core section of the notebook. We look at how **two parties** use ML-KEM to establish
a shared secret.

We call them:

- **Receiver** (server-like): owns the long-term ML-KEM key pair.
- **Sender** (client-like): initiates the session using the receiver's public key.

We split the flow into three phases:

1. **Phase 0 – Setup (one-time or infrequent)**
2. **Phase 1 – Session Establishment (per session)**
3. **Phase 2 – Secure Data Exchange (using the shared secret)**


### 5.1 Phase 0 – Setup (Receiver)

This phase is usually performed **once** or very infrequently (depending on your key rotation
policy).

1. **Receiver generates an ML-KEM key pair**
   - Calls ML-KEM key generation using a secure random number generator.
   - Obtains `(pk, sk)` where `pk` is the public key and `sk` is the secret key.

2. **Receiver stores `sk` securely**
   - Secret key resides in secure storage (HSM, secure enclave, protected file storage, etc.).

3. **Receiver distributes `pk`**
   - Makes the public key available via a trusted mechanism:
     - TLS certificate
     - Configuration or provisioning files
     - Embedded firmware image

Once this phase is complete, any Sender that trusts the Receiver can obtain `pk` and use it to
encapsulate a shared secret.


### 5.2 Phase 1 – Session Establishment (Per Session)

**Goal:** use ML-KEM to create a fresh, session-specific shared secret `ss` known only to Sender
and Receiver.

We will first describe the steps conceptually, then show a concrete Python implementation.

**Sender side (client-like):**

1S. **Obtain Receiver's public key `pk`**  
    The Sender loads `pk` (for example, from a certificate, configuration file, or prior
    provisioning). The Sender must also verify the authenticity of `pk` (e.g., certificate
    validation or a separate signature).

2S. **Encapsulate**  
    The Sender calls `Encaps(pk)`, which internally samples randomness and computes an ML-KEM
    ciphertext. The function returns `(ct, ss_sender)`:
    - `ct`: the encapsulation ciphertext.
    - `ss_sender`: the shared secret on the Sender side.

3S. **Send ciphertext**  
    The Sender transmits `ct` to the Receiver over the (possibly insecure) channel.

**Receiver side (server-like):**

1R. **Receive ciphertext `ct`**  
    The Receiver listens on the communication channel and receives the ciphertext from the Sender.

2R. **Decapsulate**  
    The Receiver calls `Decaps(ct, sk)` using its secret key `sk`. This returns `ss_receiver`. If
    the decapsulation is valid, then with overwhelming probability:

    `ss_receiver == ss_sender`

At the end of Phase 1:

- Sender holds `ss_sender`.
- Receiver holds `ss_receiver`.
- These values are equal and secret from eavesdroppers.


#### 5.2.1 Python Simulation of Sender and Receiver

The following code block explicitly simulates the **Receiver** and **Sender** roles using ML-KEM.
We will:

- Generate a Receiver key pair.
- Have the Sender encapsulate using the Receiver's public key.
- Send `ct` to the Receiver (simulated as passing a variable).
- Have the Receiver decapsulate and derive the shared secret.
- Check that both shared secrets match.


In [None]:
import oqs

print("[Receiver] Phase 0 – Setup: generating long-term key pair...")
receiver_kem = oqs.KeyEncapsulation("Kyber768")
receiver_pk = receiver_kem.generate_keypair()  # Public key as bytes
print("[Receiver] Public key and secret key generated.")
print("[Receiver] Public key length:", len(receiver_pk), "bytes\n")

print("[Sender] Phase 1 – Session Establishment: obtaining receiver's public key...")
# In a real protocol, the Sender would verify the authenticity of receiver_pk here.
sender_kem = oqs.KeyEncapsulation("Kyber768")
sender_view_of_pk = receiver_pk
print("[Sender] Public key obtained and trusted.\n")

print("[Sender] Encapsulating to derive a shared secret...")
try:
    # Use the correct pyoqs API method
    ct, ss_sender = sender_kem.encap_secret(sender_view_of_pk)
        
    print("[Sender] Encapsulation complete.")
    print("[Sender] Ciphertext length:", len(ct))
    print("[Sender] Shared secret length:", len(ss_sender), "bytes")
    print("[Sender] Shared secret:", ss_sender.hex()[:32], "...\n")

    print("[Sender] Sending ciphertext to Receiver (simulated by variable passing)...\n")

    print("[Receiver] Receiving ciphertext and performing decapsulation...")
    # Use the correct pyoqs API method
    ss_receiver = receiver_kem.decap_secret(ct)
        
    print("[Receiver] Decapsulation complete.")
    print("[Receiver] Shared secret length:", len(ss_receiver), "bytes")
    print("[Receiver] Shared secret:", ss_receiver.hex()[:32], "...\n")

    print("[Check] Do Sender and Receiver share the same secret?", ss_sender == ss_receiver)
    if ss_sender == ss_receiver:
        print("✅ Key establishment successful!")
    else:
        print("❌ Key establishment failed!")

except Exception as e:
    print(f"Error: {e}")
    print("Available methods on KEM object:")
    for method in dir(sender_kem):
        if 'encap' in method.lower() or 'decap' in method.lower():
            print(f"  - {method}")

### 5.3 Phase 2 – Secure Data Exchange (Key Derivation)

ML-KEM produces a **shared secret**, but it does not directly encrypt arbitrary data. Instead,
the shared secret is passed through a **Key Derivation Function (KDF)** to derive symmetric keys.

A common choice is HKDF (HMAC-based Key Derivation Function). Conceptually, we do:

- `k_app = HKDF(ss, info = "app data key")`
- `k_ctrl = HKDF(ss, info = "control channel key")`

These derived keys are then used for symmetric encryption with an AEAD cipher, such as AES-GCM or
ChaCha20-Poly1305.

The following code demonstrates how both Sender and Receiver would derive the same symmetric key
from the shared secret using HKDF. Since we're using `pyoqs`, we'll use Python's standard library
for the HKDF implementation.

In [None]:
import hashlib
import hmac

def hkdf_expand(prk: bytes, info: bytes, length: int) -> bytes:
    """HKDF-Expand function using SHA-256."""
    hash_len = hashlib.sha256().digest_size
    n = (length + hash_len - 1) // hash_len  # Ceiling division
    
    okm = b""
    previous = b""
    for i in range(n):
        previous = hmac.new(prk, previous + info + bytes([i + 1]), hashlib.sha256).digest()
        okm += previous
    
    return okm[:length]

def hkdf_extract(salt: bytes, ikm: bytes) -> bytes:
    """HKDF-Extract function using SHA-256."""
    if not salt:
        salt = b'\x00' * hashlib.sha256().digest_size
    return hmac.new(salt, ikm, hashlib.sha256).digest()

def derive_app_key(shared_secret: bytes, context: bytes, length: int = 32) -> bytes:
    """Derive a symmetric key from the ML-KEM shared secret using HKDF."""
    # Extract step
    prk = hkdf_extract(b"", shared_secret)
    # Expand step
    return hkdf_expand(prk, context, length)

context_info = b"mlkem-demo app key"

sender_app_key = derive_app_key(ss_sender, context_info)
receiver_app_key = derive_app_key(ss_receiver, context_info)

print("Sender app key:", sender_app_key.hex())
print("Receiver app key:", receiver_app_key.hex())
print("Keys match:", sender_app_key == receiver_app_key)

# Derive additional keys for different purposes
auth_key = derive_app_key(ss_sender, b"authentication key")
enc_key = derive_app_key(ss_sender, b"encryption key")

print(f"\nDerived keys:")
print(f"Authentication key: {auth_key.hex()[:16]}...")
print(f"Encryption key: {enc_key.hex()[:16]}...")

## 6. Simple Performance Measurement

Let us get a rough feeling for the performance of ML-KEM-768 on this machine. This is **not**
a rigorous benchmark, but it gives an order-of-magnitude estimate for key generation,
encapsulation, and decapsulation.


In [None]:
import time
import oqs

N = 100  # number of iterations for each operation

def bench_keygen(algorithm="Kyber768", n=N):
    """Benchmark key generation."""
    times = []
    for _ in range(n):
        kem = oqs.KeyEncapsulation(algorithm)
        t0 = time.perf_counter()
        _ = kem.generate_keypair()
        t1 = time.perf_counter()
        times.append(t1 - t0)
    return sum(times) / len(times)

def bench_encaps(algorithm="Kyber768", n=N):
    """Benchmark encapsulation."""
    # Setup: generate a key pair once
    kem_setup = oqs.KeyEncapsulation(algorithm)
    pk = kem_setup.generate_keypair()
    
    times = []
    for _ in range(n):
        kem = oqs.KeyEncapsulation(algorithm)
        t0 = time.perf_counter()
        # Use the correct pyoqs API method
        _ = kem.encap_secret(pk)
        t1 = time.perf_counter()
        times.append(t1 - t0)
    return sum(times) / len(times)

def bench_decaps(algorithm="Kyber768", n=N):
    """Benchmark decapsulation."""
    # Setup: generate a key pair and ciphertext once
    kem_setup = oqs.KeyEncapsulation(algorithm)
    pk = kem_setup.generate_keypair()
    kem_sender = oqs.KeyEncapsulation(algorithm)
    
    # Get ciphertext using the correct method
    ct, _ = kem_sender.encap_secret(pk)
    
    times = []
    for _ in range(n):
        t0 = time.perf_counter()
        # Use the correct pyoqs API method
        _ = kem_setup.decap_secret(ct)
        t1 = time.perf_counter()
        times.append(t1 - t0)
    return sum(times) / len(times)

print("Benchmarking ML-KEM (Kyber768) operations...")
print(f"Running {N} iterations each...\n")

try:
    keygen_time = bench_keygen()
    encaps_time = bench_encaps()
    decaps_time = bench_decaps()

    print(f"Average keygen time: {keygen_time*1e6:.1f} µs")
    print(f"Average encaps time: {encaps_time*1e6:.1f} µs")
    print(f"Average decaps time: {decaps_time*1e6:.1f} µs")

    # Compare with other Kyber variants if available
    print(f"\nComparing different Kyber variants:")
    variants = ["Kyber512", "Kyber768", "Kyber1024"]

    for variant in variants:
        try:
            # Test if we can create the KEM object
            test_kem = oqs.KeyEncapsulation(variant)
            kg_time = bench_keygen(variant, n=10)  # Fewer iterations for comparison
            print(f"{variant:10}: {kg_time*1e6:.1f} µs keygen")
        except Exception as e:
            print(f"{variant:10}: Not available ({str(e)[:30]}...)")

except Exception as e:
    print(f"Error during benchmarking: {e}")
    print("Make sure Kyber768 is available in your pyoqs installation.")

## 7. Mathematical Appendix (Light Version)

This section gives a **high-level** view of the math behind ML-KEM. It is not a full formal
proof, but enough to understand the main ideas.

### 7.1 Polynomial Rings and Parameters

ML-KEM works over a polynomial ring of the form:

$$ R_q = \mathbb{Z}_q[x] / (x^n + 1) $$

Typical parameters for Kyber/ML-KEM are:

- $n = 256$ (degree of polynomials)
- $q = 3329$ (modulus)
- Dimension parameter $k \in \{2, 3, 4\}$, depending on the security level

Elements of $R_q$ are polynomials of degree < $n$ with coefficients in
$\{0, 1, \dots, q-1\}$, and arithmetic is done modulo $x^n + 1$ and $q$.

### 7.2 Module-LWE Problem

The **Module Learning-With-Errors (Module-LWE)** problem is the core hardness assumption.

Given:

- A matrix $A \in R_q^{k \times k}$ (sampled from a suitable distribution)
- A secret vector $s \in R_q^k$ sampled from a small (error) distribution $\chi$
- An error vector $e \in R_q^k$ also sampled from $\chi$

We compute:

$$ t = A s + e \pmod{q} $$

The Module-LWE problem says that, given $(A, t)$, it is computationally hard (for both classical
and quantum adversaries) to recover $s$.

### 7.3 Key Generation

In ML-KEM key generation, we:

1. Sample a seed that determines $A$.
2. Sample $s, e \leftarrow \chi^k$.
3. Compute:

$$ t = A s + e \pmod{q}. $$

The public key is essentially $(A, t)$ (with some compression and hashing), and the secret key is
$s$ (plus some additional data like re-encoded public key and hashes).

### 7.4 Encapsulation

Encapsulation uses a fresh random vector $r$ and fresh noise terms $(e_1, e_2)$:

1. Sample $r, e_1, e_2$ from suitable noise distributions.
2. Compute:

$$ u = A r + e_1 \pmod{q}, $$
$$ v = t^T r + e_2 + \text{encode}(m) \pmod{q}, $$

for a message $m$ derived from random bits.

3. The ciphertext is $(u, v)$.
4. The shared secret is derived as:

$$ ss = \text{KDF}(m \parallel \text{hash}(u, v)). $$

### 7.5 Decapsulation

The holder of the secret key $s$ computes:

$$ v - u^T s = (t^T r + e_2 + \text{encode}(m)) - (A r + e_1)^T s. $$

Using $t = A s + e$, we can rewrite the right-hand side. The key intuition is that most of the
terms **cancel out**, leaving:

$$ v - u^T s \approx \text{encode}(m) + \text{small noise}. $$

Because the noise is small and controlled, a decoding step can recover $m$. Then the same KDF is
applied:

$$ ss = \text{KDF}(m \parallel \text{hash}(u, v)). $$

Thus both encapsulator and decapsulator derive the same shared secret `ss`.


## 8. Next Steps

From here you can:

- Integrate ML-KEM key establishment with a symmetric cipher (e.g., AES-GCM)
- Explore hybrid key exchange (ECDHE + ML-KEM)
- Experiment with different ML-KEM parameter sets (Kyber512, Kyber1024) available in pyoqs
- Try other post-quantum algorithms available in pyoqs (e.g., FrodoKEM, SIKE, etc.)
- Benchmark the algorithm on different hardware and platforms
- Explore other post-quantum signature schemes available in pyoqs (Dilithium, Falcon, etc.)

### Using pyoqs for Other Post-Quantum Algorithms

The `pyoqs` library provides access to many post-quantum algorithms. You can explore them:

```python
import oqs

# Try different methods to list available algorithms
if hasattr(oqs, 'get_enabled_KEM_mechanisms'):
    print("Available KEMs:", oqs.get_enabled_KEM_mechanisms())
elif hasattr(oqs, 'get_supported_kem_mechanisms'):
    print("Available KEMs:", oqs.get_supported_kem_mechanisms())

if hasattr(oqs, 'get_enabled_sig_mechanisms'):
    print("Available Signature schemes:", oqs.get_enabled_sig_mechanisms())
elif hasattr(oqs, 'get_supported_sig_mechanisms'):
    print("Available Signature schemes:", oqs.get_supported_sig_mechanisms())
```

Note: The exact method names may vary depending on your pyoqs version. Check the documentation for your specific version.