In [3]:
# Large prime numbers given by the Anemoi designer
# BLS12-381 Base field
BLS12_381_BASEFIELD = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab
# BLS12-381 Scalar field
BLS12_381_SCALARFIELD = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001

# BLS12-377 Base field = BW6_761 Scalar field
BLS12_377_BASEFIELD = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001
# BLS12-377 Scalar field = Ed_on_bls_12_377 Base field
BLS12_377_SCALARFIELD = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001

# BN-254 Base field
BN_254_BASEFIELD = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47
# BN-254 Scalar field
BN_254_SCALARFIELD = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001

# Pallas Base field = Vesta Scalar field
PALLAS_BASEFIELD = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001

# Vesta Base field = Pallas Scalar field
VESTA_BASEFIELD = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001

# Small Goldilocks field
GOLDILOCKS_64_FIELD = 0xffffffff00000001

In [82]:
#!/usr/bin/sage
# -*- mode: python ; -*-
#An implementation of the Anemoi algorithm given by the Anemoi designers

from sage.all import *
import hashlib
import itertools

COST_ALPHA = {
    3   : 2, 5   : 3, 7   : 4, 9   : 4,
    11  : 5, 13  : 5, 15  : 5, 17  : 5,
    19  : 6, 21  : 6, 23  : 6, 25  : 6,
    27  : 6, 29  : 7, 31  : 7, 33  : 6,
    35  : 7, 37  : 7, 39  : 7, 41  : 7,
    43  : 7, 45  : 7, 47  : 8, 49  : 7,
    51  : 7, 53  : 8, 55  : 8, 57  : 8,
    59  : 8, 61  : 8, 63  : 8, 65  : 7,
    67  : 8, 69  : 8, 71  : 9, 73  : 8,
    75  : 8, 77  : 8, 79  : 9, 81  : 8,
    83  : 8, 85  : 8, 87  : 9, 89  : 9,
    91  : 9, 93  : 9, 95  : 9, 97  : 8,
    99  : 8, 101 : 9, 103 : 9, 105 : 9,
    107 : 9, 109 : 9, 111 : 9, 113 : 9,
    115 : 9, 117 : 9, 119 : 9, 121 : 9,
    123 : 9, 125 : 9, 127 : 10,
}

ALPHA_BY_COST = {
    c : [x for x in range(3, 128, 2) if COST_ALPHA[x] == c]
    for c in range(2, 11)
}

PI_0 = 1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679
PI_1 = 8214808651328230664709384460955058223172535940812848111745028410270193852110555964462294895493038196

def get_prime(N):
    """Returns the highest prime number that is strictly smaller than
    2**N.

    """
    result = (1 << N) - 1
    while not is_prime(result):
        result -= 2
    return result


def get_n_rounds(s, l, alpha):
    """Returns the number of rounds needed in Anemoi (based on the
    complexity of algebraic attacks).

    """
    r = 0
    complexity = 0
    kappa = {3:1, 5:2, 7:4, 9:7, 11:9}
    assert alpha in kappa
    while complexity < 2**s:
        r += 1
        complexity = binomial(
            4*l*r + kappa[alpha],
            2*l*r
        )**2
    r += 2 # considering the second model
    r += min(5,l+1) # security margin
    
    return max(8, r)


# Linear layer generation

def is_mds(m):
    # Uses the Laplace expansion of the determinant to calculate the (m+1)x(m+1) minors in terms of the mxm minors.
    # Taken from https://github.com/mir-protocol/hash-constants/blob/master/mds_search.sage.

    # 1-minors are just the elements themselves
    if any(any(r == 0 for r in row) for row in m):
        return False

    N = m.nrows()
    assert m.is_square() and N >= 2

    det_cache = m

    # Calculate all the nxn minors of m:
    for n in range(2, N+1):
        new_det_cache = dict()
        for rows in itertools.combinations(range(N), n):
            for cols in itertools.combinations(range(N), n):
                i, *rs = rows

                # Laplace expansion along row i
                det = 0
                for j in range(n):
                    # pick out c = column j; the remaining columns are in cs
                    c = cols[j]
                    cs = cols[:j] + cols[j+1:]

                    # Look up the determinant from the previous iteration
                    # and multiply by -1 if j is odd
                    cofactor = det_cache[(*rs, *cs)]
                    if j % 2 == 1:
                        cofactor = -cofactor

                    # update the determinant with the j-th term
                    det += m[i, c] * cofactor

                if det == 0:
                    return False
                new_det_cache[(*rows, *cols)] = det
        det_cache = new_det_cache
    return True

