In [1]:
# Prerequisites

import hashlib
import random
import time

security_bits = 128

F = GF(2**8, name = 'a') # galois field GF(2^8)
v = 2       # number of polynomials in central map
l = 3 * v   # number of variables of GF(256)
P = PolynomialRing(F, names=[f'x{i}' for i in range(1, l + 1)])
a = P.gens()

hash_len_1 = l
hash_len_2 = v*l*(l + 3)//2
hash_len_3 = v

Message = b'Hello!'  # byte string of message sent
pretty_print(Message)

In [2]:
# Helper Functions

def bytes_to_vector(byte_string):
        
    temp_vector = vector(F, [0]*(len(byte_string) // 2 ))

    for i in range(0, len(byte_string), 2):
        hex_pair = byte_string[i:i+2]
        elem = int(hex_pair, 16)
        temp_vector[i // 2] = F.from_integer(elem)

    return temp_vector

def bytes_to_list(byte_string):

    return bytes_to_vector(byte_string).list()

def bits_to_bytes(bits, bits_length):
    return bits.to_bytes(bits_length // 8, byteorder = 'big')

def list_to_bytes(in_list):

    byte_string = b''
    
    for i in range(len(in_list)):
        int_val = in_list[i].to_integer()
        byte_string += int_val.to_bytes(1, byteorder = 'big')

    return byte_string

def vector_to_bytes(vector):

    byte_string = b''
    
    for i in range(len(vector)):
        int_val = vector[i].to_integer()
        byte_string += int_val.to_bytes(1, byteorder = 'big')
        
    return byte_string

def random_bits(bit_length):
    return random.getrandbits(bit_length)

def hashFunction(M, output_length):
    
    # Create SHAKE128 hash object
    shake_hash = hashlib.shake_128()
    shake_hash.update(M)
    
    # Return the hash as a hexadecimal string
    return shake_hash.hexdigest(output_length)

def computePolarForm(poly_system, u, v):
    sum_vector = vector(F, [F(u[i] + v[i]) for i in range(len(v))])

    val1 = [poly.subs({a[j]: sum_vector[j] for j in range(len(sum_vector))}) for poly in poly_system]
    val2 = [poly.subs({a[j]: u[j] for j in range(len(u))}) for poly in poly_system]
    val3 = [poly.subs({a[j]: v[j] for j in range(len(v))}) for poly in poly_system]

    return vector(F, [F(val1[i] - val2[i] - val3[i]) for i in range(len(val1))])

In [3]:
# PRIVATE KEYGEN

keygen_s_time = time.time()

central_map = []

# Generate Nu Linear tranformation
while True:
    Nu = random_matrix(F, l, l)
    if Nu.is_invertible():
        break

# generate the central map consisting of multivariate quadratic polynomials
central_map = []
for i in range(v):

    poly = P.zero()

    # quadratic terms
    for j in range(l - v):
        for k in range(j, l):
            poly += F.random_element() * a[j] * a[k]

    # linear terms
    for j in range(l):
        poly += F.random_element() * a[j]

    # No Constant term is added so polar form is computed easily
    central_map.append(poly)

# generate Z
z = random_bits(security_bits)

In [4]:
# PUBLIC KEYGEN

transformation_nu = [P(0) for _ in range(l)]

for i in range(l):
    for j in range(l):
        transformation_nu[i] += Nu[i][j] * a[j]
        
trans_central_map = []

for poly in central_map:
    pub_poly = poly
    substitutions = {a[j]: transformation_nu[j] for j in range(l)}
    pub_poly = pub_poly.subs(substitutions)
    trans_central_map.append(pub_poly)


keygen_e_time = time.time()

pretty_print("Keygen Time (in seconds): ", keygen_e_time - keygen_s_time)

In [5]:
# SIGNING

sign_s_time = time.time()

alpha = random_bits(security_bits)
sd = random_bits(security_bits)
alpha_bytes = bits_to_bytes(alpha, security_bits)
sd_bytes = bits_to_bytes(sd, security_bits)

s = hashFunction(alpha_bytes, hash_len_1)
temp_hash_1 = hashFunction(sd_bytes, hash_len_2)
temp_hash_2 = hashFunction(Message + sd_bytes, hash_len_2)

d_prime = []
poly_size = l*(l+3)//2

for i in range(v):
    counter = 0
    poly = P.zero()
    for j in range(l):
        for k in range(j, l):
            byte_val = temp_hash_1[counter + i * poly_size]
            elem = int(byte_val, 16)
            poly += F.from_integer(elem) * a[j] * a[k]
            counter += 1

    for j in range(l):
        byte_val = temp_hash_1[counter + i * poly_size]
        elem = int(byte_val, 16)
        poly += F.from_integer(elem) * a[j]
        counter += 1

    d_prime.append(poly)

d_prime_m = []

for i in range(v):
    counter = 0
    poly = P.zero()
    for j in range(l):
        for k in range(j, l):
            byte_val = temp_hash_2[counter + i * poly_size]
            elem = int(byte_val, 16)
            poly += F.from_integer(elem) * a[j] * a[k]
            counter += 1

    for j in range(l):
        byte_val = temp_hash_2[counter + i * poly_size]
        elem = int(byte_val, 16)
        poly += F.from_integer(elem) * a[j]
        counter += 1

    d_prime_m.append(poly)


temp_priv_key_s = bytes_to_list(s)


y = [poly(*temp_priv_key_s) for poly in d_prime]
sigma_1 = [poly(*temp_priv_key_s) for poly in d_prime_m]
sigma_2 = (y, alpha ^^ z, sd)

sigma_1_bytes = list_to_bytes(sigma_1)

y_bytes = list_to_bytes(y)
alpha_z_bytes = bits_to_bytes(alpha ^^ z, security_bits)
sigma_2_bytes = b''.join([y_bytes, alpha_z_bytes, sd_bytes])

hashed_vector = bytes_to_vector(hashFunction(sigma_1_bytes + sigma_2_bytes, hash_len_3))

solved = False

while not solved:
    # Generate random elements for l - v variables
    l_minus_v = [F.random_element() for i in range(l - v)]

    # Substitute to form a linear system
    linear_eqs = [poly.subs({a[j]: l_minus_v[j] for j in range(l - v)}) for poly in central_map]
    
    # Construct matrix C with coefficients of a[l - v + i] for i in range(v)
    C = Matrix(F, v, v, [[eq.coefficient(a[l - v + i]) for i in range(v)] for eq in linear_eqs])
    d = vector(F, [hashed_vector[i] - eq.constant_coefficient() for i, eq in enumerate(linear_eqs)])
    
    # Solve the linear system Cx = d if rank condition is met
    if C.rank() == v:
        try:
            solution = C.solve_right(d)
            solved = True
        except Exception:
            solved = False

sol_vector = vector(F, l_minus_v + list(solution))
sigma_3 = Nu.inverse() * sol_vector

sign_e_time = time.time()

pretty_print("Sigma1:", sigma_1)
pretty_print("Sigma2:", sigma_2)
pretty_print("Sigma3:", sigma_3)
pretty_print("Signing Time (in seconds): ", sign_e_time - sign_s_time)

In [6]:
# VERIFICATION STEP 1

verification_s_time = time.time()

result = vector(F, [poly.subs({a[j]: sigma_3[j] for j in range(len(sigma_3))}) for poly in trans_central_map])

ver_true = True

if (hashed_vector != result):
    ver_true = False
    pretty_print("Verification Step 1 failed!")
        
if ver_true:        
    pretty_print("Verification Step 1 Completed Successfully!")

In [7]:
# Verification Table

u0 = vector(F, [F.random_element() for i in range(l)])
v0 = vector(F, [F.random_element() for i in range(l)])
p0 = vector(F, [F.random_element() for i in range(l)])

w0 = vector(F, [F.random_element() for i in range(v)])
q0 = vector(F, [F.random_element() for i in range(v)])

u1 = vector(F, [temp_priv_key_s[i] - u0[i] for i in range(l)])
v1 = vector(F, [u0[i] - v0[i] for i in range(l)])
p1 = vector(F, [u1[i] - p0[i] for i in range(l)])

w1 = vector(F, [poly.subs({a[j]: u0[j] for j in range(l)}) for poly in d_prime])
w1 = vector(F, [w1[i] - w0[i] for i in range(v)])

w1_star = vector(F, [poly.subs({a[j]: u0[j] for j in range(l)}) for poly in d_prime_m])
w1_star = vector(F, [w1_star[i] - w0[i] for i in range(v)])

q1 = vector(F, [poly.subs({a[j]: u1[j] for j in range(l)}) for poly in d_prime])
q1 = vector(F, [q1[i] - q0[i] for i in range(v)])

q1_star =  vector(F, [poly.subs({a[j]: u1[j] for j in range(l)}) for poly in d_prime_m])
q1_star = vector(F, [q1_star[i] - q0[i] for i in range(v)])

In [8]:
# COMPUTE COMMITMENTS

commitment_length = (security_bits * 2) // 8 # security_bits is 128, hence 32 byte output from commitment

c0 = hashFunction(vector_to_bytes(u1) + vector_to_bytes(vector(F, computePolarForm(d_prime, v0, u1) + w0)), commitment_length)
c0_star = hashFunction(vector_to_bytes(u1) + vector_to_bytes(vector(F, computePolarForm(d_prime_m, v0, u1) + w0)), commitment_length)

c1 = hashFunction(vector_to_bytes(u0) + vector_to_bytes(vector(F, computePolarForm(d_prime, u0, p0) + q0)), commitment_length)
c1_star = hashFunction(vector_to_bytes(u0) + vector_to_bytes(vector(F, computePolarForm(d_prime_m, u0, p0) + q0)), commitment_length)

c2 = hashFunction(vector_to_bytes(v0) + vector_to_bytes(w0), commitment_length)
c3 = hashFunction(vector_to_bytes(v1) + vector_to_bytes(w1) + vector_to_bytes(w1_star), commitment_length)
c4 = hashFunction(vector_to_bytes(p0) + vector_to_bytes(q0), commitment_length)
c5 = hashFunction(vector_to_bytes(p1) + vector_to_bytes(q1) + vector_to_bytes(q1_star), commitment_length)

In [9]:
# VERIFIER CHOOSES A CHALLENGE FROM {0, 1, 2, 3}

chosen_challenge = randint(0, 3)

pretty_print("Challenge chosen:", chosen_challenge)

In [10]:
# SIGNER GENERATES RESPONSE AS PER CHALLENGE

if (chosen_challenge == 0):
    response = (u0, v1, w1, w1_star, p1, q1, q1_star)
elif (chosen_challenge == 1):
    response = (u0, v0, w0, p0, q0)
elif (chosen_challenge == 2):
    response = (u1, v1, w1, w1_star, p1, q1, q1_star)
elif (chosen_challenge == 3):
    response = (u1, v0, w0, p0, q0)

In [11]:
# VERIFIER VERIFIES COMMITMENTS

# CHALLENGE 0

if (chosen_challenge == 0):
    temp_val1 = vector(F, [poly.subs({a[j]: u0[j] for j in range(l)}) for poly in d_prime])  # d_prime(u0)
    temp_val2 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(vector(F, y) - temp_val1 - computePolarForm(d_prime, response[0], response[4]) - response[5]), commitment_length)

    temp_val3 = vector(F, [poly.subs({a[j]: u0[j] for j in range(l)}) for poly in d_prime_m])  # d_prime_m(u0)
    temp_val4 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(vector(F, sigma_1) - temp_val3 - computePolarForm(d_prime_m, response[0], response[4]) - response[6]), commitment_length)

    temp_val5 = hashFunction(vector_to_bytes(response[0] - response[1]) + vector_to_bytes(temp_val1 - response[2]), commitment_length)

    temp_val6 = hashFunction(vector_to_bytes(response[1]) + vector_to_bytes(response[2]) + vector_to_bytes(response[3]), commitment_length)
    temp_val7 = hashFunction(vector_to_bytes(response[4]) + vector_to_bytes(response[5]) + vector_to_bytes(response[6]), commitment_length)
    
    if (c1 == temp_val2 and c1_star == temp_val4):
        pretty_print("Confirmation in Challenge 0!")
    elif (not (c1_star == temp_val4)):
        pretty_print("Disavowal in Challenge 0!")
    
    if (not (c2 == temp_val5 and c3 == temp_val6 and c5 == temp_val7)):
        pretty_print("Challenge 0 Failed!")
    else:
        pretty_print("Challenge 0 Succeeded!")

In [12]:
# CHALLENGE 1

if (chosen_challenge == 1):
    
    temp_val1 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(computePolarForm(d_prime, response[0], response[3]) + response[4]), commitment_length)
    temp_val2 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(computePolarForm(d_prime_m, response[0], response[3]) + response[4]), commitment_length)
    temp_val3 = hashFunction(vector_to_bytes(response[1]) + vector_to_bytes(response[2]), commitment_length)

    mid_result1 = vector(F, [poly.subs({a[j]: u0[j] for j in range(l)}) for poly in d_prime])
    mid_result2 = vector(F, [poly.subs({a[j]: u0[j] for j in range(l)}) for poly in d_prime_m])
    
    temp_val4 = hashFunction(vector_to_bytes(response[0] - response[1]) + vector_to_bytes(mid_result1 - response[2]) + vector_to_bytes(mid_result2 - response[2]), commitment_length)
    temp_val5 = hashFunction(vector_to_bytes(response[3]) + vector_to_bytes(response[4]), commitment_length)
    
    if (not (c1 == temp_val1 and c1_star == temp_val2 and c2 == temp_val3 and c3 == temp_val4 and c4 == temp_val5)):
        pretty_print("Challenge 1 Failed!")
    else:
        pretty_print("Challenge 1 Succeeded!")

