## Setup

### Parameters

In [737]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# Kyber Parameters
q = 3329
q_bytes = 16
k = 2
n = 4


rQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = rQ.gen()
f = x^n + 1
RQ = rQ.quotient(f)


DEBUG = True

### Helper Functions

In [738]:
def RandomList(length, cbd=False):
    out = [randrange(q) for i in range(length)]
    if cbd:
        out = FauxCbd(out)
    return out

def FauxCbd(r: list):
    out = []
    for i in r:
        out.append((i % 5) - 2) # Restrict to -2 <= n <= 2
    return out

def RandPolyUniform(length):
    return RQ(RandomList(length))

def RandPolyCbd(length) -> polynomial:
    return RQ(RandomList(length, cbd=True))

def RandListCbd(length) -> list:
    return RandomList(length, cbd=True)

def BytesNeed4Bits(bits: int) -> int:
    return ((bits+7) & (-8))//8

def RandInt(bits: int) -> int:
    m = Integer(int.from_bytes(urandom(BytesNeed4Bits(bits)), 'big'))
    m &= 2**n-1
    return m

def Poly2Bytes(poly: polynomial) -> bytes:
    out = b''
    p = poly.coefficients()[0].list()
    for c in p:
        c = int(c)
        cb = c.to_bytes(q_bytes, 'big')
        out += cb
    return out

def Bytes2ListInt(b: bytes) -> list:
    out = []
    for byte in b:
        out.append(Integer(byte))
    return out

def Compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def Decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]

def dbg(label: str, *args: str):
    if DEBUG:
        s = f'{label}:\n'
        for arg in args:
            s += f'{arg}\n'
        if len(args) == 0:
            s = s[:-2]
        print(s)

## INDCPA Key Exchange (K-PKE)

### K-PKE KeyGen

In [739]:
def KPKE_Keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')

    # Initialize
    A = [[[None] for _ in range(0, k)] for _ in range(0, k)]
    s = [[None] for _ in range(0, k)]
    e = [[None] for _ in range(0, k)]

    # A is a k*k dimension matrix of polynomials with n terms
    for i in range(0, k):
        for j in range(0, k):
            A[i][j] = RandPolyUniform(n)
    A = matrix(A)

    dbg('A', A)

    # s is a k*1 dimension matrix of polynomials with n terms
    for i in range(0, k):
        s[i] = [RandPolyCbd(n)]
    s = matrix(s)

    dbg('s', s)

    # e is a k*1 dimension matrix of polynomials with n terms
    for i in range(0, k):
        e[i] = [RandPolyCbd(n)]
    e = matrix(e)

    dbg('e', e)

#   Compute t = A*s*e:
#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms
#   t is a k*1 dimension matrix
#
#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |
#
#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |
#
#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    t = A*s+e

    dbg('t', t, '\n')

    return (A, t, s)

### K-PKE Encrypt

In [740]:
def KPKE_Encrypt(A: matrix, t: matrix, m: int, r: polynomial) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')

    # Initialize
    rr = [[None] for _ in range(0, k)]
    e1 = [[None] for _ in range(0, k)]
    e2 = [None] * n
    
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mb = m.bits()

    dbg('Bits of m', mb)

    # N is nonce used to deterministicly modify r
    N = 0

    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mb))]
    mbp = RQ(mb + pad)

    dbg('Polynomial m', mbp)

    # Compress m
    mbpc = Compress(mbp)

    dbg('Compressed m', mbpc)

    # Generate r, e1, e2
    # r is a k*1 matrix of polynomials with n terms
    for i in range(0, k):
        tpoly = [None] * n
        for j in range(0, n):
            tpoly[j] = r[j] + N
        tpoly = FauxCbd(tpoly)
        tpoly = RQ(tpoly)
        rr[i] = [tpoly]
        N += 1
    rr = matrix(rr)

    dbg('rr', rr)

    # e1 is a k*1 matrix of polynomials with n terms
    for i in range(0, k):
        tpoly = [None] * n
        for j in range(0, n):
            tpoly[j] = r[j] + N
        tpoly = FauxCbd(tpoly)
        tpoly = RQ(tpoly)
        e1[i] = [tpoly]
        N += 1
    e1 = matrix(e1)

    dbg('e1', e1)

    # e2 is an n-length polynomial with n terms
    for i in range(0, n):
        e2[i] = r[i] + N
    e2 = FauxCbd(e2)
    e2 = RQ(e2)

    dbg('e2', e2)

    u = A.transpose() * rr + e1
    v = t.transpose() * rr + e2 + mbpc

    dbg('u', u)
    dbg('v', v, '\n')

    return (u, v)

    

