# Estruturas Cripográficas - Criptografia e Segurança da Informação

## TP3 - Exercício 2 - KEM

Tratalho realizado por:

| Número     | Nome                          |
|--------|-------------------------------|
| PG54177| Ricardo Alves Oliveira        |
| PG54236| Simão Oliveira Alvim Barroso  |

### Enunciado

2. Em Agosto de 2023 a NIST publicou um draf da [norma FIPS203](https://www.dropbox.com/scl/fo/vllnz60fnd4payrkllm1d/h?rlkey=4z7418pn17qcgmx97etxepvzo&e=1&dl=0)  para um Key Encapsulation Mechanism (KEM) derivado dos algoritmos [KYBER](https://www.dropbox.com/scl/fo/y2i97mufz371tmz0orc40/h?rlkey=ffhdgacrx5wb4jsjb84kugclq&e=1&dl=0). 
    O preâmbulo do “draft” 
> A key-encapsulation mechanism (or KEM) is a set of algorithms that, under certain conditions, can be used by two parties to establish a shared secret key over a public channel. A shared secret key that is securely established using a KEM can then be used with symmetric-key cryptographic algorithms to perform basic tasks in secure communications, such as encryption and authentication. This standard specifes a key-encapsulation mechanism called ML-KEM. The security of ML-KEM is related to the computational diffculty of the so-called MoDUle Learning with Errorsproblem. At present, ML-KEM is believed to be secure even against adversaries who possess a quantum computer


Neste trabalho pretende-se implementar em Sagemath um protótipo deste standard parametrizado de acordo com as variantes sugeridas na norma (512, 768 e 1024 bits de segurança)


# Imports + Explicação

In [109]:
import hashlib, os
from functools import reduce

In [110]:
sec_bits = int(input("n value: "))

assert sec_bits in [512,768,1024], "n must be 512, 768 or 1024"

print(f"sec_bits -> {sec_bits}")

N = 256
Q = 3329 

if sec_bits == 512:
    k=2
    n1=3 
    n2=2
    DU=10
    DV=5
    rbg =128
    
elif sec_bits == 768:
    k=3
    n1=2 
    n2=2
    DU=10
    DV=4
    rbg =192

elif sec_bits == 1024:
    k=4
    n1=2 
    n2=2
    DU=11
    DV=5
    rbg =256

print(f"""
    security level -> {sec_bits}
    n -> {N}\n
    q -> {Q}\n
    k -> {k}\n
    n1 -> {n1}\n
    n2 -> {n2}\n
    DU -> {DU}\n
    DV -> {DV}\n
    rbg -> {rbg}\n""")

sec_bits -> 1024

    security level -> 1024
    n -> 256

    q -> 3329

    k -> 4

    n1 -> 2

    n2 -> 2

    DU -> 11

    DV -> 5

    rbg -> 256



In [111]:

def bit_rev_7(r):
    return int('{:07b}'.format(r)[::-1], 2)

def G(c):
    G_result = hashlib.sha3_512(c).digest()
    return G_result[:32], G_result[32:]


def H(c):
    return hashlib.sha3_256(c).digest()


def J(s, l):
    return hashlib.shake_256(s).digest(l)


def XOF(rho, i, j):
    return hashlib.shake_128(rho + bytes([i]) + bytes([j])).digest(1536)

def PRF(eta, s, b):
    return hashlib.shake_256(s + b).digest(64 * eta)

ZETA = [pow(17, bit_rev_7(k), Q) for k in range(128)]
GAMMA = [pow(17, 2 * bit_rev_7(k) + 1, Q) for k in range(128)]

In [112]:
def bits_to_bytes(b):
    B = bytearray([0] * (len(b) // 8))
    
    for i in range(len(b)):
        B[i // 8] += b[i] * 2 ** (i % 8)
    
    return bytes(B)


def bytes_to_bits(B):
    B_list = list(B)
    b = [0] * (len(B_list) * 8)
    
    for i in range(len(B_list)):
        for j in range(8):
            b[8 * i + j] = B_list[i] % 2
            B_list[i] //= 2

    return b


def byte_encode(d, F):
    b = [0] * (256 * d)
    for i in range(256):
        a = F[i]
        for j in range(d):
            b[i * d + j] = a % 2
            a = (a - b[i * d + j]) // 2
    
    return bits_to_bytes(b)


def byte_decode(d, B):
    m = 2 ** d if d < 12 else Q
    b = bytes_to_bits(B)
    F = [0] * 256
    for i in range(256):
        F[i] = sum(b[i * d + j] * (2 ** j) % m for j in range(d))
    
    return F

In [113]:
def sample_ntt(B):
    i, j = 0, 0
    ac = [0] * 256

    while j < 256:
        d1 = B[i] + 256 * (B[i + 1] % 16)
        d2 = (B[i + 1] // 16) + 16 * B[i + 2]

        if d1 < Q:
            ac[j] = d1
            j += 1

        if d2 < Q and j < 256:
            ac[j] = d2
            j += 1
            
        i += 3

    return ac


def sample_poly_cbd(B, eta):
    b = bytes_to_bits(B)
    f = [0] * 256
    
    for i in range(256):
        x = sum(b[2 * i * eta + j] for j in range(eta))
        y = sum(b[2 * i * eta + eta + j] for j in range(eta))
        f[i] = (x - y) % Q

    return f


In [114]:
def ntt(f):
    fc = f
    k = 1
    len = 128

    while len >= 2:
        start = 0
        while start < 256:
            zeta = ZETA[k]
            k += 1
            for j in range(start, start + len):
                t = (zeta * fc[j + len]) % Q
                fc[j + len] = (fc[j] - t) % Q
                fc[j] = (fc[j] + t) % Q

            start += 2 * len
            
        len //= 2

    return fc


def ntt_inv(fc):
    f = fc
    k = 127
    len = 2
    while len <= 128:
        start = 0
        while start < 256:
            zeta = ZETA[k]
            k -= 1
            for j in range(start, start + len):
                t = f[j]
                f[j] = (t + f[j + len]) % Q
                f[j + len] = (zeta * (f[j + len] - t)) % Q

            start += 2 * len

        len *= 2

    return [(felem * 3303) % Q for felem in f]


def base_case_multiply(a0, a1, b0, b1, gamma):
    c0 = a0 * b0 + a1 * b1 * gamma
    c1 = a0 * b1 + a1 * b0

    return c0, c1


def multiply_ntt_s(fc, gc):
    hc = [0] * 256
    for i in range(128):
        hc[2 * i], hc[2 * i + 1] = base_case_multiply(fc[2 * i], fc[2 * i + 1], gc[2 * i], gc[2 * i + 1], GAMMA[i])
    
    return hc

In [115]:
def vector_add(ac, bc):
	return [(x + y) % Q for x, y in zip(ac, bc)]

def vector_sub(ac, bc):
	return [(x - y) % Q for x, y in zip(ac, bc)]




In [116]:
def k_pke_keygen(k, eta1):
    d = os.urandom(32)
    rho, sigma = G(d)
    N = 0
    Ac = [[None for _ in range(k)] for _ in range(k)]
    s = [None for _ in range(k)]
    e = [None for _ in range(k)]

    for i in range(k):
        for j in range(k):
            Ac[i][j] = sample_ntt(XOF(rho, i, j))

    for i in range(k):
        s[i] = sample_poly_cbd(PRF(eta1, sigma, bytes([N])), eta1)
        N += 1

    for i in range(k):
        e[i] = sample_poly_cbd(PRF(eta1, sigma, bytes([N])), eta1)
        N += 1

    sc = [ntt(s[i]) for i in range(k)]
    ec = [ntt(e[i]) for i in range(k)]
    tc = [reduce(vector_add, [multiply_ntt_s(Ac[i][j], sc[j]) for j in range(k)] + [ec[i]]) for i in range(k)]

    ek_PKE = b"".join(byte_encode(12, tc_elem) for tc_elem in tc) + rho
    dk_PKE = b"".join(byte_encode(12, sc_elem) for sc_elem in sc)

    return ek_PKE, dk_PKE

In [117]:




def transparray_array_mult(uct, vc, k):
    return reduce(vector_add, [multiply_ntt_s(uct[i], vc[i]) for i in range(k)])


def compress(d, x):
	return [(((n * 2**d) + Q // 2 ) // Q) % (2**d) for n in x]

def decompress(d, x):
	return [(((n * Q) + 2**(d-1) ) // 2**d) % Q for n in x]


In [118]:
def k_pke_encrypt(ek_PKE, m, rand, k, eta1, eta2):
    N = 0
    tc = [byte_decode(12, ek_PKE[i * 128 * k : (i + 1) * 128 * k]) for i in range(k)]
    rho = ek_PKE[384 * k : 384 * k + 32]
    Ac = [[None for _ in range(k)] for _ in range(k)]
    r = [None for _ in range(k)]
    e1 = [None for _ in range(k)]

    for i in range(k):
        for j in range(k):
            Ac[i][j] = sample_ntt(XOF(rho, i, j))


    for i in range(k):
        r[i] = sample_poly_cbd(PRF(eta1, rand, bytes([N])), eta1)
        N += 1

    
    for i in range(k):
        e1[i] = sample_poly_cbd(PRF(eta2, rand, bytes([N])), eta2)
        N += 1

    e2 = sample_poly_cbd(PRF(eta2, rand, bytes([N])), eta2)
    rc = [ntt(r[i]) for i in range(k)]

    
    # ntt_mult_Act_rc = [ntt_inv(mult_Act_rc_elem) for mult_Act_rc_elem in mult_Act_rc]
    # u = [vector_add(ntt_mult_Act_rc[i], e1[i]) for i in range(k)]
    u = [vector_add(ntt_inv(reduce(vector_add, [multiply_ntt_s(Ac[j][i], rc[j]) for j in range(k)])), e1[i]) for i in range(k)]

    mu = decompress(1, byte_decode(1, m))

    v = vector_add(ntt_inv(reduce(vector_add, [multiply_ntt_s(tc[i], rc[i]) for i in range(k)])), vector_add(e2, mu))

    c1 = b"".join(byte_encode(DU, compress(DU, u[i])) for i in range(k))
    c2 = byte_encode(DV, compress(DV, v))


    
    return c1 + c2


In [119]:
def k_pke_decrypt(dk_PKE, c, k):
    c1 = c[:32 * DU * k]
    c2 = c[32 * DU * k : 32 * (DU * k + DV)]
    u = [decompress(DU, byte_decode(DU, c1[i * 32 * DU : (i + 1) * 32 * DU])) for i in range(k)]

    v = decompress(DV, byte_decode(DV, c2))

    sc = [byte_decode(12, dk_PKE[i * 384 : (i + 1) * 384]) for i in range(k)]

    w = vector_sub(v, ntt_inv(reduce(vector_add, [multiply_ntt_s(sc[i], ntt(u[i])) for i in range(k)])))

    # w = poly256_sub(v, ntt_inv(reduce(ntt_add, [
	# 	ntt_mul(shat[i], ntt(u[i]))
	# 	for i in range(K)
	# ])))


    return byte_encode(1, compress(1, w))

In [120]:
def ml_kem_keygen():
    z = os.urandom(32)
    ek_PKE, dk_PKE = k_pke_keygen(3, 2)
    ek = ek_PKE
    dk = dk_PKE + ek + H(ek) + z

    return ek, dk


def ml_kem_encaps(ek):
    m = os.urandom(32)
    K, r = G(m + H(ek))
    c = k_pke_encrypt(ek, m, r, 3, 2, 2)

    return K, c


def ml_kem_decaps(c, dk):
    dk_PKE = dk[0: 384 * 3]
    ek_PKE = dk[384 * 3 : 768 * 3 + 32]
    h = dk[768 * 3 + 32 : 768 * 3 + 64]
    z = dk[768 * 3 + 64 : 768 * 3 + 96]
    ml = k_pke_decrypt(dk_PKE, c, 3)
    Kl, rl = G(ml + h)
    Kb = J((z + c), 32)
    cl = k_pke_encrypt(ek_PKE, ml, rl, 3, 2, 2)
    if c != cl:
        Kl = Kb

    return Kl


In [121]:
rand = os.urandom(32)
message = b'Este e um exemplo de mensagem !!'
ek_PKE, dk_PKE = k_pke_keygen(3, 2)
#  
print('message:', message)
cipher = k_pke_encrypt(ek_PKE, message, rand, 3, 2, 2)
print(cipher)
# 
message2 = k_pke_decrypt(dk_PKE, cipher, 3)
print('message:', message2)

ek, dk = ml_kem_keygen()
#print('ek:', ek)
#print('dk:', dk)

K, c = ml_kem_encaps(ek)
#print('K:', K)
#print('c:', c)

Kl = ml_kem_decaps(c, dk)

message: b'Este e um exemplo de mensagem !!7'
b'\x1a\xf1xZ\xf9\xcc\x8d6JC\x1f^\xc8\x9b{\xd6\xc6\xf4\xe0\xde\xf5\x19D\xc3\xc0v\xd2\xee\x9b\xb7v\xc2\x1fv>-\xfc\xec\xbfz\xef#\x06C($\xcfv\xda\xc6j\x80\xe6a\xd4\x9cm\xdb\x84QI\x9fT\x8dwm\x9b"\x04w\xc2^\x86\x04]<[7\xb8\xd0\x1b\x03A\x8d\xd4R\x84U\xb8:\x9b\x8c\xf2\x0c<\xf1"\x03\x95\x96\r>U\xb0[\x8e(\xe4\xd6\x04\xf9_\x18\xd1XZ\xab\x87\x9b\xa0\xb7\x88\x91\x0bW \xa4\xaf\xa9\x8cL\xf3\x18\xa4r2Xd\xd1\xd0[\xeb.\xddk\x19Tc\x1fy0\xa8\x08\xc1\x83=\x82\x03\x0e\x1fV\xf6\xebFX\xdcn\x12\xd7E\x94\xf1y-\xf1\xff\xc9\xea\x8e\x073\xcfF\xc6\x04\x1b^\xe0\x9cBh[Ev\xd8\x8a-\xe5\xb5M\x06\x13\xc4\x96?\x01\x8d\xa6\x94\xf4\xa13\x97o^\xdb&\xdc\n\xcf\xfa\x8e\xc5f\x03\xf7\xaeP\xde\xa6\x94 \\,\x8e\x95&\x9c\x7f\xd97t_1\xd5\x07H\xa2\xce\xb1\xc0\xf3\xb4iE\xf0\x97\x94\xd3\xad\xfe$D~\xe6\xc6\xdalP@\x12\xcc\x16\xce\xd5[u\x90X\xa6<n8o2\xe2\xaax(\xfc\x0c\xa29+\xef\xce\xb0\x0e;\x92G*!k\x0b8!\x859\x14\xf0\xc1N\xf8ViU\xc4\t\xb3\x15\xcd\xd6\xc7\x1d\xe9\xbe\xef\x82.\x9b\x8d\xe8K:~}\x10[