In [1]:
from CompactFIPS202 import SHAKE256
from sage.functions.log import logb
import random
from time import time

# prime=2**64-59=18446744073709551557
# output-length l=1; c=1
# alpha=3,m=3, security_level=32

#definition of the Grendel_permutation

def legendre_symbol(value,p):
    return value^((p-1)/2)

def get_mds_matrix(p,m):
    #get a primitive element
    Fp=FiniteField(p)
    g=Fp(2)
    while g.multiplicative_order()!=p-1:
        g=g+1
    V=matrix([[g^(i*j) for j in range(0,2*m)] for i in range(0,m)])
    V_ech=V.echelon_form()
    MDS=V_ech[:,m:].transpose()
    return MDS
    
def get_round_constants(p,m,security_level,N):
    bytes_per_int=1+ceil(logb(p,2)/8)
    num_bytes=m*N*bytes_per_int
    seed_string="grendel-%i-%i-%i"%(p,m,security_level)
    byte_string=SHAKE256(bytes(seed_string, "ascii"), num_bytes)
    round_constants=[]
    Fp=FiniteField(p)
    for i in range(m*N):
        acc=0
        for j in range(bytes_per_int):
            acc=256*acc+ZZ(byte_string[i*bytes_per_int+j])
        round_constants.append(Fp(acc %p))
    return round_constants

def Grendel_permutation(p,alpha,m,N,state,linear_layer,round_constants):
    new_state=state
    for i in range(0,N):
        for j in range(0,m):
            new_state[j]=(new_state[j])^alpha*(legendre_symbol(new_state[j],p))
        new_state=linear_layer*vector(new_state)
        for j in range(0,m):
            new_state[j]=new_state[j]+round_constants[i*m+j]
    return new_state


# definition of extract_guesses
def extract_guesses(permutation_guess,number_single_guesses):
    single_guesses=[]
    for i in range(0,number_single_guesses):
        guess_intermediate=(permutation_guess>>i)&0b1
        guess_converted=(guess_intermediate*1)+((guess_intermediate-1)*1)
        single_guesses.append(guess_converted)
    return single_guesses

#Finding a preimage for a sponge hash function instantiated with N-round Grendel over F_p^m
p=18446744073709551557
Fp=GF(p)
security_level=32
alpha=3
alpha_1=inverse_mod(alpha,p-1)
m=3
N=6

M=get_mds_matrix(p,m)
M_inv=M.inverse()

round_constants=get_round_constants(p,m,security_level,N)


print("p,alpha,m,N",p,alpha,m,N)

p,alpha,m,N 18446744073709551557 3 3 6


In [2]:
#build the equation for each guess
L=m*(N-2)
guess_limit=2^L

a=M_inv[2][0]
b=M_inv[2][1]
c=M_inv[2][2]
c_1=c^(-1)

d=round_constants[0]
e=round_constants[1]
f=round_constants[2]


A1=b
A2=-a
g=(((a*d+b*e+c*f)*c_1)^alpha)*legendre_symbol((a*d+b*e+c*f)*c_1,p)

s_1=(A1^alpha)*legendre_symbol(A1,p)
s_2=(A2^alpha)*legendre_symbol(A2,p)

R.<x>=PolynomialRing(Fp)

load("ntl_power_mod.spyx")
print("load: ntl_power_mod.spyx")

Compiling ./ntl_power_mod.spyx...


load: ntl_power_mod.spyx


In [3]:
sum_verify=0
solution=[ ]
sumtime_step1=0
sumtime_step2=0
sumtime_step3=0
for permutation_guess in range(0,guess_limit):
    state=[s_1*x,s_2*x,g]
    single_guesses=extract_guesses(permutation_guess,L)
    new_state=M*vector(state)
    for j in range(m):
        new_state[j]=new_state[j]+round_constants[m+j]
    for i in range(2,N):
        for j in range(m):
            new_state[j]=((new_state[j])^alpha)*single_guesses[(i-2)*m+j]
        new_state=M*vector(new_state)
        for j in range(m):
            new_state[j]=new_state[j]+round_constants[i*m+j]
    equation=new_state[2]
    
    start1=time()
    Q=ntl_power_mod(x,Fp.order(),equation)
    end1=time()
    sumtime_step1=sumtime_step1+(end1-start1)
    
    start2=time()
    R=equation.gcd(Q-x)
    end2=time()
    sumtime_step2=sumtime_step2+(end2-start2)
    
    start3=time()
    roots=R.roots()
    end3=time()
    sumtime_step3=sumtime_step3+(end3-start3)
    
    S=[S[0] for S in roots]
    for r in S:
        flag=True
        state=[s_1*r,s_2*r,g]
        new_state=M*vector(state)
        for j in range(m):
            new_state[j]=new_state[j]+round_constants[m+j]
        for i in range(2,N):
            for j in range(m):
                le=legendre_symbol(new_state[j],p)
                sum_verify=sum_verify+1
                if le!= single_guesses[(i-2)*m+j]:
                    flag=False
                    break
                new_state[j]=((new_state[j])^alpha)*le   
            if flag==True:
                new_state=M*vector(new_state)
                for j in range(m):
                    new_state[j]=new_state[j]+round_constants[i*m+j]
            else:
                break
        if flag==True:
            forward_state=[0]*m
            for j in range(m):
                forward_state[j]=(state[j]^alpha_1)*(legendre_symbol(state[j],p))
            for j in range(m):
                forward_state[j]=forward_state[j]-round_constants[j]
            forward_state=M_inv*vector(forward_state)
            for j in range(m):
                forward_state[j]=(forward_state[j]^alpha_1)*(legendre_symbol(forward_state[j],p))
            input_state=forward_state
            output_state=Grendel_permutation(p,alpha,m,N,input_state,M,round_constants)
            solution.append([input_state,output_state])
            break
    if len(solution)>0:
        break
            
            
log_sum_verify=logb(sum_verify,2)
print('log_sum_verify:%f,L+1:%i' % (log_sum_verify,L+1))
print(solution)
print(sumtime_step1,sumtime_step2,sumtime_step3)

log_sum_verify:11.433064,L+1:13
[[(3978053920827369818, 11054388833269671370, 0), (12866554063376284852, 7184824262592905636, 0)]]
3.1884052753448486 0.33231115341186523 0.2979919910430908
