## Setup

### Parameters

In [190]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# Kyber Parameters
q = 3329
q_bytes = 16
k = 2
n = 256


rQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = rQ.gen()
f = x^n + 1
RQ = rQ.quotient(f)


DEBUG = True

### Helper Functions

In [191]:
def RandomList(length, cbd=False):
    out = [randrange(q) for i in range(length)]
    if cbd:
        out = FauxCbd(out)
    return out

def FauxCbd(r: list):
    out = []
    for i in r:
        out.append((i % 5) - 2) # Restrict to -2 <= n <= 2
    return out

def RandPolyUniform(length):
    return RQ(RandomList(length))

def RandPolyCbd(length):
    return RQ(RandomList(length, cbd=True))

def BytesNeed4Bits(bits: int) -> int:
    return ((bits+7) & (-8))//8

def RandInt(bits: int) -> int:
    m = Integer(int.from_bytes(urandom(BytesNeed4Bits(bits)), 'big'))
    m &= 2**n-1
    return m

def Poly2Bytes(poly: polynomial) -> bytes:
    out = b''
    p = poly.coefficients()[0].list()
    for c in p:
        c = int(c)
        cb = c.to_bytes(q_bytes, 'big')
        out += cb
    return out

def Bytes2ListInt(b: bytes) -> list:
    out = []
    for byte in b:
        out.append(Integer(byte))
    return out

def Compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def Decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]
    # return int(''.join([str(x) for x in dpoly]), 2)

def dbg(s: str = ''):
    if DEBUG:
        print(s)

## INDCPA Key Exchange (K-PKE)

### K-PKE KeyGen

In [192]:
def KPKE_Keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')
    # A is a k*k dimension matrix of polynomials with n terms
    A = []
    for _ in range(0, k):
        tA = []
        for _ in range(0, k):
            tA.append(RandPolyUniform(n))
        A.append(tA)
    A = matrix(A)
    dbg('A:')
    dbg(A)
    
    # s is a k*1 dimension matrix of polynomials with n terms
    s = [[RandPolyCbd(n)] for _ in range(0, k)]
    s = matrix(s)
    dbg('s:')
    dbg(s)

    # e is a k*1 dimension matrix of polynomials with n terms
    e = [[RandPolyCbd(n)] for _ in range(0, k)]
    e = matrix(e)
    dbg('e:')
    dbg(e)

#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms

#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |

#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |

#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    # compute t = A*s*e
    # t is a k*1 dimension matrix
    t = A*s+e
    dbg()

    return (A, t, s)

### K-PKE Encrypt

In [193]:
def KPKE_Encrypt(A: matrix, t: matrix, m: int, r: polynomial) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mm = m.bits()
    dbg('Bits of m:')
    dbg(mm)

    N = 0
    
    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mm))]
    mm = RQ(mm + pad)
    dbg('Polynomial m:')
    dbg(mm)
    mm = Compress(mm)
    dbg('Compressed m:')
    dbg(mm)

    # Generate r, e1, e2
    # r is a k*1 matrix
    rr = [[RQ(FauxCbd([x + i for x in r]))] for i in range(N,N+k)]
    rr = matrix(rr)
    N += k
    dbg('rr:')
    dbg(rr)

    # e1 is a k*1 matrix
    e1 = [[RQ(FauxCbd([x + i for x in r]))] for i in range(N,N+k)]
    e1 = matrix(e1)
    N += k
    dbg('e1:')
    dbg(e1)

    # e2 is an n-length polynomial
    e2 =  RQ(FauxCbd([x + N for x in r]))
    dbg('e2:')
    dbg(e2)

    u = A.transpose() * rr + e1
    v = t.transpose() * rr + e2 + mm

    dbg('u:')
    dbg(u)
    dbg('v:')
    dbg(v)
    dbg()

    return (u, v)

    

### K-PKE Decrypt

In [194]:
def KPKE_Decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    mn = mn.coefficients()[0]
    dbg('Noisy recovered m:')
    dbg(mn)
   
    mn_c = mn.list()
    mn_c.reverse()

    m_rec = Decompress(mn_c)
    dbg('Decompressed m:')
    dbg(m_rec)

    m_rec = int(''.join([str(x) for x in m_rec]), 2)
    dbg('Recovered m:')
    dbg(m_rec)
    dbg()

    return m_rec

### INDCPA Key Test

In [195]:
from os import urandom
m = RandInt(n)
A, t, s = KPKE_Keygen()
r = FauxCbd(RandomList(n))
u, v = KPKE_Encrypt(A, t, m, r)
mr = KPKE_Decrypt(u, v, s)
dbg('Original m:')
dbg(m)
dbg(m.bits())
if m != mr:
    raise ValueError('decrypted m does not match, final decompression likely failed')