def M_2(x_input, b):
    """Fast matrix-vector multiplication algorithm for Anemoi MDS layer with \ell = 1,2."""

    x = x_input[:]
    x[0] += b*x[1]
    x[1] += b*x[0]
    return x

def M_3(x_input, b):
    """Fast matrix-vector multiplication algorithm for Anemoi MDS layer with \ell = 3.

    From Figure 6 of [DL18](https://tosc.iacr.org/index.php/ToSC/article/view/888)."""

    x = x_input[:]
    t = x[0] + b*x[2]
    x[2] += x[1]
    x[2] += b*x[0]
    x[0] = t + x[2]
    x[1] += t
    return x


def M_4(x_input, b):
    """Fast matrix-vector multiplication algorithm for Anemoi MDS layer with \ell = 4.

    Figure 8 of [DL18](https://tosc.iacr.org/index.php/ToSC/article/view/888)."""

    x = x_input[:]
    x[0] += x[1]
    x[2] += x[3]
    x[3] += b*x[0]
    x[1]  = b*(x[1] + x[2])
    x[0] += x[1]
    x[2] += b*x[3]
    x[1] += x[2]
    x[3] += x[0]
    return x

def lfsr(x_input, b):
    x = x_input[:]
    l = len(x)
    for r in range(0, l):
        t = sum(b**(2**i) * x[i] for i in range(0, l))
        x = x[1:] + [t]
    return x

def circulant_mds_matrix(field, l, coeff_upper_limit=None):
    if coeff_upper_limit == None:
        coeff_upper_limit = l+1
    assert(coeff_upper_limit > l)
    for v in itertools.combinations_with_replacement(range(1,coeff_upper_limit), l):
        mat = matrix.circulant(list(v)).change_ring(field)
        if is_mds(mat):
            return(mat)
    # In some cases, the method won't return any valid matrix,
    # hence the need to increase the limit further.
    return circulant_mds_matrix(field, l, coeff_upper_limit+1)

def get_mds(field, l):
    if l == 1:
        return identity_matrix(field, 1)
    if l <= 4: # low addition case
        a = field.multiplicative_generator()
        b = field.one()
        t = 0
        while True:
            # we construct the matrix
            mat = []
            b = b*a
            t += 1
            for i in range(0, l):
                x_i = [field.one() * (j == i) for j in range(0, l)]
                if l == 2:
                    mat.append(M_2(x_i, b))
                elif l == 3:
                    mat.append(M_3(x_i, b))
                elif l == 4:
                    mat.append(M_4(x_i, b))
            mat = Matrix(field, l, l, mat).transpose()
            if is_mds(mat):
                return mat
    else: # circulant matrix case
        return circulant_mds_matrix(field, l)

# AnemoiPermutation class

