In [None]:
pip install pycryptodome

Collecting pycryptodome
  Downloading pycryptodome-3.20.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pycryptodome
Successfully installed pycryptodome-3.20.0


In [None]:
import os
from utils import xor_bytes
from Crypto.Cipher import AES

class AES256_CTR_DRBG:
    def __init__(self, seed=None, personalization=b""):
        self.seed_length = 48
        self.reseed_interval = 2**48
        self.key = bytes([0])*32
        self.V   = bytes([0])*16
        self.entropy_input = self.__check_entropy_input(seed)

        seed_material = self.__instantiate(personalization=personalization)
        self.ctr_drbg_update(seed_material)
        self.reseed_ctr = 1

    def __check_entropy_input(self, entropy_input):

        if entropy_input is None:
            return os.urandom(self.seed_length)
        elif len(entropy_input) != self.seed_length:
            raise ValueError(f"The entropy input must be of length: {self.seed_length}. Input has length {len(entropy_input)}")
        return entropy_input

    def __instantiate(self, personalization=b""):

        if len(personalization) > self.seed_length:
            raise ValueError(f"The Personalization String must be at most length: {self.seed_length}. Input has length {len(personalization)}")
        elif len(personalization) < self.seed_length:
            personalization += bytes([0]) * (self.seed_length - len(personalization))
        # debugging
        assert len(personalization) == self.seed_length
        return xor_bytes(self.entropy_input, personalization)

    def __increment_counter(self):
        int_V = int.from_bytes(self.V, 'big')
        new_V = (int_V + 1) % 2**(8*16)
        self.V = new_V.to_bytes(16, byteorder='big')

    def ctr_drbg_update(self, provided_data):
        tmp = b""
        cipher = AES.new(self.key, AES.MODE_ECB)
        # Collect bytes from AES ECB
        while len(tmp) != self.seed_length:
            self.__increment_counter()
            tmp  += cipher.encrypt(self.V)

        # Take the first 48 bytes
        tmp = tmp[:self.seed_length]
        tmp = xor_bytes(tmp, provided_data)

        # Set the new values of key and V
        self.key = tmp[:32]
        self.V = tmp[32:]

    def reseed(self, additional_information=b""):

        seed_material = self.__instantiate(additional_information)
        self.ctr_drbg_update(seed_material)
        self.reseed_ctr = 1

    def random_bytes(self, num_bytes, additional=None):
        if self.reseed_ctr >= self.reseed_interval:
            raise Warning("The DRBG has been exhausted! Reseed!")

        # Set the optional additional information
        if additional is None:
            additional = bytes([0]) * self.seed_length
        else:
            if len(additional) > self.seed_length:
                 raise ValueError(f"The additional input must be of length at most: {self.seed_length}. Input has length {len(seed)}")
            elif len(additional) < self.seed_length:
                additional += bytes([0]) * (self.seed_length - len(additional))
            self.ctr_drbg_update(additional)

        # Collect bytes!
        tmp = b""
        cipher = AES.new(self.key, AES.MODE_ECB)
        while len(tmp) < num_bytes:
            self.__increment_counter()
            tmp += cipher.encrypt(self.V)

        # Collect only the requested number of bits
        output_bytes = tmp[:num_bytes]
        self.ctr_drbg_update(additional)
        self.reseed_ctr += 1
        print(output_bytes)
        return output_bytes


In [None]:

from polynomials import *
from modules import *

def keygen():

    s0 = R([0,1,-1,-1])
    s1 = R([0,-1,0,-1])
    s = M([s0,s1]).transpose()


    A00 = R([11,16,16,6])
    A01 = R([3,6,4,9])
    A10 = R([1,10,3,5])
    A11 = R([15,9,1,6])
    A = M([[A00, A01],[A10, A11]])


    e0 = R([0,0,1,0])
    e1 = R([0,-1,1,0])
    e = M([e0,e1]).transpose()


    t = A @ s + e


    assert t == M([R([7,0,15,16]),R([6,11,12,10])]).transpose()
    return (A, t), s

def enc(m, public_key):

    r0 = R([0,0,1,-1])
    r1 = R([-1,0,1,1])
    r = M([r0, r1]).transpose()


    e_10 = R([0,1,1,0])
    e_11 = R([0,0,1,0])
    e_1 = M([e_10, e_11]).transpose()


    e_2 = R([0,0,-1,-1])

    A, t = public_key
    poly_m = R.decode(m).decompress(1)
    assert poly_m == R([9,9,0,9])

    u = A.transpose() @ r + e_1
    assert u == M([R([3,10,11,11]), R([11,13,4,4])]).transpose()


    v = (t.transpose() @ r)[0][0] + e_2 - poly_m
    assert v == R([15, 8 , 6, 7])
    return u, v

def dec(u, v, s):
    m_n = v - (s.transpose() @ u)[0][0]
    assert m_n == R([5,7,14,7])
    m_n_reduced = m_n.compress(1)
    assert m_n_reduced == R([1,1,0,1])
    return m_n_reduced.encode(l=2)

if __name__ == '__main__':
    R = PolynomialRing(17, 4)
    M = Module(R)

    m = bytes([69])
    assert R.decode(m) == R([1,1,0,1])
    # Generate keypair
    pub, priv = keygen()
    #print(pub,priv)
    # Encrypt message
    u, v = enc(m, pub)
    #print(u,v)
    # Decrypt message
    n = dec(u, v, priv)
    assert n == m

In [None]:
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from polynomials import *
from modules import *
from ntt_helper import NTTHelperKyber
try:
    from aes256_ctr_drbg import AES256_CTR_DRBG
