## Setup

### Parameters

In [180]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# Kyber Parameters
q = 3329
q_bytes = 16
k = 2
n = 256


rQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = rQ.gen()
f = x^n + 1
RQ = rQ.quotient(f)
print(type(RQ))


DEBUG = True

<class 'sage.rings.polynomial.polynomial_quotient_ring.PolynomialQuotientRing_generic_with_category'>


### Helper Functions

In [181]:
def RandomList(length, cbd=False):
    out = [randrange(q) for i in range(length)]
    if cbd:
        out = FauxCbd(out)
    return out

def FauxCbd(r: list):
    out = []
    for i in r:
        out.append((i % 5) - 2) # Restrict to -2 <= n <= 2
    return out

def RandPolyUniform(length):
    return RQ(RandomList(length))

def RandPolyCbd(length):
    return RQ(RandomList(length, cbd=True))

def BytesNeed4Bits(bits: int) -> int:
    return ((bits+7) & (-8))//8

def RandInt(bits: int) -> int:
    m = Integer(int.from_bytes(urandom(BytesNeed4Bits(bits)), 'big'))
    m &= 2**n-1
    return m

def Poly2Bytes(poly: polynomial) -> bytes:
    out = b''
    p = poly.coefficients()[0].list()
    for c in p:
        c = int(c)
        cb = c.to_bytes(q_bytes, 'big')
        out += cb
    return out

def Bytes2ListInt(b: bytes) -> list:
    out = []
    for byte in b:
        out.append(Integer(byte))
    return out

def Compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def Decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]
    # return int(''.join([str(x) for x in dpoly]), 2)

def dbg(s: str = ''):
    if DEBUG:
        print(s)

## INDCPA Key Exchange (K-PKE)

### K-PKE KeyGen

In [182]:
def KPKE_Keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')
    # A is a k*k dimension matrix of polynomials with n terms
    A = []
    for _ in range(0, k):
        tA = []
        for _ in range(0, k):
            tA.append(RandPolyUniform(n))
        A.append(tA)
    A = matrix(A)
    dbg('A:')
    dbg(A)
    
    # s is a k*1 dimension matrix of polynomials with n terms
    s = [[RandPolyCbd(n)] for _ in range(0, k)]
    s = matrix(s)
    dbg('s:')
    dbg(s)

    # e is a k*1 dimension matrix of polynomials with n terms
    e = [[RandPolyCbd(n)] for _ in range(0, k)]
    e = matrix(e)
    dbg('e:')
    dbg(e)

#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms

#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |

#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |

#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    # compute t = A*s*e
    # t is a k*1 dimension matrix
    t = A*s+e
    dbg()

    return (A, t, s)

### K-PKE Encrypt

In [183]:
def KPKE_Encrypt(A: matrix, t: matrix, m: int, r: polynomial) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mm = m.bits()
    dbg('Bits of m:')
    dbg(mm)

    N = 0
    
    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mm))]
    mm = RQ(mm + pad)
    dbg('Polynomial m:')
    dbg(mm)
    mm = Compress(mm)
    dbg('Compressed m:')
    dbg(mm)

    # Generate r, e1, e2
    # r is a k*1 matrix
    rr = [[RQ(FauxCbd([x + i for x in r]))] for i in range(N,N+k)]
    rr = matrix(rr)
    N += k
    dbg('rr:')
    dbg(rr)

    # e1 is a k*1 matrix
    e1 = [[RQ(FauxCbd([x + i for x in r]))] for i in range(N,N+k)]
    e1 = matrix(e1)
    N += k
    dbg('e1:')
    dbg(e1)

    # e2 is an n-length polynomial
    e2 =  RQ(FauxCbd([x + N for x in r]))
    dbg('e2:')
    dbg(e2)

    u = A.transpose() * rr + e1
    v = t.transpose() * rr + e2 + mm

    dbg('u:')
    dbg(u)
    dbg('v:')
    dbg(v)
    dbg()

    return (u, v)

    

### K-PKE Decrypt

In [184]:
def KPKE_Decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    mn = mn.coefficients()[0]
    dbg('Noisy recovered m:')
    dbg(mn)
   
    mn_c = mn.list()
    mn_c.reverse()

    m_rec = Decompress(mn_c)
    dbg('Decompressed m:')
    dbg(m_rec)

    m_rec = int(''.join([str(x) for x in m_rec]), 2)
    dbg('Recovered m:')
    dbg(m_rec)
    dbg()

    return m_rec

### INDCPA Key Test

In [185]:
from os import urandom
m = RandInt(n)
A, t, s = KPKE_Keygen()
r = FauxCbd(RandomList(n))
u, v = KPKE_Encrypt(A, t, m, r)
mr = KPKE_Decrypt(u, v, s)
dbg('Original m:')
dbg(m)
dbg(m.bits())
if m != mr:
    raise ValueError('decrypted m does not match, final decompression likely failed')