### K-PKE Decrypt

In [741]:
def KPKE_Decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    mn = mn.coefficients()[0]

    dbg('Noisy recovered m', mn)
   
    mn_c = mn.list()
    mn_c.reverse()

    # Decompress and remove the noise
    m_rec = Decompress(mn_c)

    dbg('Decompressed m', list(reversed(m_rec)), '\n')

    # Convert to integer
    m_rec = int(''.join([str(x) for x in m_rec]), 2)

    return m_rec

### INDCPA Key Exchange Test

In [742]:
# Alice generates a public key (A, t),
# and a private key s
A, t, s = KPKE_Keygen()

# Alice sends Bob her pk
# Bob chooses a random message m
# and encrypts it using Alice's pk
# and some randomness r to produce the ciphertext (u, v)
m = RandInt(n)
r = RandListCbd(n)
u, v = KPKE_Encrypt(A, t, m, r)

# Bob sends Alice (u, v).
# Alice can then recover the message m
mr = KPKE_Decrypt(u, v, s)

dbg('Alice\'s m', mr)
dbg('Bob\'s m', m)
dbg(m.bits())
if m != mr:
    raise ValueError('Alice and Bob\'s messages do not match, final decompression likely failed. Try increasing the value of the prime q')


===== kpke_keygen =====
A:
[2556*xbar^3 + 1001*xbar^2 + 1202*xbar + 2075   1232*xbar^3 + 331*xbar^2 + 2782*xbar + 646]
[ 1989*xbar^3 + 2319*xbar^2 + 173*xbar + 1312 1353*xbar^3 + 1028*xbar^2 + 1229*xbar + 3171]

s:
[          xbar^3 + xbar^2 + 3328*xbar + 1]
[2*xbar^3 + 3327*xbar^2 + 3328*xbar + 3327]

e:
[2*xbar^2 + 1]
[           2]

t:
[ 1094*xbar^3 + 2107*xbar^2 + 1149*xbar + 796]
[1305*xbar^3 + 2452*xbar^2 + 2890*xbar + 2078]



===== kpke_encrypt =====
Bits of m:
[0, 1]

Polynomial m:
xbar

Compressed m:
1665*xbar

rr:
[xbar^2 + 3328*xbar + 1]
[ xbar^3 + 2*xbar^2 + 2]

e1:
[     2*xbar^3 + 3327*xbar^2 + xbar + 3327]
[3327*xbar^3 + 3328*xbar^2 + 2*xbar + 3328]

e2:
3328*xbar^3 + 3327*xbar

u:
[  1737*xbar^3 + 487*xbar^2 + 608*xbar + 1441]
[2029*xbar^3 + 1910*xbar^2 + 2959*xbar + 1274]

v:
[616*xbar^3 + 2851*xbar^2 + 1640*xbar + 2803]



===== kpke_decrypt =====
Noisy recovered m:
3326*xbar^3 + 8*xbar^2 + 1664*xbar + 8

Decompressed m:
[0, 1, 0, 0]



Alice's m:
2

Bob's m:
2