except ImportError as e:
    print("Error importing AES CTR DRBG. Have you tried installing requirements?")
    print(f"ImportError: {e}\n")
    print("Kyber will work perfectly fine with system randomness")


DEFAULT_PARAMETERS = {
    "kyber_512" : {
        "n" : 256,
        "k" : 2,
        "q" : 3329,
        "eta_1" : 3,
        "eta_2" : 2,
        "du" : 10,
        "dv" : 4,
    },
    "kyber_768" : {
        "n" : 256,
        "k" : 3,
        "q" : 3329,
        "eta_1" : 2,
        "eta_2" : 2,
        "du" : 10,
        "dv" : 4,
    },
    "kyber_1024" : {
        "n" : 256,
        "k" : 4,
        "q" : 3329,
        "eta_1" : 2,
        "eta_2" : 2,
        "du" : 11,
        "dv" : 5,
    }
}

class Kyber:
    def __init__(self, parameter_set):
        self.n = parameter_set["n"]
        self.k = parameter_set["k"]
        self.q = parameter_set["q"]
        self.eta_1 = parameter_set["eta_1"]
        self.eta_2 = parameter_set["eta_2"]
        self.du = parameter_set["du"]
        self.dv = parameter_set["dv"]

        self.R = PolynomialRing(self.q, self.n, ntt_helper=NTTHelperKyber)
        self.M = Module(self.R)

        self.drbg = None
        self.random_bytes = os.urandom

    def set_drbg_seed(self, seed):

        self.drbg = AES256_CTR_DRBG(seed)
        self.random_bytes = self.drbg.random_bytes

    def reseed_drbg(self, seed):

        if self.drbg is None:
            raise Warning(f"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`")
        else:
            self.drbg.reseed(seed)

    @staticmethod
    def _xof(bytes32, a, b, length):

        input_bytes = bytes32 + a + b
        if len(input_bytes) != 34:
            raise ValueError(f"Input bytes should be one 32 byte array and 2 single bytes.")
        return shake_128(input_bytes).digest(length)

    @staticmethod
    def _h(input_bytes):
        """
        H: B* -> B^32
        """
        return sha3_256(input_bytes).digest()

    @staticmethod
    def _g(input_bytes):

        output = sha3_512(input_bytes).digest()
        return output[:32], output[32:]

    @staticmethod
    def _prf(s, b, length):

        input_bytes = s + b
        if len(input_bytes) != 33:
            raise ValueError(f"Input bytes should be one 32 byte array and one single byte.")
        return shake_256(input_bytes).digest(length)

    @staticmethod
    def _kdf(input_bytes, length):

        return shake_256(input_bytes).digest(length)

    def _generate_error_vector(self, sigma, eta, N, is_ntt=False):

        elements = []
        for i in range(self.k):
            input_bytes = self._prf(sigma,  bytes([N]), 64*eta)
            poly = self.R.cbd(input_bytes, eta, is_ntt=is_ntt)
            elements.append(poly)
            N = N + 1
        v = self.M(elements).transpose()
        return v, N

    def _generate_matrix_from_seed(self, rho, transpose=False, is_ntt=False):

        A = []
        for i in range(self.k):
            row = []
            for j in range(self.k):
                if transpose:
                    input_bytes = self._xof(rho, bytes([i]), bytes([j]), 3*self.R.n)
                else:
                    input_bytes = self._xof(rho, bytes([j]), bytes([i]), 3*self.R.n)
                aij = self.R.parse(input_bytes, is_ntt=is_ntt)
                row.append(aij)
            A.append(row)
        return self.M(A)

    def _cpapke_keygen(self):

        # Generate random value, hash and split
        d = self.random_bytes(32)
        rho, sigma = self._g(d)
        # Set counter for PRF
        N = 0

        # Generate the matrix A ∈ R^kxk
        A = self._generate_matrix_from_seed(rho, is_ntt=True)

        # Generate the error vector s ∈ R^k
        s, N = self._generate_error_vector(sigma, self.eta_1, N)
        s.to_ntt()

        # Generate the error vector e ∈ R^k
        e, N = self._generate_error_vector(sigma, self.eta_1, N)
        e.to_ntt()

        # Construct the public key
        t = (A @ s).to_montgomery() + e

        # Reduce vectors mod^+ q
        t.reduce_coefficents()
        s.reduce_coefficents()

        # Encode elements to bytes and return
        pk = t.encode(l=12) + rho
        sk = s.encode(l=12)
        return pk, sk

    def _cpapke_enc(self, pk, m, coins):
        """
        Input:
            pk: public key
            m:  message ∈ B^32
            coins:  random coins ∈ B^32
        Output:
            c:  ciphertext
        """
        N = 0
        rho = pk[-32:]

        tt = self.M.decode(pk, 1, self.k, l=12, is_ntt=True)

        # Encode message as polynomial
        m_poly = self.R.decode(m, l=1).decompress(1)

        # Generate the matrix A^T ∈ R^(kxk)
        At = self._generate_matrix_from_seed(rho, transpose=True, is_ntt=True)

        # Generate the error vector r ∈ R^k
        r, N = self._generate_error_vector(coins, self.eta_1, N)
        r.to_ntt()

        # Generate the error vector e1 ∈ R^k
        e1, N = self._generate_error_vector(coins, self.eta_2, N)

        # Generate the error polynomial e2 ∈ R
        input_bytes = self._prf(coins,  bytes([N]), 64*self.eta_2)
        e2 = self.R.cbd(input_bytes, self.eta_2)

        # Module/Polynomial arithmetic
        u = (At @ r).from_ntt() + e1
        v = (tt @ r)[0][0].from_ntt()
        v = v + e2 + m_poly

        # Ciphertext to bytes
        c1 = u.compress(self.du).encode(l=self.du)
        c2 = v.compress(self.dv).encode(l=self.dv)

        return c1 + c2

    def _cpapke_dec(self, sk, c):
        """
        Input:
            sk: public key
            c:  message ∈ B^32
        Output:
            m:  message ∈ B^32
        """
        # Split ciphertext to vectors
        index = self.du * self.k * self.R.n // 8
        c2 = c[index:]

        # Recover the vector u and convert to NTT form
        u = self.M.decode(c, self.k, 1, l=self.du).decompress(self.du)
        u.to_ntt()

        # Recover the polynomial v
        v = self.R.decode(c2, l=self.dv).decompress(self.dv)

        # s_transpose (already in NTT form)
        st = self.M.decode(sk, 1, self.k, l=12, is_ntt=True)

        # Recover message as polynomial
        m = (st @ u)[0][0].from_ntt()
        m = v - m

        # Return message as bytes
        return m.compress(1).encode(l=1)

    def keygen(self):
        """
        Output:
            pk: Public key
            sk: Secret key

        """
        # Note, although the paper gens z then
        # pk, sk, the implementation does it this
        # way around, which matters for deterministic
        # randomness...
        pk, _sk = self._cpapke_keygen()
        z = self.random_bytes(32)

        # sk = sk' || pk || H(pk) || z
        sk = _sk + pk + self._h(pk) + z
        return pk, sk

    def enc(self, pk, key_length=32):
        """
        Input:
            pk: Public Key
        Output:
            c:  Ciphertext
            K:  Shared key
        """
        m = self.random_bytes(32)
        m_hash = self._h(m)
        Kbar, r = self._g(m_hash + self._h(pk))
        c = self._cpapke_enc(pk, m_hash, r)
        K = self._kdf(Kbar + self._h(c), key_length)
        return c, K

    def dec(self, c, sk, key_length=32):
        """
        Input:
            c:  ciphertext
            sk: Secret Key
        Output:
            K:  Shared key
        """
        # Extract values from `sk`
        # sk = _sk || pk || H(pk) || z
        index = 12 * self.k * self.R.n // 8
        _sk =  sk[:index]
        pk = sk[index:-64]
        hpk = sk[-64:-32]
        z = sk[-32:]

        # Decrypt the ciphertext
        _m = self._cpapke_dec(_sk, c)

        # Decapsulation
        _Kbar, _r = self._g(_m + hpk)
        _c = self._cpapke_enc(pk, _m, _r)

        # if decapsulation was successful return K
        if c == _c:
            return self._kdf(_Kbar + self._h(c), key_length)
        # Decapsulation failed... return random value
        return self._kdf(z + self._h(c), key_length)

