## Setup

### Parameters

In [82]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512, shake_256

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# Kyber Parameters
q = 3329
q_bytes = 16
k = 2
n = 256


rQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = rQ.gen()
f = x^n + 1
RQ = rQ.quotient(f)


DEBUG = True

### Helper Functions

In [83]:
def RandomList(length, cbd=False):
    out = [randrange(q) for i in range(length)]
    if cbd:
        out = FauxCbd(out)
    return out

def FauxCbd(r: list):
    out = []
    for i in r:
        out.append((i % 5) - 2) # Restrict to -2 <= n <= 2
    return out

def RandPolyUniform(length):
    return RQ(RandomList(length))

def RandPolyCbd(length) -> polynomial:
    return RQ(RandomList(length, cbd=True))

def RandListCbd(length) -> list:
    return RandomList(length, cbd=True)

def BytesNeed4Bits(bits: int) -> int:
    return ((bits+7) & (-8))//8

def RandInt(bits: int) -> int:
    m = Integer(int.from_bytes(urandom(BytesNeed4Bits(bits)), 'big'))
    m &= 2**n-1
    return m

def Poly2Bytes(poly: polynomial) -> bytes:
    out = b''
    p = poly.coefficients()[0].list()
    for c in p:
        c = int(c)
        cb = c.to_bytes(q_bytes, 'big')
        out += cb
    return out

def Bytes2ListBit(b: bytes) -> list:
    out = []
    for byte in b:
        for i in range(0,8):
            bit = (byte >> i) & 1
            out.append(bit)
    return out

def Compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def Decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]

def dbg(label: str, *args: str):
    if DEBUG:
        s = f'{label}:\n'
        for arg in args:
            s += f'{arg}\n'
        if len(args) == 0:
            s = s[:-2]
        print(s)

## INDCPA Public Key Encryption (K-PKE)

### K-PKE KeyGen

In [84]:
def KPKE_Keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')

    # Initialize
    A = [[[None] for _ in range(0, k)] for _ in range(0, k)]
    s = [[None] for _ in range(0, k)]
    e = [[None] for _ in range(0, k)]

    # A is a k*k dimension matrix of polynomials with n terms
    for i in range(0, k):
        for j in range(0, k):
            A[i][j] = RandPolyUniform(n)
    A = matrix(A)

    dbg('A', A)

    # s is a k*1 dimension matrix of polynomials with n terms
    for i in range(0, k):
        s[i] = [RandPolyCbd(n)]
    s = matrix(s)

    dbg('s', s)

    # e is a k*1 dimension matrix of polynomials with n terms
    for i in range(0, k):
        e[i] = [RandPolyCbd(n)]
    e = matrix(e)

    dbg('e', e)

#   Compute t = A*s*e:
#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms
#   t is a k*1 dimension matrix
#
#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |
#
#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |
#
#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    t = A*s+e

    dbg('t', t, '\n')

    return (A, t, s)

### K-PKE Encrypt

In [85]:
def KPKE_Encrypt(A: matrix, t: matrix, m: int, r: polynomial) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')

    # Initialize
    rr = [[None] for _ in range(0, k)]
    e1 = [[None] for _ in range(0, k)]
    e2 = [None] * n
    
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mb = m.bits()

    dbg('Bits of m', mb)

    # N is nonce used to deterministicly modify r
    N = 0

    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mb))]
    mbp = RQ(mb + pad)

    dbg('Polynomial m', mbp)

    # Compress m
    mbpc = Compress(mbp)

    dbg('Compressed m', mbpc)

    # Generate r, e1, e2
    # r is a k*1 matrix of polynomials with n terms
    for i in range(0, k):
        tpoly = [None] * n
        for j in range(0, n):
            tpoly[j] = r[j] + N
        tpoly = FauxCbd(tpoly)
        tpoly = RQ(tpoly)
        rr[i] = [tpoly]
        N += 1
    rr = matrix(rr)
    
    dbg('rr', rr)

    # e1 is a k*1 matrix of polynomials with n terms
    for i in range(0, k):
        tpoly = [None] * n
        for j in range(0, n):
            tpoly[j] = r[j] + N
        tpoly = FauxCbd(tpoly)
        tpoly = RQ(tpoly)
        e1[i] = [tpoly]
        N += 1
    e1 = matrix(e1)

    dbg('e1', e1)

    # e2 is an n-length polynomial with n terms
    for i in range(0, n):
        e2[i] = r[i] + N
    e2 = FauxCbd(e2)
    e2 = RQ(e2)

    dbg('e2', e2)

    u = A.transpose() * rr + e1
    v = t.transpose() * rr + e2 + mbpc

    dbg('u', u)
    dbg('v', v, '\n')

    return (u, v)

    