[0, 

# INDCCA Encapsulation (ML-KEM)

## ML-KEM KeyGen

In [743]:
def MLKEM_KeyGen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    dbg('===== MLKEM_KeyGen =====')
    
    z = RandInt(n)
    A, t, s = KPKE_Keygen()
    ek = (A, t)
    Ht = sha3_256(Poly2Bytes(t)).digest()

    dbg('SHA3-256(t)', Ht.hex(), '\n')

    dk = (s, ek, Ht, z)

    return ek, dk

## ML-KEM Encaps

In [744]:
def MLKEM_Encaps(ek: (polynomial, polynomial)) -> (bytes, bytes):
    dbg('===== MLKEM_Encaps =====')

    m = RandInt(n)
    dbg('m', m)

    A, t = ek
    Ht = sha3_256(Poly2Bytes(t)).digest()

    dbg('SHA3-256(t)', Ht.hex())

    Kr = sha3_512(int(m).to_bytes(BytesNeed4Bits(n), 'big') + Ht).digest()
    K, r = (Kr[:BytesNeed4Bits(n)], Kr[BytesNeed4Bits(n):])
    r = FauxCbd(Bytes2ListInt(r))

    dbg('r', r, '\n')
    
    c = KPKE_Encrypt(A, t, m, r)

    return K, c

## ML-KEM Decaps

In [745]:
def MLKEM_Decaps(c: (polynomial, polynomial), dk: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    dbg('===== MLKEM_Decaps =====')
    s, ek, h, _ = dk
    A, t = ek
    u, v = c

    mprime = KPKE_Decrypt(u, v, s)

    dbg('m\'', mprime)

    Krprime = sha3_512(int(mprime).to_bytes(BytesNeed4Bits(n), 'big') + h).digest()
    Kprime, rprime = (Krprime[:BytesNeed4Bits(n)], Krprime[BytesNeed4Bits(n):])
    rprime = FauxCbd(Bytes2ListInt(rprime))

    dbg('r\'', rprime)

    uprime, vprime = KPKE_Encrypt(A, t, Integer(mprime), rprime)

    dbg('u\'', uprime)
    dbg('v\'', vprime, '\n')

    return Kprime



## ML-KEM Scheme Test

In [746]:
# Alice runs KeyGen()
pkA, skA = MLKEM_KeyGen()

# Bob receives pkA from Alice,
# then runs Encaps() to generate
# his copy of the shared secret kB, 
# and a ciphertext c
ssB, c = MLKEM_Encaps(pkA)

# Bob sends c to Alice,
# who then uses her secret key
# to generate her copy of
# the shared secret kA
ssA = MLKEM_Decaps(c, skA)

dbg('Bob\'s Shared Secret', ssB.hex())
dbg('Alice\'s Shared Secret', ssA.hex())
dbg('ssA == ssB?', ssA == ssB)
if(ssA != ssB):
    raise ValueError('The shared keys do not match!')

===== MLKEM_KeyGen =====
===== kpke_keygen =====
A:
[1283*xbar^3 + 2190*xbar^2 + 2391*xbar + 1905    325*xbar^3 + 2639*xbar^2 + 2578*xbar + 57]
[ 950*xbar^3 + 1004*xbar^2 + 1001*xbar + 2515  3323*xbar^3 + 1367*xbar^2 + 2018*xbar + 907]

s:
[3328*xbar^3 + 2*xbar^2 + 1]
[  3327*xbar^3 + 2*xbar + 2]

e:
[3327*xbar^3 + 3328*xbar + 2]
[            2*xbar^3 + 3328]

t:
[3314*xbar^3 + 1722*xbar^2 + 2575*xbar + 1209]
[  1347*xbar^3 + 426*xbar^2 + 2031*xbar + 711]



SHA3-256(t):
5879da4086343fb73b8fb59320cebdb7259cc22c26bb0f40c80c3c4c1af41fb9



===== MLKEM_Encaps =====
m:
14

SHA3-256(t):
5879da4086343fb73b8fb59320cebdb7259cc22c26bb0f40c80c3c4c1af41fb9

r:
[-1, -2, 2, 1, 1, -1, 0, 0, 1, 1, 0, -1, 1, 2, -2, -2, 2, -1, -2, -1, 2, -2, 0, -2, -2, -1, 1, 0, 2, 1, -2, 2, -2, 2, 1, 0, -2, -2, 0, 2, 0, 1, 0, 2, 0, 1, 0, 1, 0, 0, 0, -1, -2, 1, 1, -2, -1, 1, -2, 2, -2, -1, -2]



===== kpke_encrypt =====
Bits of m:
[0, 1, 1, 1]

Polynomial m:
xbar^3 + xbar^2 + xbar

Compressed m:
1665*xbar^3 + 1665*xba