# Initialise with default parameters for easy import
Kyber512 = Kyber(DEFAULT_PARAMETERS["kyber_512"])
Kyber768 = Kyber(DEFAULT_PARAMETERS["kyber_768"])
Kyber1024 = Kyber(DEFAULT_PARAMETERS["kyber_1024"])

In [None]:
class Module:
    def __init__(self, ring):
        self.ring = ring

    def decode(self, input_bytes, m, n, l=None, is_ntt=False):
        if l is None:
            # Input length must be 32*l*m*n bytes long
            l, check = divmod(8*len(input_bytes), self.ring.n*m*n)
            if check != 0:
                raise ValueError("input bytes must be a multiple of (polynomial degree) / 8")
        else:
            if self.ring.n*l*m*n > len(input_bytes)*8:
                raise ValueError("Byte length is too short for given l")
        chunk_length = 32*l
        byte_chunks = [input_bytes[i:i+chunk_length] for i in range(0, len(input_bytes), chunk_length)]
        matrix = [[0 for _ in range(n)] for _ in range(m)]
        for i in range(m):
            for j in range(n):
                mij = self.ring.decode(byte_chunks[n*i+j], l=l, is_ntt=is_ntt)
                matrix[i][j] = mij
        return self(matrix)

    def __repr__(self):
        return f"Module over the commutative ring: {self.ring}"

    def __str__(self):
        return f"Module over the commutative ring: {self.ring}"

    def __call__(self, matrix_elements):
        if not isinstance(matrix_elements, list):
            raise TypeError(f"Elements of a module are matrices, with elements .")

        if isinstance(matrix_elements[0], list):
            for element_list in matrix_elements:
                if not all(isinstance(aij, self.ring.element) for aij in element_list):
                    raise TypeError(f"All elements of the matrix must be elements of the ring: {self.ring}")
            return Module.Matrix(self, matrix_elements)

        elif isinstance(matrix_elements[0], self.ring.element):
            if not all(isinstance(aij, self.ring.element) for aij in matrix_elements):
                raise TypeError(f"All elements of the matrix must be elements of the ring: {self.ring}")
            return Module.Matrix(self, [matrix_elements])

        else:
            raise TypeError(f"Elements of a module are matrices, built from elements of the base ring.")


    class Matrix:
        def __init__(self, parent, matrix_elements):
            self.parent = parent
            self.rows = matrix_elements
            self.m = len(matrix_elements)
            self.n = len(matrix_elements[0])
            if not self.check_dimensions():
                raise ValueError("Inconsistent row lengths in matrix")

        def get_dim(self):
            return self.m, self.n

        def check_dimensions(self):
            return all(len(row) == self.n for row in self.rows)

        def transpose(self):
            new_rows = [list(item) for item in zip(*self.rows)]
            return self.parent(new_rows)

        def transpose_self(self):
            self.m, self.n = self.n, self.m
            self.rows = [list(item) for item in zip(*self.rows)]
            return self

        def reduce_coefficents(self):
            for row in self.rows:
                for ele in row:
                    ele.reduce_coefficents()
            return self

        def to_montgomery(self):
            for row in self.rows:
                for ele in row:
                    ele.to_montgomery()
            return self

        def encode(self, l=None):
            output = b""
            for row in self.rows:
                for j in range(self.n):
                    output += row[j].encode(l=l)
            return output

        def compress(self, d):
            for row in self.rows:
                for ele in row:
                    ele.compress(d)
            return self

        def decompress(self, d):
            for row in self.rows:
                for ele in row:
                    ele.decompress(d)
            return self

        def to_ntt(self):
            for row in self.rows:
                for ele in row:
                    ele.to_ntt()
            return self

        def from_ntt(self):
            for row in self.rows:
                for ele in row:
                    ele.from_ntt()
            return self

        def __getitem__(self, i):
            return self.rows[i]

        def __eq__(self, other):
            return other.rows == self.rows

        def __add__(self, other):
            if not isinstance(other, Module.Matrix):
                raise TypeError("Can only add matrcies to other matrices")
            if self.parent != other.parent:
                raise TypeError("Matricies must have the same base ring")
            if self.get_dim() != other.get_dim():
                raise ValueError("Matrices are not of the same dimensions")

            new_elements = []
            for i in range(self.m):
                new_elements.append([a+b for a,b in zip(self.rows[i], other.rows[i])])
            return self.parent(new_elements)

        def __radd__(self, other):
            return self.__add__(other)

        def __iadd__(self, other):
            self = self + other
            return self

        def __sub__(self, other):
            if not isinstance(other, Module.Matrix):
                raise TypeError("Can only subtract matrcies from other matrices")
            if self.parent != other.parent:
                raise TypeError("Matricies must have the same base ring")
            if self.get_dim() != other.get_dim():
                raise ValueError("Matrices are not of the same dimensions")

            new_elements = []
            for i in range(self.m):
                new_elements.append([a-b for a,b in zip(self.rows[i], other.rows[i])])
            return self.parent(new_elements)

        def __rsub__(self, other):
            return self.__sub__(other)

        def __isub__(self, other):
            self = self - other
            return self

        def __matmul__(self, other):
            """
            Denoted A @ B
            """
            if not isinstance(other, Module.Matrix):
                raise TypeError("Can only multiply matrcies with other matrices")
            if self.parent != other.parent:
                raise TypeError("Matricies must have the same base ring")
            if self.n != other.m:
                raise ValueError("Matrices are of incompatible dimensions")

            new_elements = [[sum(a*b for a,b in zip(A_row, B_col)) for B_col in other.transpose().rows] for A_row in self.rows]
            return self.parent(new_elements)

        def __repr__(self):
            if len(self.rows) == 1:
                return str(self.rows[0])
            max_col_width = []
            for n_col in range(self.n):
                max_col_width.append(max(len(str(row[n_col])) for row in self.rows))
            info = ']\n['.join([', '.join([f'{str(x):>{max_col_width[i]}}' for i,x in enumerate(r)]) for r in self.rows])
            return f"[{info}]"