### K-PKE Decrypt

In [86]:
def KPKE_Decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    mn = mn.coefficients()[0]

    dbg('Noisy recovered m', mn)
   
    mn_c = mn.list()
    mn_c.reverse()

    # Decompress and remove the noise
    m_rec = Decompress(mn_c)

    dbg('Decompressed m', list(reversed(m_rec)), '\n')

    # Convert to integer
    m_rec = int(''.join([str(x) for x in m_rec]), 2)

    return m_rec

### K-PKE Test

In [87]:
# Alice generates a public key (A, t),
# and a private key s
A, t, s = KPKE_Keygen()

# Alice sends Bob her pk
# Bob chooses a random message m
# and encrypts it using Alice's pk
# and some randomness r to produce the ciphertext (u, v)
m = RandInt(n)
r = RandListCbd(n)
u, v = KPKE_Encrypt(A, t, m, r)

# Bob sends Alice (u, v).
# Alice can then recover the message m
mr = KPKE_Decrypt(u, v, s)

dbg('Alice\'s m', mr)
dbg('Bob\'s m', m)
dbg(m.bits())
if m != mr:
    raise ValueError('Alice and Bob\'s messages do not match, final decompression likely failed. Try increasing the value of the prime q')


===== kpke_keygen =====
A:
[3141*xbar^255 + 652*xbar^254 + 3187*xbar^253 + 566*xbar^252 + 2253*xbar^251 + 1661*xbar^250 + 1358*xbar^249 + 2880*xbar^248 + 654*xbar^247 + 1874*xbar^246 + 135*xbar^245 + 3024*xbar^244 + 2015*xbar^243 + 2395*xbar^242 + 665*xbar^241 + 524*xbar^240 + 1546*xbar^239 + 2944*xbar^238 + 1230*xbar^237 + 993*xbar^236 + 2479*xbar^235 + 3218*xbar^234 + 2404*xbar^233 + 2955*xbar^232 + 3214*xbar^231 + 2540*xbar^230 + 1526*xbar^229 + 2009*xbar^228 + 2950*xbar^227 + 2473*xbar^226 + 502*xbar^225 + 593*xbar^224 + 1406*xbar^223 + 90*xbar^222 + 2974*xbar^221 + 1336*xbar^220 + 2239*xbar^219 + 127*xbar^218 + 2372*xbar^217 + 160*xbar^216 + 468*xbar^215 + 1505*xbar^214 + 1672*xbar^213 + 1979*xbar^212 + 656*xbar^211 + 3104*xbar^210 + 126*xbar^209 + 2133*xbar^208 + 2204*xbar^207 + 2914*xbar^206 + 1101*xbar^205 + 2338*xbar^204 + 2119*xbar^203 + 2369*xbar^202 + 1312*xbar^201 + 2809*xbar^200 + 3322*xbar^199 + 1317*xbar^198 + 182*xbar^197 + 2303*xbar^196 + 2179*xbar^195 + 1214*xbar^194

# INDCCA Key Exchange Mechanism (ML-KEM)

## ML-KEM KeyGen

In [88]:
def MLKEM_KeyGen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    dbg('===== MLKEM_KeyGen =====')
    
    z = RandInt(n)
    A, t, s = KPKE_Keygen()
    ek = (A, t)
    Ht = sha3_256(Poly2Bytes(t)).digest()

    dbg('SHA3-256(t)', Ht.hex(), '\n')

    dk = (s, ek, Ht, z)

    return ek, dk

## ML-KEM Encaps

In [89]:
def MLKEM_Encaps(ek: (polynomial, polynomial)) -> (bytes, bytes):
    dbg('===== MLKEM_Encaps =====')

    m = RandInt(n)
    dbg('m', m)

    A, t = ek
    Ht = sha3_256(Poly2Bytes(t)).digest()

    dbg('SHA3-256(t)', Ht.hex())

    Kr = sha3_512(int(m).to_bytes(BytesNeed4Bits(n), 'big') + Ht).digest()
    K, r = (Kr[:BytesNeed4Bits(n)], Kr[BytesNeed4Bits(n):])
    r = FauxCbd(Bytes2ListBit(r))

    dbg('r', r, '\n')
    
    c = KPKE_Encrypt(A, t, m, r)

    return K, c

## ML-KEM Decaps

In [90]:
def MLKEM_Decaps(c: (polynomial, polynomial), dk: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    dbg('===== MLKEM_Decaps =====')
    s, ek, h, _ = dk
    A, t = ek
    u, v = c

    mprime = KPKE_Decrypt(u, v, s)

    dbg('m\'', mprime)

    Krprime = sha3_512(int(mprime).to_bytes(BytesNeed4Bits(n), 'big') + h).digest()
    Kprime, rprime = (Krprime[:BytesNeed4Bits(n)], Krprime[BytesNeed4Bits(n):])
    rprime = FauxCbd(Bytes2ListBit(rprime))

    dbg('r\'', rprime)

    uprime, vprime = KPKE_Encrypt(A, t, Integer(mprime), rprime)

    dbg('u\'', uprime)
    dbg('v\'', vprime, '\n')

    return Kprime



## ML-KEM Test

In [91]:
# Alice runs KeyGen()
pkA, skA = MLKEM_KeyGen()

# Bob receives pkA from Alice,
# then runs Encaps() to generate
# his copy of the shared secret ssB, 
# and a ciphertext c
ssB, c = MLKEM_Encaps(pkA)

# Bob sends c to Alice,
# who then uses her secret key
# to generate her copy of
# the shared secret ssA
ssA = MLKEM_Decaps(c, skA)

dbg('Bob\'s Shared Secret', ssB.hex())
dbg('Alice\'s Shared Secret', ssA.hex())
dbg('ssA == ssB?', ssA == ssB)
if(ssA != ssB):
    raise ValueError('The shared keys do not match!')

===== MLKEM_KeyGen =====
===== kpke_keygen =====


A:
[                    413*xbar^255 + 1099*xbar^254 + 3100*xbar^253 + 1426*xbar^252 + 1317*xbar^251 + 1979*xbar^250 + 3280*xbar^249 + 2183*xbar^248 + 2512*xbar^247 + 2606*xbar^246 + 3121*xbar^245 + 400*xbar^244 + 3166*xbar^243 + 17*xbar^242 + 2553*xbar^241 + 2316*xbar^240 + 657*xbar^239 + 664*xbar^238 + 778*xbar^237 + 1731*xbar^236 + 1039*xbar^235 + 1342*xbar^234 + 2018*xbar^233 + 2783*xbar^232 + 3288*xbar^231 + 1362*xbar^230 + 1739*xbar^229 + 2888*xbar^228 + 2752*xbar^227 + 670*xbar^226 + 1760*xbar^225 + 1476*xbar^224 + 2875*xbar^223 + 1502*xbar^222 + 3070*xbar^221 + 874*xbar^220 + 1645*xbar^219 + 2894*xbar^218 + 1534*xbar^217 + 1388*xbar^216 + 1524*xbar^215 + 1220*xbar^214 + 1079*xbar^213 + 1296*xbar^212 + 2647*xbar^211 + 3307*xbar^210 + 1311*xbar^209 + 2574*xbar^208 + 47*xbar^207 + 1174*xbar^206 + 2425*xbar^205 + 2895*xbar^204 + 1120*xbar^203 + 1898*xbar^202 + 1776*xbar^201 + 2318*xbar^200 + 525*xbar^199 + 99*xbar^198 + 1420*xbar^197 + 3039*xbar^196 + 2658*xbar^195 + xbar^194 + 611

# Key Exchange Tests

## Init

In [92]:
DEBUG = False

# Alice and Bob create independent *static* (pk, sk) pairs
pkA, skA = MLKEM_KeyGen()
pkB, skB = MLKEM_KeyGen()

## Unilaterally Authenticated Key Exchange (UAKE)

In [93]:
# Alice generates a new set of temporary keys
tpkA, tskA = MLKEM_KeyGen()
# Alice encapsulates Bob's static public key (pkB)
# to generare a ciphertext and shared secret (tcA, tssA) to send to Bob
tssA, cA = MLKEM_Encaps(pkB)

# Alice sends (tpkA, cA) to Bob
# Bob encapsulates Alice's temporary public key (tpkA)
# to generare a ciphertext and shared secret (cB, tssB)
# to send to Alice
tssB, cB = MLKEM_Encaps(tpkA)
# Bob decapsulates Alice's ciphertext to produce tssAprime
tssAprime = MLKEM_Decaps(cA, skB)
# Bob hashes tssB and tssAprime to create his copy of the final shared key (ssB)
ssB = shake_256(tssB + tssAprime).digest(32)

# Bob sends Alice his ciphertext tcB
# Alice decapsulates Bob's tcB using her temporary secret key
# to recover Bob's temporary shared secret
tssBprime = MLKEM_Decaps(cB, tskA)
# Alice hashes tssBprime and tssA to produce her copy of the final shared key (ssA)
ssA = shake_256(tssBprime + tssA).digest(32)

print('Alice\'s SS:')
print(ssA.hex())
print()
print('Bob\'s SS:')
print(ssB.hex())
print()
print('Length of shared secret:', len(ssA), 'bytes')
print()
print('ssA == ssB?', ssA == ssB)

Alice's SS:
cb6638655ba084af69a820a10399ef33748337df66e54c0ce222c698d9c6004d

Bob's SS:
cb6638655ba084af69a820a10399ef33748337df66e54c0ce222c698d9c6004d

Length of shared secret: 32 bytes

ssA == ssB? True


## Mutually Authenticated Key Exchange (AKE)

In [94]:
# Alice generates a new set of temporary keys
tpkA, tskA = MLKEM_KeyGen()
# Alice encapsulates Bob's static public key (pkB)
# to generare a ciphertext and shared secret (tcA, tSSA) to send to Bob
tssA, cA = MLKEM_Encaps(pkB)

# Alice sends (tpkA, cA) to Bob
# Bob encapsulates Alice's temporary public key (tpkA)
# to generare a ciphertext and shared secret (tcB, tssB)
tssB, cB = MLKEM_Encaps(tpkA)
# Bob encapsulates Alice's static public key (pkA)
# to generare a second ciphertext and shared secret (tcB2, tssB2)
tssB2, cB2 = MLKEM_Encaps(pkA)
# Bob decapsulates Alice's tcA using his static secret key (skB)
# to recover Alice's temporary shared secret
tssAprime = MLKEM_Decaps(cA, skB)
# Bob hashes tssB2, tssB2, and tssAprime to create his copy of the final shared key (ssB)
ssB = shake_256(tssB + tssB2 + tssAprime).digest(32)

# Bob sends tcB, tcB2 to Alice
# Alice decapsulates Bob's cB using her temprary secret key (tskA)
tssBprime = MLKEM_Decaps(cB, tskA)
# Alice decapsulates Bob's cB2 using her static secret key (skA)
tssB2prime = MLKEM_Decaps(cB2, skA)
# Alice hashes tssBprime, tssB2prime, and tssA to create her copy of the final shared key (ssA)
ssA = shake_256(tssBprime + tssB2prime + tssA).digest(32)

print('Alice\'s SS:')
print(ssA.hex())
print()
print('Bob\'s SS:')
print(ssB.hex())
print()
print('Length of shared secret:', len(ssA), 'bytes')
print()
print('ssA == ssB?', ssA == ssB)


Alice's SS:
7e01cad2813afe05f3c4dc0f542986bb3bcc131953674be7d82dc303d17cac42

Bob's SS:
7e01cad2813afe05f3c4dc0f542986bb3bcc131953674be7d82dc303d17cac42

Length of shared secret: 32 bytes

ssA == ssB? True