===== kpke_keygen =====
A:
[       1999*xbar^255 + 2290*xbar^254 + 1426*xbar^253 + 252*xbar^252 + 148*xbar^251 + 3233*xbar^250 + 1675*xbar^249 + 1526*xbar^248 + 909*xbar^247 + 2532*xbar^246 + 1589*xbar^245 + 3233*xbar^244 + 2493*xbar^243 + 1040*xbar^242 + 1041*xbar^241 + 1540*xbar^240 + 673*xbar^239 + 18*xbar^238 + 443*xbar^237 + 2843*xbar^236 + 1004*xbar^235 + 3327*xbar^234 + 1009*xbar^233 + 1430*xbar^232 + 1735*xbar^231 + 3031*xbar^230 + 387*xbar^229 + 776*xbar^228 + 696*xbar^227 + 3266*xbar^226 + 2788*xbar^225 + 844*xbar^224 + 2303*xbar^223 + 1249*xbar^222 + 1068*xbar^221 + 2724*xbar^220 + 1482*xbar^219 + 2789*xbar^218 + 628*xbar^217 + 1945*xbar^216 + 1916*xbar^215 + 2040*xbar^214 + 720*xbar^213 + 2078*xbar^212 + 2274*xbar^211 + 458*xbar^210 + 1274*xbar^209 + 3226*xbar^208 + 254*xbar^207 + 2980*xbar^206 + 625*xbar^205 + 3320*xbar^204 + 2746*xbar^203 + 143*xbar^202 + 2150*xbar^201 + 1261*xbar^200 + 1632*xbar^199 + 40*xbar^198 + 1252*xbar^197 + 2156*xbar^196 + 111*xbar^195 + 2366*xbar

# INDCCA Encapsulation (ML-KEM)

## ML-KEM KeyGen

In [186]:
def MLKEM_KeyGen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    dbg('===== MLKEM_KeyGen =====')
    z = RandInt(n)
    A, t, s = KPKE_Keygen()
    ek = (A, t)
    Ht = sha3_256(Poly2Bytes(t)).digest()
    dbg('SHA3-256(t):')
    dbg(Ht.hex())
    dk = (s, ek, Ht, z)
    dbg()
    return ek, dk

## ML-KEM Encaps

In [187]:
def MLKEM_Encaps(ek: (polynomial, polynomial)) -> (bytes, bytes):
    dbg('===== MLKEM_Encaps =====')
    m = RandInt(n)
    dbg('m:')
    dbg(m)
    A, t = ek
    Ht = sha3_256(Poly2Bytes(t)).digest()
    dbg('SHA3-256(t):')
    dbg(Ht.hex())
    Kr = sha3_512(int(m).to_bytes(BytesNeed4Bits(n), 'big') + Ht).digest()
    K, r = (Kr[:BytesNeed4Bits(n)], Kr[BytesNeed4Bits(n):])
    r = FauxCbd(Bytes2ListInt(r))
    dbg('r')
    dbg(r)
    dbg()
    
    c = KPKE_Encrypt(A, t, m, r)
    dbg()
    return K, c

## ML-KEM Decaps

In [188]:
def MLKEM_Decaps(c: (polynomial, polynomial), dk: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    dbg('===== MLKEM_Decaps =====')
    s, ek, h, _ = dk
    A, t = ek
    u, v = c

    mprime = KPKE_Decrypt(u, v, s)
    dbg('m\':')
    dbg(mprime)
    Krprime = sha3_512(int(mprime).to_bytes(BytesNeed4Bits(n), 'big') + h).digest()
    _, rprime = (Krprime[:BytesNeed4Bits(n)], Krprime[BytesNeed4Bits(n):])
    rprime = FauxCbd(Bytes2ListInt(rprime))
    dbg('r\':')
    dbg(rprime)

    dbg()
    uprime, vprime = KPKE_Encrypt(A, t, Integer(mprime), rprime)
    dbg('u\':')
    dbg(uprime)
    dbg('v\':')
    dbg(uprime)
    dbg()
    assert(u == uprime)
    assert(v == vprime)



## ML-Kem Scheme Test

In [189]:
ek, dk = MLKEM_KeyGen()
K, c = MLKEM_Encaps(ek)
MLKEM_Decaps(c, dk)

===== MLKEM_KeyGen =====
===== kpke_keygen =====
A:
[                                         2885*xbar^255 + 2772*xbar^254 + 197*xbar^253 + 2221*xbar^252 + 1231*xbar^251 + 2197*xbar^250 + 2392*xbar^249 + 180*xbar^248 + 1859*xbar^247 + 2900*xbar^246 + 2570*xbar^245 + 1371*xbar^244 + 2405*xbar^243 + 2427*xbar^242 + 1503*xbar^241 + 2289*xbar^240 + 2899*xbar^239 + 947*xbar^238 + 861*xbar^237 + 1967*xbar^236 + 876*xbar^235 + 1203*xbar^234 + 3012*xbar^233 + 2596*xbar^232 + 216*xbar^231 + 2107*xbar^230 + 2383*xbar^229 + 1079*xbar^228 + 2183*xbar^227 + 11*xbar^226 + 2656*xbar^225 + 3249*xbar^224 + 2743*xbar^223 + 2904*xbar^222 + 1117*xbar^221 + 2005*xbar^220 + 169*xbar^219 + 599*xbar^218 + 599*xbar^217 + 339*xbar^216 + 887*xbar^215 + 823*xbar^214 + 2623*xbar^213 + 1107*xbar^212 + 137*xbar^211 + 1855*xbar^210 + 3191*xbar^209 + 2084*xbar^208 + 1318*xbar^207 + 2536*xbar^206 + 2212*xbar^205 + 762*xbar^204 + 1632*xbar^203 + 990*xbar^202 + 2004*xbar^201 + 2815*xbar^200 + 3129*xbar^199 + 1290*xbar^1