## Setup

### Parameters

In [23]:
import math
from os import urandom
from hashlib import sha3_256, sha3_512

from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing
from sage.rings.polynomial.polynomial_quotient_ring import PolynomialQuotientRing as polynomial
from sage.rings.integer import Integer

# Kyber Parameters
q = 3329
q_bytes = 16
k = 2
n = 256


rQ = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = rQ.gen()
f = x^n + 1
RQ = rQ.quotient(f)
print(type(RQ))


DEBUG = True

<class 'sage.rings.polynomial.polynomial_quotient_ring.PolynomialQuotientRing_generic_with_category'>


### Helper Functions

In [67]:
def RandomList(length, cbd=False):
    out = [randrange(q) for i in range(length)]
    if cbd:
        out = FauxCbd(out)
    return out

def FauxCbd(r: list):
    out = []
    for i in r:
        out.append((i % 5) - 2) # Restrict to -2 <= n <= 2
    return out

def RandPolyUniform(length):
    return RQ(RandomList(length))

def RandPolyCbd(length):
    return RQ(RandomList(length, cbd=True))

def BytesNeed4Bits(bits: int) -> int:
    return ((bits+7) & (-8))//8

def RandInt(bits: int) -> int:
    m = Integer(int.from_bytes(urandom(BytesNeed4Bits(bits)), 'big'))
    m &= 2**n-1
    return m

def Poly2Bytes(poly: polynomial) -> bytes:
    out = b''
    p = poly.coefficients()[0].list()
    for c in p:
        c = int(c)
        cb = c.to_bytes(q_bytes, 'big')
        out += cb
    return out

def Bytes2ListInt(b: bytes) -> list:
    out = []
    for byte in b:
        out.append(Integer(byte))
    return out

def Compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def Decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]
    # return int(''.join([str(x) for x in dpoly]), 2)

def dbg(s: str = ''):
    if DEBUG:
        print(s)

## INDCPA Key Exchange (K-PKE)

### K-PKE KeyGen

In [25]:
def KPKE_Keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')
    # A is a k*k dimension matrix of polynomials with n terms
    A = []
    for _ in range(0, k):
        tA = []
        for _ in range(0, k):
            tA.append(RandPolyUniform(n))
        A.append(tA)
    A = matrix(A)
    dbg('A:')
    dbg(A)
    
    # s is a k*1 dimension matrix of polynomials with n terms
    s = [[RandPolyCbd(n)] for _ in range(0, k)]
    s = matrix(s)
    dbg('s:')
    dbg(s)

    # e is a k*1 dimension matrix of polynomials with n terms
    e = [[RandPolyCbd(n)] for _ in range(0, k)]
    e = matrix(e)
    dbg('e:')
    dbg(e)

#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms

#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |

#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |

#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    # compute t = A*s*e
    # t is a k*1 dimension matrix
    t = A*s+e
    dbg()

    return (A, t, s)

### K-PKE Encrypt

In [51]:
def KPKE_Encrypt(A: matrix, t: matrix, m: int, r: polynomial) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mm = m.bits()
    dbg('Bits of m:')
    dbg(mm)

    N = 0
    
    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mm))]
    mm = RQ(mm + pad)
    dbg('Polynomial m:')
    dbg(mm)
    mm = Compress(mm)
    dbg('Compressed m:')
    dbg(mm)

    # Generate r, e1, e2
    # r is a k*1 matrix
    rr = [[RQ(FauxCbd([x + i for x in r]))] for i in range(N,N+k)]
    rr = matrix(rr)
    N += k
    dbg('rr:')
    dbg(rr)

    # e1 is a k*1 matrix
    e1 = [[RQ(FauxCbd([x + i for x in r]))] for i in range(N,N+k)]
    e1 = matrix(e1)
    N += k
    dbg('e1:')
    dbg(e1)

    # e2 is an n-length polynomial
    e2 =  RQ(FauxCbd([x + N for x in r]))
    dbg('e2:')
    dbg(e2)

    u = A.transpose() * rr + e1
    v = t.transpose() * rr + e2 + mm

    dbg('u:')
    dbg(u)
    dbg('v:')
    dbg(v)
    dbg()

    return (u, v)

    

### K-PKE Decrypt

In [68]:
def KPKE_Decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    mn = mn.coefficients()[0]
    dbg('Noisy recovered m:')
    dbg(mn)
   
    mn_c = mn.list()
    mn_c.reverse()

    m_rec = Decompress(mn_c)
    dbg('Decompressed m:')
    dbg(m_rec)

    m_rec = int(''.join([str(x) for x in m_rec]), 2)
    dbg('Recovered m:')
    dbg(m_rec)
    dbg()

    return m_rec

### INDCPA Key Test

In [52]:
from os import urandom
m = RandInt(n)
A, t, s = KPKE_Keygen()
r = FauxCbd(RandomList(n))
u, v = KPKE_Encrypt(A, t, m, r)
mr = KPKE_Decrypt(u, v, s)
dbg('Original m:')
dbg(m)
dbg(m.bits())
if m != mr:
    raise ValueError('decrypted m does not match, final decompression likely failed')