In [13]:
# CHALLENGE 2

if (chosen_challenge == 2):

    temp_val0 = vector(F, [poly.subs({a[j]: u1[j] for j in range(l)}) for poly in d_prime])
    temp_val1 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(vector(F, y) - temp_val0 - computePolarForm(d_prime, response[1], response[0]) - response[2]), commitment_length)

    temp_val2 = vector(F, [poly.subs({a[j]: u1[j] for j in range(l)}) for poly in d_prime_m])
    temp_val3 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(vector(F, sigma_1) - temp_val2 - computePolarForm(d_prime_m, response[1], response[0]) - response[3]), commitment_length)

    temp_val4 = hashFunction(vector_to_bytes(response[1]) + vector_to_bytes(response[2]) + vector_to_bytes(response[3]), commitment_length)
    temp_val5 = hashFunction(vector_to_bytes(response[0] - response[4]) + vector_to_bytes(temp_val0 - response[5]), commitment_length)
    temp_val6 = hashFunction(vector_to_bytes(response[4]) + vector_to_bytes(response[5]) + vector_to_bytes(response[6]), commitment_length)

    if (c0 == temp_val1 and c0_star == temp_val3):
        pretty_print("Confirmation in challenge 2!")
    elif (not (c0_star == temp_val3)):
        pretty_print("Disavowal in challenge 2!")
    
    if (not (c3 == temp_val4 and c4 == temp_val5 and c5 == temp_val6)):
        pretty_print("Challenge 2 Failed!")
    else:
        pretty_print("Challenge 2 Succeeded!")