In [None]:
"""
The class `NTTHelper` has been defined to allow for the
`Polynomial` class to have some `n=256` NTT help for
Kyber. This is ok code, but it doesnt generalise nicely.

TODOs:

- Build structure to allow this to generalise away from n=256.
- Allow for kyber and dilithium NTT in one file.

"""

NTT_PARAMETERS = {
    "kyber" : {
        "q" : 3329,
        "mont_r"        : 2285,  # 2^16 % q
        "mont_r2"       : 1353,  # 2^32 % q
        "mont_r_inv"    : 169,   # (1 / 2^16) % q
        "mont_mask"     : 65535, # 2^16 - 1,
        "q_inv"         : 3327,  # -1 / 3329 ^ 2^16,
        "root_of_unity" : 17,
        # NTT_ZETAS  : [(mont_r * pow(root_of_unity,  br(i,7), q)) % q for i in range(128)],
        "zetas" : [2285, 2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468,
                     573, 2004, 264, 383, 2500, 1458, 1727, 3199, 2648, 1017, 732, 608, 1787, 411, 3124, 1758,
                     1223, 652, 2777, 1015, 2036, 1491, 3047, 1785, 516, 3321, 3009, 2663, 1711, 2167, 126, 1469,
                     2476, 3239, 3058, 830, 107, 1908, 3082, 2378, 2931, 961, 1821, 2604, 448, 2264, 677, 2054,
                     2226, 430, 555, 843, 2078, 871, 1550, 105, 422, 587, 177, 3094, 3038, 2869, 1574, 1653, 3083,
                     778, 1159, 3182, 2552, 1483, 2727, 1119, 1739, 644, 2457, 349, 418, 329, 3173, 3254, 817,
                     1097, 603, 610, 1322, 2044, 1864, 384, 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, 2144,
                     1799, 2051, 794, 1819, 2475, 2459, 478, 3221, 3021, 996, 991, 958, 1869, 1522, 1628],
        "f" : 1441,              # 2^32 / 128 % q
    },
}