===== kpke_keygen =====
A:
[                  1221*xbar^255 + 2697*xbar^254 + 1728*xbar^253 + 1003*xbar^252 + 2*xbar^251 + 722*xbar^250 + 900*xbar^249 + 2089*xbar^248 + 160*xbar^247 + 2794*xbar^246 + 1087*xbar^245 + 2000*xbar^244 + 88*xbar^243 + 1185*xbar^242 + 1845*xbar^241 + 3170*xbar^240 + 2698*xbar^239 + 712*xbar^238 + 1612*xbar^237 + 914*xbar^236 + 2972*xbar^235 + 1641*xbar^234 + 520*xbar^233 + 3175*xbar^232 + 555*xbar^231 + 2653*xbar^230 + 31*xbar^229 + 2807*xbar^228 + 1221*xbar^227 + 880*xbar^226 + 2164*xbar^225 + 2722*xbar^224 + 2882*xbar^223 + 383*xbar^222 + 489*xbar^221 + 473*xbar^220 + 1837*xbar^219 + 1191*xbar^218 + 3227*xbar^217 + 2294*xbar^216 + 159*xbar^215 + 1051*xbar^214 + 160*xbar^213 + 1710*xbar^212 + 1222*xbar^211 + 2243*xbar^210 + 340*xbar^209 + 2432*xbar^208 + 1018*xbar^207 + 2273*xbar^206 + 2793*xbar^205 + 42*xbar^204 + 1352*xbar^203 + 1760*xbar^202 + 3129*xbar^201 + 2899*xbar^200 + 867*xbar^199 + 1837*xbar^198 + 1570*xbar^197 + 696*xbar^196 + 2706*xbar^195 + 785

# INDCCA Encapsulation (ML-KEM)

## ML-KEM KeyGen

In [196]:
def MLKEM_KeyGen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    dbg('===== MLKEM_KeyGen =====')
    z = RandInt(n)
    A, t, s = KPKE_Keygen()
    ek = (A, t)
    Ht = sha3_256(Poly2Bytes(t)).digest()
    dbg('SHA3-256(t):')
    dbg(Ht.hex())
    dk = (s, ek, Ht, z)
    dbg()
    return ek, dk

## ML-KEM Encaps

In [197]:
def MLKEM_Encaps(ek: (polynomial, polynomial)) -> (bytes, bytes):
    dbg('===== MLKEM_Encaps =====')
    m = RandInt(n)
    dbg('m:')
    dbg(m)
    A, t = ek
    Ht = sha3_256(Poly2Bytes(t)).digest()
    dbg('SHA3-256(t):')
    dbg(Ht.hex())
    Kr = sha3_512(int(m).to_bytes(BytesNeed4Bits(n), 'big') + Ht).digest()
    K, r = (Kr[:BytesNeed4Bits(n)], Kr[BytesNeed4Bits(n):])
    r = FauxCbd(Bytes2ListInt(r))
    dbg('r')
    dbg(r)
    dbg()
    
    c = KPKE_Encrypt(A, t, m, r)
    dbg()
    return K, c

## ML-KEM Decaps

In [198]:
def MLKEM_Decaps(c: (polynomial, polynomial), dk: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    dbg('===== MLKEM_Decaps =====')
    s, ek, h, _ = dk
    A, t = ek
    u, v = c

    mprime = KPKE_Decrypt(u, v, s)
    dbg('m\':')
    dbg(mprime)
    Krprime = sha3_512(int(mprime).to_bytes(BytesNeed4Bits(n), 'big') + h).digest()
    _, rprime = (Krprime[:BytesNeed4Bits(n)], Krprime[BytesNeed4Bits(n):])
    rprime = FauxCbd(Bytes2ListInt(rprime))
    dbg('r\':')
    dbg(rprime)

    dbg()
    uprime, vprime = KPKE_Encrypt(A, t, Integer(mprime), rprime)
    dbg('u\':')
    dbg(uprime)
    dbg('v\':')
    dbg(uprime)
    dbg()
    assert(u == uprime)
    assert(v == vprime)



## ML-Kem Scheme Test

In [199]:
ek, dk = MLKEM_KeyGen()
K, c = MLKEM_Encaps(ek)
MLKEM_Decaps(c, dk)

===== MLKEM_KeyGen =====
===== kpke_keygen =====
A:
[     2465*xbar^255 + 3163*xbar^254 + 1879*xbar^253 + 1873*xbar^252 + 1907*xbar^251 + 137*xbar^250 + 830*xbar^249 + 1036*xbar^248 + 2291*xbar^247 + 2096*xbar^246 + 417*xbar^245 + 3007*xbar^244 + 800*xbar^243 + 2927*xbar^242 + 790*xbar^241 + 1465*xbar^240 + 1687*xbar^239 + 2745*xbar^238 + 2230*xbar^237 + 660*xbar^236 + 3176*xbar^235 + 368*xbar^234 + 2942*xbar^233 + 1300*xbar^232 + 1948*xbar^231 + 2716*xbar^230 + 2904*xbar^229 + 2683*xbar^228 + 1238*xbar^227 + 1074*xbar^226 + 2642*xbar^225 + 1827*xbar^224 + 315*xbar^223 + 1554*xbar^222 + 532*xbar^221 + 1723*xbar^220 + 400*xbar^219 + 3173*xbar^218 + 1066*xbar^217 + 1669*xbar^216 + 3315*xbar^215 + 2983*xbar^214 + 1784*xbar^213 + 2843*xbar^212 + 1618*xbar^211 + 2752*xbar^210 + 2661*xbar^209 + 3180*xbar^208 + 1997*xbar^207 + 573*xbar^206 + 1628*xbar^205 + 716*xbar^204 + 1361*xbar^203 + 1836*xbar^202 + 2508*xbar^201 + 1538*xbar^200 + 1533*xbar^199 + 1320*xbar^198 + 2160*xbar^197 + 3108*xbar^