===== kpke_keygen =====
A:
[          2954*xbar^255 + 1144*xbar^254 + 220*xbar^253 + 1296*xbar^252 + 420*xbar^251 + 812*xbar^250 + 1427*xbar^249 + 2112*xbar^248 + 1714*xbar^247 + 181*xbar^246 + 1521*xbar^245 + 1174*xbar^244 + 291*xbar^243 + 2740*xbar^242 + 2581*xbar^241 + 731*xbar^240 + 3082*xbar^239 + 2009*xbar^238 + 1433*xbar^237 + 3015*xbar^236 + 1994*xbar^235 + 618*xbar^234 + 2271*xbar^233 + 1194*xbar^232 + 41*xbar^231 + 3218*xbar^230 + 1095*xbar^229 + 1845*xbar^228 + 467*xbar^227 + 803*xbar^226 + 90*xbar^225 + 2457*xbar^224 + 3030*xbar^223 + 1609*xbar^222 + 1636*xbar^221 + 2345*xbar^220 + 1023*xbar^219 + 1340*xbar^218 + 2773*xbar^217 + 627*xbar^216 + 2663*xbar^215 + 1297*xbar^214 + 483*xbar^213 + 1459*xbar^212 + 1704*xbar^211 + 2333*xbar^210 + 1457*xbar^209 + 568*xbar^208 + 912*xbar^207 + 863*xbar^206 + 1235*xbar^205 + 1531*xbar^204 + 2432*xbar^203 + 1295*xbar^202 + 2074*xbar^201 + 2807*xbar^200 + 3054*xbar^199 + 3209*xbar^198 + 820*xbar^197 + 1911*xbar^196 + 130*xbar^195 + 2118*x

# INDCCA Encapsulation (ML-KEM)

## ML-KEM KeyGen

In [None]:
def mlkem_keygen() -> ((polynomial, polynomial), (polynomial, polynomial, int)):
    z = RandInt(n)
    A, t, s = KPKE_Keygen()
    ek = (A, t)
    Ht = sha3_256(Poly2Bytes(t)).digest()
    dk = (s, ek, Ht, z)
    return ek, dk

## ML-KEM Encaps

In [61]:
def mlkem_encaps(ek: (polynomial, polynomial)) -> (bytes, bytes):
    RandInt(n)
    A, t = ek
    Ht = sha3_256(Poly2Bytes(t)).digest()
    Kr = sha3_512(int(m).to_bytes(BytesNeed4Bits(n), 'big') + Ht).digest()
    K, r = (Kr[:BytesNeed4Bits(n)], Kr[BytesNeed4Bits(n):])
    r = FauxCbd(Bytes2ListInt(r))
    c = KPKE_Encrypt(A, t, m, r)
    return K, c

## ML-KEM Decaps

In [58]:
def mlkem_decaps(c: (polynomial, polynomial), dk: (polynomial, (polynomial, polynomial), bytes, int)) -> bytes:
    s, ek, h, z = dk
    A, t = ek
    u, v = c

    mprime = KPKE_Decrypt(u, v, s)
    print(Integer(mprime))
    Krprime = sha3_512(int(mprime).to_bytes(BytesNeed4Bits(n), 'big') + h).digest()
    Kprime, rprime = (Krprime[:BytesNeed4Bits(n)], Krprime[BytesNeed4Bits(n):])
    rprime = FauxCbd(Bytes2ListInt(rprime))

    uprime, vprime = KPKE_Encrypt(A, t, Integer(mprime), rprime)
    assert(u == uprime)
    assert(v == vprime)



## ML-Kem Scheme Test

In [62]:
ek, dk = mlkem_keygen()
K, c = mlkem_encaps(ek)
mlkem_decaps(c, dk)

===== kpke_keygen =====
A:
[751*xbar^255 + 3176*xbar^254 + 1235*xbar^253 + 483*xbar^252 + 947*xbar^251 + 2964*xbar^250 + 1746*xbar^249 + 661*xbar^248 + 3110*xbar^247 + 691*xbar^246 + 1977*xbar^245 + 2846*xbar^244 + 3156*xbar^243 + 2319*xbar^242 + 3255*xbar^241 + 1726*xbar^240 + 58*xbar^239 + 2880*xbar^238 + 1424*xbar^237 + 3122*xbar^236 + 1802*xbar^235 + 416*xbar^234 + 1923*xbar^233 + 1221*xbar^232 + 311*xbar^231 + 882*xbar^230 + 1651*xbar^229 + 1782*xbar^228 + 3186*xbar^227 + 2775*xbar^226 + 1306*xbar^225 + 1978*xbar^224 + 937*xbar^223 + 2818*xbar^222 + 875*xbar^221 + 1530*xbar^220 + 551*xbar^219 + 3182*xbar^218 + 2667*xbar^217 + 683*xbar^216 + 866*xbar^215 + 2576*xbar^214 + 1955*xbar^213 + 3225*xbar^212 + 2492*xbar^211 + 2785*xbar^210 + 2830*xbar^209 + 2603*xbar^208 + 755*xbar^207 + 1888*xbar^206 + 1932*xbar^205 + 2003*xbar^204 + 3286*xbar^203 + 2283*xbar^202 + 1558*xbar^201 + 1664*xbar^200 + 1225*xbar^199 + 769*xbar^198 + 719*xbar^197 + 1795*xbar^196 + 2299*xbar^195 + 3289*xbar^194 