class AnemoiPermutation:
    def __init__(self,
                 q=None,
                 alpha=None,
                 mat=None,
                 n_rounds=None,
                 n_cols=1,
                 security_level=128):
        if q == None:
            raise Exception("The characteristic of the field must be specified!")
        self.q = q
        self.prime_field = is_prime(q)  # if true then we work over a
                                        # prime field with
                                        # characteristic just under
                                        # 2**N, otherwise the
                                        # characteristic is 2**self
        self.n_cols = n_cols # the number of parallel S-boxes in each round
        self.security_level = security_level

        # initializing the other variables in the state:
        # - q     is the characteristic of the field
        # - g     is a generator of the multiplicative subgroup
        # - alpha is the main exponent (in the center of the Flystel)
        # - beta  is the coefficient in the quadratic subfunction
        # - gamma is the constant in the second quadratic subfunction
        # - QUAD  is the secondary (quadratic) exponent
        # - from_field is a function mapping field elements to integers
        # - to_field   is a function mapping integers to field elements
        self.F = GF(self.q)
        if self.prime_field:
            if alpha != None:
                if gcd(alpha, self.q-1) != 1:
                    raise Exception("alpha should be co-prime with the characteristic!")
                else:
                    self.alpha = alpha
            else:
                self.alpha = 3
                while gcd(self.alpha, self.q-1) != 1:
                    self.alpha += 1
            self.QUAD = 2
            self.to_field   = lambda x : self.F(x)
            self.from_field = lambda x : Integer(x)
        else:
            self.alpha = 3
            self.QUAD = 3
            self.to_field   = lambda x : self.F.fetch_int(x)
            self.from_field = lambda x : x.integer_representation()
        self.g = self.F.multiplicative_generator()
        self.beta = self.g
        self.delta = self.g**(-1)
        self.alpha_inv = inverse_mod(self.alpha, self.q-1)
        print("g:",self.g)
        print("g_inv",self.g^(-1))
        print("alpha:",self.alpha)
        print("alpha_inv:",self.alpha_inv)

        # total number of rounds
        if n_rounds != None:
            self.n_rounds = n_rounds
        else:
            self.n_rounds = get_n_rounds(self.security_level,
                                         self.n_cols,
                                         self.alpha)

        # Choosing constants: self.C and self.D are built from the
        # digits of pi using an open butterfly
        self.C = []
        self.D = []
        pi_F_0 = self.to_field(PI_0 % self.q)
        pi_F_1 = self.to_field(PI_1 % self.q)
        for r in range(0, self.n_rounds):
            pi_0_r = pi_F_0**r
            self.C.append([])
            self.D.append([])
            for i in range(0, self.n_cols):
                pi_1_i = pi_F_1**i
                pow_alpha = (pi_0_r + pi_1_i)**self.alpha
                self.C[r].append(self.g * (pi_0_r)**2 + pow_alpha)
                self.D[r].append(self.g * (pi_1_i)**2 + pow_alpha + self.delta)
        self.mat = get_mds(self.F, self.n_cols)


    def __str__(self):
        result = "Anemoi instance over F_{:d} ({}), n_rounds={:d}, n_cols={:d}, s={:d}".format(
            self.q,
            "odd prime field" if self.prime_field else "characteristic 2",
            self.n_rounds,
            self.n_cols,
            self.security_level
        )
        result += "\nalpha={}, beta={}, \delta={}\nM_x=\n{}\n".format(
            self.alpha,
            self.beta,
            self.delta,
            self.mat
        )
        result += "C={}\nD={}".format(
            [[self.from_field(x) for x in self.C[r]] for r in range(0, self.n_rounds)],
            [[self.from_field(x) for x in self.D[r]] for r in range(0, self.n_rounds)],
        )
        return result


    # !SECTION! Sub-components

    def evaluate_sbox(self, _x, _y):
        """Applies an open Flystel to the full state. """
        x, y = _x, _y
#         print(x, y)
        x -= (self.beta*y^2 + self.delta)
#         print(self.beta)
#         print(self.delta)
#         print(x, y)
        y -= x**self.alpha_inv
#         print(x, y)
        x += self.beta*y**self.QUAD
#         print(x, y)
        return x, y

    def linear_layer(self, _x, _y):
        x, y = _x[:], _y[:]
        x = self.mat*vector(x)
        y = self.mat*vector(y[1:] + [y[0]])
        print(x, y)
        # Pseudo-Hadamard transform on each (x,y) pair
        y += x
        x += y
        print(x, y)
        return list(x), list(y)


    # !SECTION! Evaluation

    def eval_with_intermediate_values(self, _x, _y):
        """Returns a list of vectors x_i and y_i such that [x_i, y_i] is the
        internal state of Anemoi at the end of round i.

        The output is of length self.n_rounds+2 since it also returns
        the input values, and since there is a last degenerate round
        consisting only in a linear layer.

        """
        x, y = _x[:], _y[:]
        result = [[x[:], y[:]]]
        for r in range(0, self.n_rounds):
#             for i in range(0, self.n_cols):
#                 x[i] += self.C[r][i]
#                 y[i] += self.D[r][i]
            x, y = self.linear_layer(x, y)
#             print(f"LRound {r}: x = {x}, y = {y}")
            for i in range(0, self.n_cols):
                x[i], y[i] = self.evaluate_sbox(x[i], y[i])
            result.append([x[:], y[:]])
        # final call to the linear layer
        x, y = self.linear_layer(x, y)
        result.append([x[:], y[:]])
        return result


    def input_size(self):
        return 2*self.n_cols


    def __call__(self, _x):
        if len(_x) != self.input_size():
            raise Exception("wrong input size!")
        else:
            x, y = _x[:self.n_cols], _x[self.n_cols:]
            u, v = self.eval_with_intermediate_values(x, y)[-1]
            return u + v # concatenation, not a sum