class NTTHelper():
    def __init__(self, parameter_set):
        self.q          = parameter_set["q"]
        self.mont_r     = parameter_set["mont_r"]
        self.mont_r2    = parameter_set["mont_r2"]
        self.mont_r_inv = parameter_set["mont_r_inv"]
        self.q_inv      = parameter_set["q_inv"]
        self.zetas      = parameter_set["zetas"]
        self.f          = parameter_set["f"]

    @staticmethod
    def br(i, k):
        """
        bit reversal of an unsigned k-bit integer
        """
        bin_i = bin(i & (2**k - 1))[2:].zfill(k)
        return int(bin_i[::-1], 2)

    def montgomery_reduce(self, a):
        """
        a -> R^(-1) a mod q
        """
        return a * self.mont_r_inv % self.q

    def to_montgomery(self, poly):
        poly.coeffs = [self.ntt_mul(self.mont_r2, c) for c in poly.coeffs]
        return poly

    def reduce_mod_q(self, a):
        """
        return a mod q
        """
        return a % self.q

    def barrett_reduce(self,a):
        """
        a mod q in -(q-1)/2, ... ,(q-1)/2
        """
        v = ((1 << 26) + self.q // 2) // self.q
        t = (v * a + (1 << 25)) >> 26
        t = t * self.q
        return (a - t)

    def ntt_mul(self, a, b):
        """
        Ra * Rb -> Rab
        """
        c = a * b
        return self.montgomery_reduce(c)

    def ntt_base_multiplication(self, a0, a1, b0, b1, zeta):
        r0  = self.ntt_mul(a1, b1)
        r0  = self.ntt_mul(r0, zeta)
        r0 += self.ntt_mul(a0, b0)
        r1  = self.ntt_mul(a0, b1)
        r1 += self.ntt_mul(a1, b0)
        return r0, r1

    def ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
        new_coeffs = []
        for i in range(64):
            r0, r1 = self.ntt_base_multiplication(
                                f_coeffs[4*i+0], f_coeffs[4*i+1],
                                g_coeffs[4*i+0], g_coeffs[4*i+1],
                                self.zetas[64+i])
            r2, r3 = self.ntt_base_multiplication(
                                f_coeffs[4*i+2], f_coeffs[4*i+3],
                                g_coeffs[4*i+2], g_coeffs[4*i+3],
                                -self.zetas[64+i])
            new_coeffs += [r0, r1, r2, r3]
        return new_coeffs

    def to_ntt(self, poly):

        if poly.is_ntt:
            raise ValueError("Cannot convert NTT form polynomial to NTT form")

        k, l = 1, 128
        coeffs = poly.coeffs
        while l >= 2:
            start = 0
            while start < 256:
                zeta = self.zetas[k]
                k = k + 1
                for j in range(start, start + l):
                    t = self.ntt_mul(zeta, coeffs[j+l])
                    coeffs[j+l] = coeffs[j] - t
                    coeffs[j]   = coeffs[j] + t
                start = l + (j + 1)
            l = l >> 1

        poly.is_ntt = True
        return poly

    def from_ntt(self, poly):
        if not poly.is_ntt:
            raise ValueError("Can only convert from a polynomial in NTT form")

        l, l_upper = 2, 128
        k = l_upper - 1
        coeffs = poly.coeffs
        while l <= 128:
            start = 0
            while start < poly.parent.n:
                zeta = self.zetas[k]
                k = k - 1
                for j in range(start, start+l):
                    t = coeffs[j]
                    coeffs[j]   = self.reduce_mod_q(t + coeffs[j+l])
                    coeffs[j+l] = coeffs[j+l] - t
                    coeffs[j+l] = self.ntt_mul(zeta, coeffs[j+l])
                start = j + l + 1
            l = l << 1
        for j in range(poly.parent.n):
            coeffs[j] = self.ntt_mul(coeffs[j], self.f)

        poly.is_ntt = False
        return poly

NTTHelperKyber = NTTHelper(NTT_PARAMETERS["kyber"])

In [None]:
import random
from utils import *

class PolynomialRing:
    """
    Initialise the polynomial ring:

        R = Z(q) / (X^n + 1)
    """
    def __init__(self, q, n, ntt_helper=None):
        self.q = q
        self.n = n
        self.element = PolynomialRing.Polynomial
        self.ntt_helper = ntt_helper

    def gen(self, is_ntt=False):
        return self([0,1], is_ntt=is_ntt)

    def random_element(self, is_ntt=False):
        coefficients = [random.randint(0, self.q - 1) for _ in range(self.n)]
        return self(coefficients, is_ntt=is_ntt)

    def parse(self, input_bytes, is_ntt=False):
        """

        Parse: B^* -> R
        """
        i, j = 0, 0
        coefficients = [0 for _ in range(self.n)]
        while j < self.n:
            d1 = input_bytes[i] + 256*(input_bytes[i+1] % 16)
            d2 = (input_bytes[i+1] // 16) + 16*input_bytes[i+2]

            if d1 < self.q:
                coefficients[j] = d1
                j = j + 1

            if d2 < self.q and j < self.n:
                coefficients[j] = d2
                j = j + 1

            i = i + 3
        return self(coefficients, is_ntt=is_ntt)

    def cbd(self, input_bytes, eta, is_ntt=False):
        """
        Expects a byte array of length (eta * deg / 4)
        For Kyber, this is 64 eta.
        """
        assert (self.n >> 2)*eta == len(input_bytes)
        coefficients = [0 for _ in range(self.n)]
        list_of_bits = bytes_to_bits(input_bytes)
        for i in range(self.n):
            a = sum(list_of_bits[2*i*eta + j]       for j in range(eta))
            b = sum(list_of_bits[2*i*eta + eta + j] for j in range(eta))
            coefficients[i] = a-b
        return self(coefficients, is_ntt=is_ntt)

    def decode(self, input_bytes, l=None, is_ntt=False):
        """
        Decode (Algorithm 3)

        decode: B^32l -> R_q
        """
        if l is None:
            l, check = divmod(8*len(input_bytes), self.n)
            if check != 0:
                raise ValueError("input bytes must be a multiple of (polynomial degree) / 8")
        else:
            if self.n*l != len(input_bytes)*8:
                raise ValueError("input bytes must be a multiple of (polynomial degree) / 8")
        coefficients = [0 for _ in range(self.n)]
        list_of_bits = bytes_to_bits(input_bytes)
        for i in range(self.n):
            coefficients[i] = sum(list_of_bits[i*l + j] << j for j in range(l))
        return self(coefficients, is_ntt=is_ntt)

    def __call__(self, coefficients, is_ntt=False):
        if isinstance(coefficients, int):
            return self.element(self, [coefficients], is_ntt)
        if not isinstance(coefficients, list):
            raise TypeError(f"Polynomials should be constructed from a list of integers, of length at most d = {self.n}")
        return self.element(self, coefficients, is_ntt)

    def __repr__(self):
        return f"Univariate Polynomial Ring in x over Finite Field of size {self.q} with modulus x^{self.n} + 1"

    class Polynomial:
        def __init__(self, parent, coefficients, is_ntt=False):
            self.parent = parent
            self.coeffs = self.parse_coefficients(coefficients)
            self.is_ntt = is_ntt

        def is_zero(self):
            """
            Return if polynomial is zero: f = 0
            """
            return all(c == 0 for c in self.coeffs)

        def is_constant(self):
            """
            Return if polynomial is constant: f = c
            """
            return all(c == 0 for c in self.coeffs[1:])

        def parse_coefficients(self, coefficients):
            """
            Helper function which right pads with zeros
            to allow polynomial construction as
            f = R([1,1,1])
            """
            l = len(coefficients)
            if l > self.parent.n:
                raise ValueError(f"Coefficients describe polynomial of degree greater than maximum degree {self.parent.n}")
            elif l < self.parent.n:
                coefficients = coefficients + [0 for _ in range (self.parent.n - l)]
            return coefficients

        def reduce_coefficents(self):
            """
            Reduce all coefficents modulo q
            """
            self.coeffs = [c % self.parent.q for c in self.coeffs]
            return self

        def encode(self, l=None):
            """
            Encode (Inverse of Algorithm 3)
            """
            if l is None:
                l = max(x.bit_length() for x in self.coeffs)
            bit_string = ''.join(format(c, f'0{l}b')[::-1] for c in self.coeffs)
            return bitstring_to_bytes(bit_string)

        def compress(self, d):
            """
            Compress the polynomial by compressing each coefficent
            NOTE: This is lossy compression
            """
            compress_mod   = 2**d
            compress_float = compress_mod / self.parent.q
            self.coeffs = [round_up(compress_float * c) % compress_mod for c in self.coeffs]
            return self

        def decompress(self, d):
            """
            Decompress the polynomial by decompressing each coefficent
            NOTE: This as compression is lossy, we have
            x' = decompress(compress(x)), which x' != x, but is
            close in magnitude.
            """
            decompress_float = self.parent.q / 2**d
            self.coeffs = [round_up(decompress_float * c) for c in self.coeffs ]
            return self

        def add_mod_q(self, x, y):
            """
            add two coefficents modulo q
            """
            tmp = x + y
            if tmp >= self.parent.q:
                tmp -= self.parent.q
            return tmp

        def sub_mod_q(self, x, y):
            """
            sub two coefficents modulo q
            """
            tmp = x - y
            if tmp < 0:
                tmp += self.parent.q
            return tmp

        def schoolbook_multiplication(self, other):
            """
            Naive implementation of polynomial multiplication
            suitible for all R_q = F_1[X]/(X^n + 1)
            """
            n = self.parent.n
            a = self.coeffs
            b = other.coeffs
            new_coeffs = [0 for _ in range(n)]
            for i in range(n):
                for j in range(0, n-i):
                    new_coeffs[i+j] += (a[i] * b[j])
            for j in range(1, n):
                for i in range(n-j, n):
                    new_coeffs[i+j-n] -= (a[i] * b[j])
            return [c % self.parent.q for c in new_coeffs]


        def to_ntt(self):
            if self.parent.ntt_helper is None:
                raise ValueError("Can only perform NTT transform when parent element has an NTT Helper")
            return self.parent.ntt_helper.to_ntt(self)

        def from_ntt(self):
            if self.parent.ntt_helper is None:
                raise ValueError("Can only perform NTT transform when parent element has an NTT Helper")
            return self.parent.ntt_helper.from_ntt(self)

        def to_montgomery(self):
            """
            Multiply every element by 2^16 mod q

            Only implemented (currently) for n = 256
            """
            if self.parent.ntt_helper is None:
                raise ValueError("Can only perform Mont. reduction when parent element has an NTT Helper")
            return self.parent.ntt_helper.to_montgomery(self)

        def ntt_multiplication(self, other):
            """
            Number Theoretic Transform multiplication.
            Only implemented (currently) for n = 256
            """
            if self.parent.ntt_helper is None:
                raise ValueError("Can only perform ntt reduction when parent element has an NTT Helper")
            if not (self.is_ntt and other.is_ntt):
                raise ValueError("Can only multiply using NTT if both polynomials are in NTT form")
            # function in ntt_helper.py
            new_coeffs = self.parent.ntt_helper.ntt_coefficient_multiplication(self.coeffs, other.coeffs)
            return self.parent(new_coeffs, is_ntt=True)

        def __neg__(self):
            """
            Returns -f, by negating all coefficients
            """
            neg_coeffs = [(-x % self.parent.q) for x in self.coeffs]
            return self.parent(neg_coeffs, is_ntt=self.is_ntt)

        def __add__(self, other):
            if isinstance(other, PolynomialRing.Polynomial):
                if self.is_ntt ^ other.is_ntt:
                    raise ValueError(f"Both or neither polynomials must be in NTT form before multiplication")
                new_coeffs = [self.add_mod_q(x,y) for x,y in zip(self.coeffs, other.coeffs)]
            elif isinstance(other, int):
                new_coeffs = self.coeffs.copy()
                new_coeffs[0] = self.add_mod_q(new_coeffs[0], other)
            else:
                raise NotImplementedError(f"Polynomials can only be added to each other")
            return self.parent(new_coeffs, is_ntt=self.is_ntt)

        def __radd__(self, other):
            return self.__add__(other)

        def __iadd__(self, other):
            self = self + other
            return self

        def __sub__(self, other):
            if isinstance(other, PolynomialRing.Polynomial):
                if self.is_ntt ^ other.is_ntt:
                    raise ValueError(f"Both or neither polynomials must be in NTT form before multiplication")
                new_coeffs = [self.sub_mod_q(x,y) for x,y in zip(self.coeffs, other.coeffs)]
            elif isinstance(other, int):
                new_coeffs = self.coeffs.copy()
                new_coeffs[0] = self.sub_mod_q(new_coeffs[0], other)
            else:
                raise NotImplementedError(f"Polynomials can only be subracted from each other")
            return self.parent(new_coeffs, is_ntt=self.is_ntt)

        def __rsub__(self, other):
            return self.__sub__(other)

        def __isub__(self, other):
            self = self - other
            return self

        def __mul__(self, other):
            if isinstance(other, PolynomialRing.Polynomial):
                if self.is_ntt and other.is_ntt:
                    return self.ntt_multiplication(other)
                elif self.is_ntt ^ other.is_ntt:
                     raise ValueError(f"Both or neither polynomials must be in NTT form before multiplication")
                else:
                    new_coeffs = self.schoolbook_multiplication(other)
            elif isinstance(other, int):
                new_coeffs = [(c * other) % self.parent.q for c in self.coeffs]
            else:
                raise NotImplementedError(f"Polynomials can only be multiplied by each other, or scaled by integers")
            return self.parent(new_coeffs, is_ntt=self.is_ntt)

        def __rmul__(self, other):
            return self.__mul__(other)

        def __imul__(self, other):
            self = self * other
            return self

        def __pow__(self, n):
            if not isinstance(n, int):
                raise TypeError(f"Exponentiation of a polynomial must be done using an integer.")

            # Deal with negative scalar multiplication
            if n < 0:
                raise ValueError(f"Negative powers are not supported for elements of a Polynomial Ring")
            f = self
            g = self.parent(1, is_ntt=self.is_ntt)
            while n > 0:
                if n % 2 == 1:
                    g = g * f
                f = f * f
                n = n // 2
            return g

        def __eq__(self, other):
            if isinstance(other, PolynomialRing.Polynomial):
                return self.coeffs == other.coeffs and self.is_ntt == other.is_ntt
            elif isinstance(other, int):
                if self.is_constant() and (other % self.parent.q) == self.coeffs[0]:
                    return True
            return False

        def __getitem__(self, idx):
            return self.coeffs[idx]

        def __repr__(self):
            ntt_info = ""
            if self.is_ntt:
                ntt_info = " (NTT form)"
            if self.is_zero():
                return "0" + ntt_info

            info = []
            for i,c in enumerate(self.coeffs):
                if c != 0:
                    if i == 0:
                        info.append(f"{c}")
                    elif i == 1:
                        if c == 1:
                            info.append("x")
                        else:
                            info.append(f"{c}*x")
                    else:
                        if c == 1:
                            info.append(f"x^{i}")
                        else:
                            info.append(f"{c}*x^{i}")
            return " + ".join(info) + ntt_info

        def __str__(self):
            return self.__repr__()

In [None]:
def bytes_to_bits(input_bytes):

    bit_string = ''.join(format(byte, '08b')[::-1] for byte in input_bytes)
    return list(map(int, list(bit_string)))

def bitstring_to_bytes(s):

    return bytes([int(s[i:i+8][::-1], 2) for i in range(0, len(s), 8)])

def round_up(x):

    return round(x + 0.000001)

def xor_bytes(a, b):

    return bytes(a^b for a,b in zip(a,b))

In [None]:
import unittest
import os
from kyber import Kyber512, Kyber768, Kyber1024
from aes256_ctr_drbg import AES256_CTR_DRBG

def parse_kat_data(data):
    parsed_data = {}
    count_blocks = data.split('\n\n')
    for block in count_blocks[1:-1]:
        block_data = block.split('\n')
        count, seed, pk, sk, ct, ss = [line.split(" = ")[-1] for line in block_data]
        parsed_data[count] = {
            "seed": bytes.fromhex(seed),
            "pk": bytes.fromhex(pk),
            "sk": bytes.fromhex(sk),
            "ct": bytes.fromhex(ct),
            "ss": bytes.fromhex(ss),
        }
    return parsed_data

class TestKyber(unittest.TestCase):


    def generic_test_kyber(self, Kyber, count):
        for _ in range(count):
            pk, sk = Kyber.keygen()
            for _ in range(count):
                c, key = Kyber.enc(pk)
                _key = Kyber.dec(c, sk)
                self.assertEqual(key, _key)

    def test_kyber512(self):
        self.generic_test_kyber(Kyber512, 5)

    def test_kyber768(self):
        self.generic_test_kyber(Kyber768, 5)

    def test_kyber1024(self):
        self.generic_test_kyber(Kyber1024, 5)

class TestKyberDeterministic(unittest.TestCase):


    def generic_test_kyber_deterministic(self, Kyber, count):
        """
        First we generate five pk,sk pairs
        from the same seed and make sure
        they're all the same
        """
        seed = os.urandom(48)
        pk_output = []
        for _ in range(count):
            Kyber.set_drbg_seed(seed)
            pk, sk = Kyber.keygen()
            pk_output.append(pk + sk)
        self.assertEqual(len(pk_output), 5)
        self.assertEqual(len(set(pk_output)), 1)

        """
        Now given a fixed keypair make sure
        that c,key are the same for a fixed seed
        """
        key_output = []
        seed = os.urandom(48)
        pk, sk = Kyber.keygen()
        for _ in range(count):
            Kyber.set_drbg_seed(seed)
            c, key = Kyber.enc(pk)
            _key = Kyber.dec(c, sk)
            # Check key derivation works
            self.assertEqual(key, _key)
            key_output.append(c + key)
        self.assertEqual(len(key_output), count)
        self.assertEqual(len(set(key_output)), 1)

    def test_kyber512_deterministic(self):
        self.generic_test_kyber_deterministic(Kyber512, 5)

    def test_kyber768_deterministic(self):
        self.generic_test_kyber_deterministic(Kyber768, 5)

    def test_kyber1024_deterministic(self):
        self.generic_test_kyber_deterministic(Kyber1024, 5)


class TestKnownTestValuesDRBG(unittest.TestCase):

    def test_kyber512_known_answer_seed(self):
        # Set DRBG to generate seeds
        entropy_input = bytes([i for i in range(48)])
        rng = AES256_CTR_DRBG(entropy_input)

        with open("assets/PQCkemKAT_1632.rsp") as f:
            # extract data from KAT
            kat_data_512 = f.read()
            parsed_data = parse_kat_data(kat_data_512)
            # Check all seeds match
            for data in parsed_data.values():
                seed = data["seed"]
                self.assertEqual(seed, rng.random_bytes(48))

class TestKnownTestValues(unittest.TestCase):
    def generic_test_kyber_known_answer(self, Kyber, filename):
        with open(filename) as f:
            kat_data = f.read()
            parsed_data = parse_kat_data(kat_data)

            for data in parsed_data.values():
                seed, pk, sk, ct, ss = data.values()

                # Seed DRBG with KAT seed
                Kyber.set_drbg_seed(seed)

                # Assert keygen matches
                _pk, _sk = Kyber.keygen()
                self.assertEqual(pk, _pk)
                self.assertEqual(sk, _sk)

                # Assert encapsulation matches
                _ct, _ss = Kyber.enc(_pk)
                self.assertEqual(ct, _ct)
                self.assertEqual(ss, _ss)

                # Assert decapsulation matches
                __ss = Kyber.dec(ct, sk)
                self.assertEqual(ss, __ss)

    def test_kyber512_known_answer(self):
        return self.generic_test_kyber_known_answer(Kyber512, "PQCkemKAT_1632.rsp")

    def test_kyber768_known_answer(self):
        return self.generic_test_kyber_known_answer(Kyber768, "PQCkemKAT_2400.rsp")

    def test_kyber1024_known_answer(self):
        return self.generic_test_kyber_known_answer(Kyber1024, "PQCkemKAT_3168.rsp")

"""if __name__ == '__main__':
    unittest.main()"""

In [None]:
from kyber import Kyber512, Kyber768, Kyber1024
import cProfile
from time import time

def profile_kyber(Kyber):
    pk, sk = Kyber.keygen()
    c, key = Kyber.enc(pk)

    gvars = {}
    lvars = {"Kyber": Kyber, "c": c, "pk": pk, "sk": sk}

    cProfile.runctx("Kyber.keygen()", globals=gvars, locals=lvars, sort=1)
    cProfile.runctx("Kyber.enc(pk)", globals=gvars, locals=lvars, sort=1)
    cProfile.runctx("Kyber.dec(c, sk)", globals=gvars, locals=lvars, sort=1)

def benchmark_kyber(Kyber, name, count):
    # Banner
    print(f"-"*27)
    print(f"  {name} | ({count} calls)")
    print(f"-"*27)

    keygen_times = []
    enc_times = []
    dec_times = []

    for _ in range(count):
        t0 = time()
        pk, sk = Kyber.keygen()
        keygen_times.append(time() - t0)

        t1 = time()
        c, key = Kyber.enc(pk)
        enc_times.append(time() - t1)

        t2 = time()
        dec = Kyber.dec(c, sk)
        dec_times.append(time() - t2)

    print(f"Keygen: {round(sum(keygen_times),3)}")
    print(f"Enc: {round(sum(enc_times), 3)}")
    print(f"Dec: {round(sum(dec_times),3)}")


if __name__ == '__main__':
    # profile_kyber(Kyber512)
    # profile_kyber(Kyber768)
    # profile_kyber(Kyber1024)

    count = 1000
    benchmark_kyber(Kyber512, "Kyber512", count)
    benchmark_kyber(Kyber768, "Kyber768", count)
    benchmark_kyber(Kyber1024, "Kyber1024", count)

---------------------------
  Kyber512 | (1000 calls)
---------------------------
Keygen: 11.234
Enc: 16.868
Dec: 26.685
---------------------------
  Kyber768 | (1000 calls)
---------------------------
Keygen: 16.934
Enc: 24.533
Dec: 37.795
---------------------------
  Kyber1024 | (1000 calls)
---------------------------
Keygen: 25.003
Enc: 34.471
Dec: 52.205