In [14]:
# CHALLENGE 3

if (chosen_challenge == 3):

    temp_val0 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(computePolarForm(d_prime, response[1], response[0]) + response[2]), commitment_length) 
    temp_val1 = hashFunction(vector_to_bytes(response[0]) + vector_to_bytes(computePolarForm(d_prime_m, response[1], response[0]) + response[2]), commitment_length)
    temp_val2 = hashFunction(vector_to_bytes(response[1]) + vector_to_bytes(response[2]), commitment_length)
    temp_val3 = hashFunction(vector_to_bytes(response[3]) + vector_to_bytes(response[4]), commitment_length)
    
    temp_val4 = vector(F, [poly.subs({a[j]: u1[j] for j in range(l)}) for poly in d_prime])
    temp_val5 = vector(F, [poly.subs({a[j]: u1[j] for j in range(l)}) for poly in d_prime_m])

    temp_val6 = hashFunction(vector_to_bytes(response[0] - response[3]) + vector_to_bytes(temp_val4 - response[4]) + vector_to_bytes(temp_val5 - response[4]), commitment_length)

    if (not (c0 == temp_val0 and c0_star == temp_val1 and c2 == temp_val2 and c4 == temp_val3 and c5 == temp_val6)):
        pretty_print("Challenge 3 Failed!")
    else:
        pretty_print("Challenge 3 Succeeded!")

verification_e_time = time.time()
pretty_print("Verification Time MCUDS (in seconds): ", verification_e_time - verification_s_time)

In [15]:
# ConMCUDS VERIFICATION PART 1
conmcuds_s_time = time.time()

check1 = vector(F, [poly.subs({a[j]: sigma_3[j] for j in range(len(sigma_3))}) for poly in trans_central_map])
check2 = bytes_to_vector(hashFunction(sigma_1_bytes + sigma_2_bytes, hash_len_3))

if (check1 == check2):
    pretty_print("Accepted!")
else:
    pretty_print("Rejected!")

In [16]:
# ConMCUDS VERIFICATION PART 2

check3 = [poly(*temp_priv_key_s) for poly in d_prime]
check4 = [poly(*temp_priv_key_s) for poly in d_prime_m]

if (check3 == y and check4 == sigma_1):
    pretty_print("Accepted!")
else:
    pretty_print("Rejected")

conmcuds_e_time = time.time()
pretty_print("Verification Time ConMCUDS (in seconds): ", conmcuds_e_time - conmcuds_s_time) 