if __name__ == "__main__":

    # These are the first circulant matrices being found by the circulant_mds_matrix()
    # method above. These are precomputed for faster initiatialization of large Anemoi
    # instances.
    CIRCULANT_FP5_MDS_MATRIX = matrix.circulant([1, 1, 3, 4, 5])
    CIRCULANT_FP6_MDS_MATRIX = matrix.circulant([1, 1, 3, 4, 5, 6])
    CIRCULANT_FP7_MDS_MATRIX = matrix.circulant([1, 2, 3, 5, 5, 6, 7])
    CIRCULANT_FP8_MDS_MATRIX = matrix.circulant([1, 2, 3, 5, 7, 8, 8, 9])
    CIRCULANT_FP9_MDS_MATRIX = matrix.circulant([1, 3, 5, 6, 8, 9, 9, 10, 11])
    CIRCULANT_FP10_MDS_MATRIX = matrix.circulant([1, 2, 5, 6, 8, 11, 11, 12, 13, 14])

    # 128-bit security level instantiations


#     A_GOLDILOCKS_64_FIELD_4_COL_128_BITS = AnemoiPermutation(
#         q=GOLDILOCKS_64_FIELD,
#         n_cols=4,
#         security_level=128
#     )
#     print(A_GOLDILOCKS_64_FIELD_4_COL_128_BITS)
#     x = [1, 2, 3, 4]
#     y = [3, 4, 5, 6]
#     intermediate_values = A_GOLDILOCKS_64_FIELD_4_COL_128_BITS.eval_with_intermediate_values(x, y)
#     for i, (xi, yi) in enumerate(intermediate_values):
#         print(f"Round {i}: x = {xi}, y = {yi}")
    





In [86]:
#Based on the output of the 2-round, 2-branch Anemoi in our work, set the positions where the variables and constants are introduced and solve the corresponding equations.
#prime = 0xffffffff00000001
prime = 0xd1a58d367d1fe3ccb0d67c2fef5038859894ebf64e81f46e847d1964e1380e81
F = GF(prime)
alpha= 3
g = F.primitive_element()
print(g)
print(g^(-1))
M = matrix(F, 2, 2, [
    [2, 1], 
    [1, 1]
])
M2 = matrix(F, 2, 2, [
    [1, g], 
    [g, g^2 + 1]
]) 
print("M =", M)
print("M2 =", M2)
# Building equations to solve for the key

#Introduce variables at the location of the input message block
R.<x0,x1,x2,x3> = PolynomialRing(F, order = 'degrevlex')
b05 = x0
b06 = g * b05 * b05 + g^(-1) 
b08 = F(0)
b07 = b08^alpha
b09 = b05 - b08
b010 = g * b09^2


b04 = b07 + b06
[a03, b03] = list(M^(-1) * vector([b04, b05]))
a00 = F(1)
a01 = (a03 - g * a00) / (g^2 + 1)
a02 = a00 + g * a01


a08 = F(0)
a07 = a08^alpha

a11 = b07 + b010
b11 = b09

b13 = x1
b10 = (b13 - g * b11) / (g^2 + 1)
b12 = b11 + g * b10
a09 = b10
a010 = g * a09^2
a10 = a07 + a010
a05 = a08 + a09
a06 = g * a05^2 + g^(-1)
a01 = (a03 - g * a00) / (g^2 + 1)
a02 = a00 + g * a01
b02 = a05 - a02
a04 = 2 * a02 + b02

[a12, a13] = list(M2 * vector([a10, a11]))
[a14, a15] = list(M * vector([a12, b12]))
[b14, b15] = list(M * vector([a13, b13]))
a16 = g * a15^2 + g^(-1)
a17 = a14 - a16
a18 = x2
a19 = a18^alpha
a20 = a15 - a18
a21 = g * a20^2
b16 = g * b15^2 + g^(-1)
b18 = x3
b17 = b18^alpha
b19 = b15 - b18
b20 = g * b19^2
b21 = b17 + b20
a22 = -g * b21
[b01, b00] = list(M2^(-1) * vector([b02, b03]))
def print_magma(Eqs, Vlist, magma_file):
    f = open(magma_file, 'w')
    f.write("Fp := GF({});\n".format(prime))
    f.write("R<" + ", ".join(Vlist) + "> := PolynomialRing(Fp, {}, \"lex\");\n".format(len(Vlist)))

    s = ""
    for i in range(len(Eqs)-1):
        s += "f{}, ".format(i)
    s += "f{}".format(len(Eqs)-1)    
    for i in range(len(Eqs)):        
        f.write("f{} := ".format(i) + str(Eqs[i]) + ";\n")

    f.write("I := ideal<R|" + s + ">;\n")
    f.write("time gb := GroebnerBasis(I : Al := \"FGLM\");\n")
    f.write("Variety(I);\n")
    f.close()
eqs = []
eqs.append(a04 - a06 - a07)
eqs.append(b14 - b16 - b17)
eqs.append(a17 - a19)
eqs.append(a19 + a21 - a22)
print_magma(eqs, ["x0", "x1", "x2", "x3"], "Anemoi.mag")
print(a00)
print(a01)
print(b00)
print(b01)
# show(a17)
# show(a18)

