## Parameters

In [9]:
import math
from sage.rings.polynomial.polynomial_zmod_flint import Polynomial_zmod_flint as polynomial
from sage.matrix.constructor import matrix
from sage.misc.prandom import randrange
from sage.rings.finite_rings.finite_field_constructor import FiniteField 
from sage.rings.polynomial.polynomial_ring_constructor import PolynomialRing

# Kyber Parameters
q = 3329
k = 2
n = 256


RR = PolynomialRing(FiniteField(q, 'x'), 'x', sparse=True)
x = RR.gen()
f = x^(n+1) + 1
RQ = RR.quotient(f)


DEBUG = True


In [10]:
# Helper Functions
def reducePolynomials(matrx):
    cols = len(matrx.columns())
    rows = len(matrx.rows())
    out = [[matrx.coefficient((i,j)) for j in range(cols)] for i in range(rows)]
    for row in range(rows):
        for col in range(cols):
            # Divide the polynomial in out[row][col] by f, only keep the remainder
            _, rt = out[row][col].quo_rem(f)
            out[row][col] = rt
    return matrix(RQ, out)

def randomList(length, cbd=False):
    if cbd:
        # NOT real centered binomial distrbution!
        return [randrange(2) - randrange(2) for i in range(length)]
    else:
        return [randrange(q) for i in range(length)]

def randomPolyUniform(length):
    return RQ(randomList(length))

def randomPolyCbd(length):
    return RQ(randomList(length, cbd=True))

def compress(poly: polynomial):
     q2 = math.ceil(q/2)
     return poly * q2

def decompress(poly: polynomial) -> int:
    return [(1 if 3*(q/4) > Integer(i) > q/4 else 0) for i in poly]
    # return int(''.join([str(x) for x in dpoly]), 2)

def dbg(s: str = ''):
    if DEBUG:
        print(s)

## K-PKE KeyGen (INDCPA)

In [11]:
def kpke_keygen() -> (matrix, matrix, matrix):
    dbg('===== kpke_keygen =====')
    # A is a k*k dimension matrix of polynomials with n terms
    A = []
    for _ in range(0, k):
        tA = []
        for _ in range(0, k):
            tA.append(randomPolyUniform(n))
        A.append(tA)
    A = matrix(A)
    dbg('A:')
    dbg(A)
    
    # s is a k*1 dimension matrix of polynomials with n terms
    s = [[randomPolyCbd(n)] for _ in range(0, k)]
    s = matrix(s)
    dbg('s:')
    dbg(s)

    # e is a k*1 dimension matrix of polynomials with n terms
    e = [[randomPolyCbd(n)] for _ in range(0, k)]
    e = matrix(e)
    dbg('e:')
    dbg(e)

#   A*s is a k * 1 matrix of polynomials with n terms
#   A*s+e is a k * 1 matrix polynomials with n terms

#   Example when k=2:
#   |     A     |   |  s  |   |  e  |
#   | :-- | :-- |   | :-- |   | :-- |
#   | 0,0 | 0,1 |   |  0  |   |  0  |
#   | 1,0 | 1,1 |   |  1  |   |  1  |

#   |             A * s             |
#   | :---------------------------- |
#   | A[0,0] * s[0] + A[0,1] * s[1] |
#   | A[1,0] * s[0] + A[1,1] * s[1] |

#   |     As+e     |
#   | :----------- |
#   | As[0] + e[0] |
#   | As[1] + e[1] |

    # compute t = A*s*e
    # t is a k*1 dimension matrix
    # Reduce
    # t = reducePolynomials(A*s+e)
    t = A*s+e
    dbg()

    return (A, t, s)

## K-PKE Encrypt (INDCPA)

