# Estruturas Criptográficas - Criptografia e Segurança da Informação

[Grupo 03](https://paper.dropbox.com/doc/Estruturas-Criptograficas-2023-2024-Trabalhos-Praticos-8WcsdZARGLv0nXS9KasmK)

(PG54177) Ricardo Alves Oliveira 

(PG54236) Simão Oliveira Alvim Barroso

## TP4 - Exercício 1

In [865]:
import time
from sage.all import *
import hashlib


In [866]:
def bits_to_integer(y):
    alpha = len(y)
    x = 0
    for i in range(1, alpha + 1):
        x = y[alpha - i] + 2 * x
    return x

def bits_to_bytes(y):
    c = len(y)
    num_bytes = ceil(c / 8)
    z = [0] * num_bytes
    for i in range(c):
        z[i // 8] = z[i // 8] + y[i] * 2**(i % 8)
    return z

def bit_reverse(x, bits):
    y = 0
    for i in range(bits):
        y = (y << 1) | (x & 1)
        x >>= 1
    return y

def bytes_to_bits(z):
    d = len(z)
    y = [0] * (d * 8)
    for i in range(d):
        for j in range(8):
            y[8*i + j] = z[i] % 2
            z[i] = z[i] // 2
    return y

def bitlen(x):
    return x.nbits()

def simple_bit_pack(w, b):
    z = []
    for i in range(256):
        z = z+integer_to_bits(w[i], bitlen(b))
    return bits_to_bytes(z)

def simple_bit_unpack(v, b):
    c = bitlen(b)
    z = bytes_to_bits(v)
    w = [0] * 256
    for i in range(256):
        # BitsToInteger((z[ic], z[ic + 1], . . . z[ic + c − 1]), c)
        w[i] = bits_to_integer(z[i*c:(i+1)*c])
    return w

def bit_pack(w, a, b):
    z = []
    for i in range(256):
        z = z+integer_to_bits(b - w[i], bitlen(a + b))
    return bits_to_bytes(z)

def bit_unpack(v, a, b):
    c = bitlen(a+b)
    z = bytes_to_bits(v)
    w = []
    for i in range(256):
        wi = b - bits_to_integer(z[i*c:(i+1)*c])
        w.append(wi)
    return w

def H1024(input_bytes):
    hash_output = hashlib.sha256(input_bytes).digest()
    concatenated_output = hash_output
    while len(concatenated_output) < 128:  
        hash_output = hashlib.sha256(hash_output).digest()
        concatenated_output += hash_output
    concatenated_output = concatenated_output[:128]  

    bit_array = []
    for byte in concatenated_output:
        bits = bin(byte)[2:].zfill(8) 
        bit_array.extend(int(bit) for bit in bits)

    return bit_array

def integer_to_bits(x, alpha):
    y = [0] * alpha
    for i in range(alpha):
        y[i] = Integer(x) % Integer(2)
        x = x // 2
    return y

def coef_from_three_bytes(b0, b1, b2, q):
    if b2 > 127:
        b2 -= 128 
    z = 2 ** 16 * b2 + 2 ** 8 * b1 + b0
    if z < q:
        return z
    else:
        return None

def coef_from_half_byte(b, n):
    if n == 2 and b < 15:
        return 2 - (b % 5)
    elif n == 4 and b < 9:
        return 4 - b
    else:
        return None
    
def rej_ntt_poly(seed, q):
    a_hat = [None] * 256
    j = 0
    c = 0
    hsh=bits_to_bytes(H1024(bytearray(seed)))
    while j < 256:
        a_hat[j] = coef_from_three_bytes(hsh[c%128],hsh[(c+1)%128],hsh[(c+2)%128],q)
        c += 3
        if a_hat[j] is not None:
            j += 1
    return a_hat

def rej_bounded_poly(seed, q, n):
    a = [0] * 256
    j = 0
    c = 0
    hsh=bits_to_bytes(H1024(bytearray(seed)))
    while j < 256:
        z = hsh[c%128]
        z0 = coef_from_half_byte(z % 16, n)
        z1 = coef_from_half_byte(z // 16, n)
        if z0 is not None:
            a[j] = z0
            j += 1
        if z1 is not None and j < 256:
            a[j] = z1
            j += 1
        c += 1
    return a

def expand_a(p, q, k, l):
    A_hat = [[None for _ in range(l)] for _ in range(k)] 
    for r in range(k):
        for s in range(l):
            bits_s = integer_to_bits(s, 8)
            bits_r = integer_to_bits(r, 8)
            combined_bytes = bytearray(p) + bytearray(bits_s) + bytearray(bits_r)
            A_hat[r][s] = rej_ntt_poly(combined_bytes, q)[s]
    return A_hat

def expand_s(p, l, k,q,n):
    s1 = [rej_bounded_poly(p + integer_to_bits(r, 16),q,n) for r in range(l)]
    s2 = [rej_bounded_poly(p + integer_to_bits(r + l, 16),q,n) for r in range(k)]
    return (s1, s2)

def polynomial_mul(a, b, q):
    return [(a[i] * b[i]) % q for i in range(len(a))]

def ntt(w, q):
    w_hat = [0 for _ in range(256)]
    for j in range(256):
        w_hat[j] = w[j]
    size = 256
    z=1753
    zeta_brvs = [pow(z, bit_reverse(k, 8), q) for k in range(size)]
    k = 0
    length = size // 2
    while length >= 1:
        start = 0
        while start < size:
            k += 1
            zeta = zeta_brvs[k % size] % q
            for j in range(start, start + length):
                t = zeta * w_hat[j + length] % q
                w_hat[j + length] = (w_hat[j] - t) % q
                w_hat[j] = (w_hat[j] + t) % q
            start += 2 * length
        length //= 2
    return w_hat

def ntt_inverse(w_hat, q):
    size = 256
    w = [0 for _ in range(256)]
    for j in range(256):
        w[j] = w_hat[j]
    z=1753
    zeta_brvs = [pow(z, bit_reverse(k, 8), q) for k in range(size)]
    k = size
    length = 1
    while length < size:
        start = 0
        while start < size:
            k -= 1
            zeta = (-zeta_brvs[k % size]) % q
            for j in range(start, start + length):
                t = w[j]
                w[j] = (t + w[j + length]) % q
                w[j + length] = (t - w[j + length]) % q
                w[j + length] = zeta * w[j + length] % q
            start += 2 * length
        length *= 2
    f = inverse_mod(size, q)
    for j in range(size):
        w[j] = f * w[j] % q
    return w

def mod_plus_minus(m, alpha):
    m_prime = Integer(m) % Integer(alpha)
    if m_prime > alpha // 2:
        m_prime -= alpha
    return m_prime

def power2_round(r, q, d):
    r_plus = Integer(r) % Integer(q)
    r0 = mod_plus_minus(r_plus,2**d)
    r1 = (r_plus - r0) // (2**d)
    return (r1, r0)

def pk_encode(p, t1, q, d, k):
    pk = bits_to_bytes(p)
    max_value = (2 ** (bitlen(q - 1) - d)) - 1
    for i in range(k):
        pk += simple_bit_pack(t1[i], max_value)
    return pk

def pk_decode(pk, q, d, k):
    y = pk[:32]
    len_z = 32*(bitlen(q - 1)-d)
    z=[]
    for i in range(k):
        z.append(pk[32+i*len_z:32+(i+1)*len_z])
    p = bytes_to_bits(y)
    max_value = (2 ** (bitlen(q - 1) - d)) - 1
    t = [None for _ in range(k)]
    for i in range(k):
        t[i] = simple_bit_unpack(z[i], max_value)
    return p, t

def sk_encode(p, K, tr, s1, s2, t0, d, n): 
    sk = bits_to_bytes(p) + bits_to_bytes(K) + bits_to_bytes(tr)
    for si in s1:
        sk = sk + bit_pack(si, n, n)
    for si in s2:
        sk = sk + bit_pack(si, n, n)
    for ti in t0:
        sk = sk + bit_pack(ti, (2**(d-1))-1, 2**(d-1))
    return sk

def sk_decode(sk, d, n, l, k):
    f = sk[:32]
    g = sk[32:64]
    h = sk[64:128]
    a_len = 32 * bitlen(2*n)
    y=[]
    for i in range(l):
        y.append(sk[128+i*a_len:128+(i+1)*a_len])
    z=[]
    for i in range(k):
        z.append(sk[128+l*a_len+i*a_len:128+l*a_len+(i+1)*a_len])
    w_len = 32*d
    w=[]
    for i in range(k):
        w.append(sk[128+(l+k)*a_len+i*w_len:128+(l+k)*a_len+(i+1)*w_len])
    p = bytes_to_bits(f)
    K = bytes_to_bits(g)
    tr = bytes_to_bits(h)
    s1 = [bit_unpack(yi, n, n) for yi in y]
    s2 = [bit_unpack(zi, n, n) for zi in z]
    t0 = [bit_unpack(wi, (2**(d-1))-1, 2**(d-1)) for wi in w]
    return p, K, tr, s1, s2, t0

In [867]:
def ML_DSA_KeyGen(k, l, q, d, n):
    eps = [randint(0, 1) for _ in range(256)]
    print('eps:',eps)

    H_output = H1024(bytearray(eps))
    p = H_output[:256]
    p_ = H_output[256:768]
    K = H_output[768:]
    print('p:',p)
    print('p_:',p_)
    print('k:',K)

    A_hat = expand_a(p, q, k, l)
    print('A_hat:',A_hat)

    s1, s2 = expand_s(p_, l, k,q,n)
    print('s1:',s1)
    print('s2:',s2)


    A_NTT_s1 =[]
    for i in range(l):
        s1_poly = PolynomialRing(Zmod(q), 'x')(s1[i])
        ntt_s1 = ntt(s1_poly, q)
        A_NTT_s1.append(polynomial_mul(A_hat[i], ntt_s1, q))
    t = [ntt_inverse(PolynomialRing(Zmod(q), 'x')(A_NTT_s1_row), q) for A_NTT_s1_row in A_NTT_s1]
    for i in range(k):
        for j in range(256):
            t[i][j] = (t[i][j] + s2[i][j])
    print('t:',t)

    t1, t0 = [], []
    for tt in t:
        tt1, tt0 = [], []
        for ti in tt:
            t1i, t0i = power2_round(ti, q, d)
            tt1.append(t1i)
            tt0.append(t0i)
        t1.append(tt1)
        t0.append(tt0)
    print('t1:',t1)
    print('t0:',t0)

    pk = pk_encode(p, t1, q, d,k)
    print('pk:',pk)

    tr = H1024(bytearray(pk))[:512]
    print('tr:',tr)

    sk = sk_encode(p, K, tr, s1, s2, t0, d, n)
    print('sk:',sk)

    return pk, sk

In [868]:
# Parameters
Tq = 128
tau=39
k = 4
l = 4
q = 8380417
d = 13
n = 2
gamma_1 = 2**17
gamma_2 = (q - 1) // 88
omega = 80
beta = 79

pk, sk = ML_DSA_KeyGen(k, l, q, d, n)
print('pk:',pk)
print('sk:',sk)

[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1]
p: [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0,

In [869]:
def decompose(r,q,y2):
    r_plus = Integer(r) % Integer(q)
    r0 = mod_plus_minus(r_plus,(2*y2))
    if r_plus - r0 == q - 1:
        r1 = 0
        r0 = r0 - 1
    else:
        r1 = (r_plus - r0) // (2*y2)
    return (r1, r0)

def high_bits(r,q,y2):
    (r1,r0) = decompose(r, q, y2)
    return r1

def low_bits(r,q,y2):
    (r1,r0) = decompose(r, q, y2)
    return r0

def expand_mask(seed, mu, l, gamma_1):
    c = 1 + bitlen(gamma_1-1)
    s = [None for _ in range(l)]
    for r in range(l):
        n = integer_to_bits(mu+r, 16)
        v = [ None for _ in range(32*c)]
        byts = bits_to_bytes(H1024(bytearray(seed+n)))
        for i in range(32*c):
            v[i] = byts[(32*r*c + i)%128]
        s[r] = bit_unpack(v, gamma_1 - 1, gamma_1)
    return s

def w1_encode(w1, k, q, y2):
    w1_tilde = []
    for i in range(k):
        w1_tilde = w1_tilde + bytes_to_bits(simple_bit_pack(w1[i], Integer(((q - 1) / (2 * y2)) - 1)))
    return w1_tilde

def sample_in_ball(seed, tau):
    c = [0 for _ in range(256)]
    k = 8
    for i in range(256 - tau, 256):
        while bits_to_bytes(H1024(bytearray(seed)))[k%128] > i:
            k += 1
        j = bits_to_bytes(H1024(bytearray(seed)))[k%128]
        c[i] = c[j]
        c[j] = (-1) ** (H1024(bytearray(seed))[i+tau-256])
        k += 1
    return c

def make_hint(z, r, q, y2):
    r1 = high_bits(r,q,y2)
    v1 = high_bits(r + z,q,y2)
    if r1 != v1:
        return 1
    return 0

def use_hint(h, r, q, gamma_2):
    m = (q - 1) // (2 * gamma_2)
    r1,r0 = decompose(r, q, gamma_2)
    if h==1 and r0 > 0:
        return (r1 + 1) % m
    elif h==1 and r0 <= 0:
        return (r1 - 1) % m
    else:
        return r1

def hint_bit_pack(h,omega,k):
    y = [0] * (omega + k)
    index = 0
    for i in range(k):
        for j in range(256):
            if h[i][j] != 0: 
                y[index] = j
                index += 1
        y[omega + i] = index
    return y

def sig_encode(c_til, z, h, l, y1, omega, K):
    o = bits_to_bytes(c_til)
    for i in range(l):
        o = o + bit_pack(z[i], y1 - 1, y1)
    o = o + hint_bit_pack(h,omega,k)
    return o


def hint_bit_unpack(y, k, omega):
    h = [[0]*256 for _ in range(k)]  
    index = 0
    for i in range(k):
        if y[omega + i] < index or y[omega + i] > omega:
            return None
        while index < y[omega + i]:
            h[i][y[index]] = 1
            index += 1
    while index < omega:
        if y[index] != 0:
            return None
        index += 1
    return h

def sig_decode(o, y1, lbd, l, k, omega):
    lbd = lbd // 4
    w = o[:lbd]
    x = []
    for i in range(l):
        x.append(o[lbd + 32*(1+bitlen(y1-1))*i:lbd + 32*(1+bitlen(y1-1))*(i+1)])
    y = o[lbd + l*32*(1+bitlen(y1-1)):]
    c_til = bytes_to_bits(w)
    z = [None for _ in range(l)]
    for i in range(l):
        z[i] = bit_unpack(x[i], y1 - 1, y1)
    h = hint_bit_unpack(y,k,omega)
    return c_til, z, h


In [870]:
def ML_DSA_Sign(sk, M, Tq, q, d, n, l, k, tau, gamma_1, gamma_2, omega, beta):

    rho, K, tr, s1, s2, t0 = sk_decode(sk, d, n, l, k)
    print('M:',M)
    print('rho:',rho)
    print('K:',K)
    print('tr:',tr)
    print('s1:',s1)
    print('s2:',s2)
    print('t0:',t0)

    mu = H1024(bytearray(tr+M))[:512]
    print('mu:',mu)

    s1_hat = [ntt(poly, q) for poly in s1]
    s2_hat = [ntt(poly, q) for poly in s2]
    t0_hat = [ntt(poly, q) for poly in t0]    
    A_hat = expand_a(rho, q, k, l)
    rnd = [randint(0, 1) for _ in range(256)]
    rho_prime = H1024(bytearray(K + rnd + mu))[:512]
    kappa = 0
    z, zz, h = None, None, None
    while z is None or h is None:
        y = expand_mask(rho_prime, kappa, l, gamma_1)
        A_NTT =[]
        for i in range(l):
            y_poly = PolynomialRing(Zmod(q), 'x')(y[i])
            ntt_y = ntt(y_poly, q)
            A_NTT.append(polynomial_mul(A_hat[i], ntt_y, q))
        w = [ntt_inverse(PolynomialRing(Zmod(q), 'x')(A_NTT_row), q) for A_NTT_row in A_NTT]
        w1 = [[high_bits(n,q,gamma_2) for n in poly] for poly in w]
        c_til = H1024(bytearray(mu + w1_encode(w1, k, q, gamma_2)))[:(2*Tq)]
        c_til_1, c_til_2 = c_til[:256], c_til[256:]

        c = sample_in_ball(c_til_1, tau)
        print('c:',c)
        
        c_poly = PolynomialRing(Zmod(q), 'x')(c)
        c_hat = ntt(c_poly, q)
        cs1_hat = [ntt_inverse(PolynomialRing(Zmod(q), 'x')(polynomial_mul(c_hat,s_hat,q)), q) for s_hat in s1_hat]
        cs2_hat = [ntt_inverse(PolynomialRing(Zmod(q), 'x')(polynomial_mul(c_hat,s_hat,q)), q) for s_hat in s2_hat]
        z=[[Integer(y[i][j])+Integer(cs1_hat[i][j]) for j in range(len(cs1_hat[i]))] for i in range(len(cs1_hat))]
        r0 = [[low_bits(Integer(w[i][j]) - Integer(cs2_hat[i][j]), q, gamma_2) for j in range(len(cs2_hat[i]))] for i in range(len(cs2_hat))]
        zz = [[mod_plus_minus(ze,q) for ze in zl] for zl in z]

        if max(map(max, zz)) >= gamma_1 - beta or max(map(max, r0)) >= gamma_2 - beta:
            z, h = None, None
        else:
            ct0_hat = [ntt_inverse(PolynomialRing(Zmod(q), 'x')(polynomial_mul(c_hat,t_hat,q)), q) for t_hat in t0_hat]
            h = [[make_hint((Integer(-1) * Integer(ct0_hat[x][y])),(Integer(w[x][y])-Integer(cs2_hat[x][y])+Integer(ct0_hat[x][y])),q,gamma_2) for y in range(len(w[x]))] for x in range(len(w))]
            ct0_pm = [[mod_plus_minus(y,q) for y in x] for x in ct0_hat]
            if max(map(max, ct0_pm)) >= gamma_2 or sum(map(sum, h)) > omega:
                z, h = None, None
        kappa += l
    sigma = sig_encode(c_til, zz, h, l, gamma_1, omega, k)
    

    ctilt,zt,ht = sig_decode(sigma, gamma_1, Tq, l, k, omega)
    print('ctilt:',ctilt)
    print('ctil:',c_til)
    print('ctil==ctilt:',c_til==ctilt)
    print('zt:',zt)
    print('z:',zz)
    print('z==zt:',zz==zt)
    print('ht:',ht)
    print('h:',h)
    print('h==ht:',h==ht)

    return sigma

In [871]:
# Step 2: Choose a message
M = "This is a test message."
M_bytes = bytearray(M.encode())
M_bits = bytes_to_bits(M_bytes)

# Step 3: Sign the message
sign = ML_DSA_Sign(sk, M_bits, Tq, q, d, n, l, k, tau, gamma_1, gamma_2, omega, beta)
print('sign:',sign)
print('sign_expected_len:',32+32*l*(1+bitlen(gamma_1-1))+k+omega)
print('sign_len:',len(sign))


M: [0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0]
rho: [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 

In [876]:
def ML_DSA_Verify(pk, M, sigma, q, tau, gamma_1, gamma_2, omega, beta, Tq, k, l, d):
    rho, t1 = pk_decode(pk, q, d, k)
    print('rho:',rho)
    print('t1:',t1)


    c_til, z, h = sig_decode(sigma, gamma_1, Tq, l, k, omega)
    ctil_, _ = c_til[:256], c_til[256:]
    print('c_til:',c_til)
    print('z:',z)
    print('h:',h)

    if h is None:
        return False

    A_hat = expand_a(rho, q, k, l)
    print('A_hat:',A_hat)

    tr = H1024(bytearray(pk))[:512]
    print('tr:',tr)
    
    mu = H1024(bytearray(tr + M))[:512]
    print('mu:',mu)


    c = sample_in_ball(ctil_, tau)
    print('c:',c)

    ntt_z = [ntt(PolynomialRing(Zmod(q), 'x')(poly), q) for poly in z]
    A_NTT = [polynomial_mul(A_hat[i], ntt_z[i], q) for i in range(l)]
    ntt_c = ntt(PolynomialRing(Zmod(q), 'x')(c), q)
    ntt_t1= [ntt(PolynomialRing(Zmod(q), 'x')([x*(2**d) for x in poly]), q) for poly in t1]
    t1_c = [polynomial_mul(ntt_t1[i], ntt_c, q) for i in range(len(t1))]
    fntt=[[(Integer(A_NTT[i][j])-Integer(t1_c[i][j])) % q for j in range(len(A_NTT[i]))] for i in range(len(A_NTT))]
    waprox = [ntt_inverse(PolynomialRing(Zmod(q), 'x')(poly), q) for poly in fntt]
    w_prime = [[use_hint(Integer(h[i][j]) % q,Integer(waprox[i][j]),q,gamma_2) for j in range(len(waprox[i]))] for i in range(len(waprox))]
    c_til_ = H1024(bytearray(mu + w1_encode(w_prime, k, q, gamma_2)))[:(2*Tq)]
    if len(c_til_) != len(c_til):
        return False
    zz = [[mod_plus_minus(ze,q) for ze in zl] for zl in z]

    return c_til==ctil_ and max(map(max, zz)) < gamma_1 - beta and sum(map(sum, h)) <= omega


In [877]:
# Step 4: Verify the signature
is_valid = ML_DSA_Verify(pk, M_bits, sign, q, tau, gamma_1, gamma_2, omega, beta, Tq, k, l, d)

# Step 5: Print the result
if is_valid:
    print("Signature is valid.")
else:
    print("Signature is invalid.")


rho: [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1]
t1: [[448, 670, 229, 303, 598, 758, 778, 653, 959, 841, 247, 57, 477, 25, 1016, 286, 328, 905, 1017, 762, 656, 118, 829, 623, 40, 990, 800, 6, 144, 245, 348, 817, 420, 437, 35, 918, 189, 152, 152, 776, 219, 400, 524, 479, 843,