3
31608629911613427086241436769537060786019065412801113474667809736984617288918
M = [2 1]
[1 1]
M2 = [ 1  3]
[ 3 10]
1
85343300761356253132851879277750064122251476614563006381603086289858466680078*x0^2 + 66378122814388196881107017216027827650640037366882338296802400447667696306727*x0 + 50573807858581483337986298831259297257630504660481781559468495579175387662268
9482588973484028125872431030861118235805719623840334042400342921095385186675*x0^2 + 2*x0 + 9482588973484028125872431030861118235805719623840334042400342921095385186675*x1 + 44252081876258797920738011477351885100426691577921558864534933631778464204485
94825889734840281258724310308611182358057196238403340424003429210953851866747*x0 + x1 + 94825889734840281258724310308611182358057196238403340424003429210953851866752


In [89]:
# Verify that the constructed equations are solved for the variables that match the original image
test = AnemoiPermutation(
        q=prime,
        alpha = alpha,
        #alpha = 11,
        n_cols=2,
        n_rounds=2
    )
x = [1, 4279363971997543644023540507172421857894676838366756130952365409581699374891]
y = [18444772456723481588208702072988572252189246891416593337375724377530032374617, 7833781251070919278166893805636898355471564091000918835880262339932687309168]
intermediate_values = test.eval_with_intermediate_values(x, y)
for i, (xi, yi) in enumerate(intermediate_values):
    print(f"Round {i}: x = {xi}, y = {yi}")

g: 3
g_inv 31608629911613427086241436769537060786019065412801113474667809736984617288918
alpha: 3
alpha_inv: 63217259823226854172482873539074121572038130825602226949335619473969234577835
(12838091915992630932070621521517265573684030515100268392857096228745098124674, 42793639719975436440235405071724218578946768383667561309523654095816993748913) (63168098621241364042793000024602615112039304765250698848007435472522784433019, 18297288850767011199139081529574052872192768710362009033391172373190681940168)
(88844282453226625906934243067637146259407365795451235633721627930012980682367, 9058678555877602820885581364411307672029109239293791228435051353870817571241) (76006190537233994974863621546119880685723335280350967240864531701267882557693, 61090928570742447639374486601298271451139537094029570342914826469007675689081)
(84411688209246007283349550391334008489475628100531495844358972254640816107171, 41033333802321635067244175151653907638322535651280484438837299958900944870330) (463183097792358878

In [88]:
F = GF(prime)
R.<x0,x1,x2,x3> = PolynomialRing(F, order = 'degrevlex')
x0 = 61090928570742447639374486601298271451139537094029570342914826469007675689081
x1 = 89901683471004761338240882487592979988137197939968319621358933521117185843396
x2 = 41073332852664300518550796532643923864618381674914114044279439684447940162718
x3 = 21063671834797071076122378122883469162768188053633187787899289615376630389831
# x0 = 7520286268644501482
# x1 = 4202448682494687111
# x2 = 7723235680324144539
# x3 = 4066956375985008018
# x0 = 16
# x1 = 11
# x2 = 8
# x3 = 4
a00 = F(1)
a01 = F(85343300761356253132851879277750064122251476614563006381603086289858466680078*x0^2 + 66378122814388196881107017216027827650640037366882338296802400447667696306727*x0 + 50573807858581483337986298831259297257630504660481781559468495579175387662268)
b00 = F(9482588973484028125872431030861118235805719623840334042400342921095385186675*x0^2 + 2*x0 + 9482588973484028125872431030861118235805719623840334042400342921095385186675*x1 + 44252081876258797920738011477351885100426691577921558864534933631778464204485)
b01 = F(94825889734840281258724310308611182358057196238403340424003429210953851866747*x0 + x1 + 94825889734840281258724310308611182358057196238403340424003429210953851866752)
print(a00)
print(a01)
print(b00)
print(b01)

1
4279363971997543644023540507172421857894676838366756130952365409581699374891
18444772456723481588208702072988572252189246891416593337375724377530032374617
7833781251070919278166893805636898355471564091000918835880262339932687309168


In [81]:
n = 0xd1a58d367d1fe3ccb0d67c2fef5038859894ebf64e81f46e847d1964e1380e81  

if is_prime(n):
    print(f"{n} is a prime number")
else:
    print(f"{n} not a prime number")

94825889734840281258724310308611182358057196238403340424003429210953851866753 is prime