In [12]:
def kpke_encrypt(A: matrix, t: matrix, m: int) -> (polynomial, polynomial):
    dbg('===== kpke_encrypt =====')
    # Ensure that m does not have more bits than n bits
    if len(m.bits()) > n:
        raise ValueError('m has more bits than n!')
    mm = m.bits()
    dbg('Bits of m:')
    dbg(mm)
    
    # We need m to be at least n bits long.
    # Pad mm with 0s until desired length is reached
    pad = [0 for _ in range(0, n - len(mm))]
    mm = RQ(mm + pad)
    dbg('Polynomial m:')
    dbg(mm)
    mm = compress(mm)
    dbg('Compressed m:')
    dbg(mm)

    # Generate r, e1, e2
    # r is a k*1 matrix
    r = [[randomPolyCbd(n)] for _ in range(0, k)]
    r = matrix(r)
    dbg('r:')
    dbg(r)

    # e1 is a k*1 matrix
    e1 = [[randomPolyCbd(n)] for _ in range(0, k)]
    e1 = matrix(e1)
    dbg('e1:')
    dbg(e1)

    # e2 is an n-length polynomial
    e2 = randomPolyCbd(n)
    dbg('e2:')
    dbg(e2)

    u = A.transpose() * r + e1
    v = t.transpose() * r + e2 + mm

    # u = reducePolynomials(u)
    # v = reducePolynomials(v)

    dbg('u:')
    dbg(u)
    dbg('v:')
    dbg(v)
    dbg()

    return (u, v)

    

## K-PKE Decryption

In [13]:
def kpke_decrypt(u: matrix, v: matrix, s: matrix) -> int:
    dbg('===== kpke_decrypt =====')
    # Compute a noisy result mn
    mn = v - s.transpose() * u
    # mn = (reducePolynomials(mn)).coefficients()[0]
    mn = mn.coefficients()[0]
    dbg('Noisy recovered m:')
    dbg(mn)
   
    mn_c = mn.list() #mn.coefficients(sparse=False)
    mn_c.reverse()

    m_rec = decompress(mn_c)
    dbg('Decompressed m:')
    dbg(m_rec)

    m_rec = int(''.join([str(x) for x in m_rec]), 2)
    dbg('Recovered m:')
    dbg(m_rec)
    dbg()

    return m_rec

In [14]:
from os import urandom
m = Integer(int.from_bytes(urandom(((n+7) & (-8))//8), 'big'))
m &= 2**n-1


A, t, s = kpke_keygen()
u, v = kpke_encrypt(A, t, m)
mr = kpke_decrypt(u, v, s)
dbg('Original m:')
dbg(m)
dbg(m.bits())
if m != mr:
    raise ValueError('decrypted m does not match, final decompression likely failed')

===== kpke_keygen =====
A:
[           2304*xbar^255 + 908*xbar^254 + 3259*xbar^253 + 1933*xbar^252 + 236*xbar^251 + 2094*xbar^250 + 1057*xbar^249 + 1047*xbar^248 + 2312*xbar^247 + 1673*xbar^246 + 2178*xbar^245 + 2648*xbar^244 + 1853*xbar^243 + 1859*xbar^242 + 755*xbar^241 + 3287*xbar^240 + 2663*xbar^239 + 956*xbar^238 + 490*xbar^237 + 3044*xbar^236 + 1405*xbar^235 + 760*xbar^234 + 2699*xbar^233 + 477*xbar^232 + 2317*xbar^231 + 245*xbar^230 + 2714*xbar^229 + 2554*xbar^228 + 1874*xbar^227 + 2210*xbar^226 + 1249*xbar^225 + 1678*xbar^224 + 1440*xbar^223 + 1707*xbar^222 + 1455*xbar^221 + 1219*xbar^220 + 2044*xbar^219 + 692*xbar^218 + 1207*xbar^217 + 1294*xbar^216 + 2030*xbar^215 + 2304*xbar^214 + 954*xbar^213 + 2725*xbar^212 + 165*xbar^211 + 1541*xbar^210 + 235*xbar^209 + 3201*xbar^208 + 2733*xbar^207 + 2447*xbar^206 + 2532*xbar^205 + 3270*xbar^204 + 2060*xbar^203 + 36*xbar^202 + 533*xbar^201 + 1327*xbar^200 + 1560*xbar^199 + 1271*xbar^198 + 758*xbar^197 + 571*xbar^196 + 98*xbar^195 + 1565