In [2]:
import numpy as np
import random

def maxplus_multi(A, B):
    m, n = A.shape
    k, l = B.shape

    if n != k:
        raise ValueError("Dimension Error")


    C = np.zeros((m, l))

    for i in range(m):
        for j in range(l):
            Ax1 = A[i, :] + B[:, j]
            C[i, j] = np.max(Ax1)

    return C


def maxplus_sum(A, B):
    result = []

    for i in range(len(A)):
        row = []
        for j in range(len(A[0])):
            row.append(max(A[i][j], B[i][j]))
        result.append(row)

    return result

def generate_random_polynomial(D, pm, pM):
    d = np.random.randint(1, D + 1)
    p = np.zeros((2, d + 1))

    for i in range(1, d + 2):
        p[:, i - 1] = [i - 1, np.random.randint(pm, pM + 1)]

    return p



def maxplus_trop_identity(n):
    A = np.full((n, n), -np.inf)
    np.fill_diagonal(A, 0)
    return A



def fast_power_max_plus(B, t):
    n, m = B.shape

    if n != m:
        raise ValueError("Dimension Error")

    if t == 0:
        return maxplus_trop_identity(n)

    exp = bin(t)[2:]
    value = maxplus_trop_identity(n)

    for i in range(len(exp)):
        value = maxplus_multi(value, value)

        if exp[i] == '1':
            value = maxplus_multi(value, B)

    return value


def apply_polynomial_max_plus(V, X):
    n, m = X.shape
    p, q = V.shape
    D = np.full((n, m), -np.inf)
    temp = D.copy()

    for i in range(q):
        if V[0, i] == 0:
            temp = maxplus_trop_identity(n) + V[1, i]
            D = np.maximum(D, temp)
        else:
            c = fast_power_max_plus(X, int(V[0, i]))
            temp = V[1, i] + c
            D = np.maximum(D, temp)

    return D



def generate_key_stickels(n, mm, mM, D, pm, pM):
    A = np.random.randint(mm, mM + 1, (n, n))
    B = np.random.randint(mm, mM + 1, (n, n))
    Wm=np.random.randint(mm, mM + 1, (n, n))

    p1 = generate_random_polynomial(D, pm, pM)
    p2 = generate_random_polynomial(D, pm, pM)
    q1 = generate_random_polynomial(D, pm, pM)
    q2 = generate_random_polynomial(D, pm, pM)

    temp=maxplus_multi(apply_polynomial_max_plus(p1, A),Wm)
    U=maxplus_multi(temp,apply_polynomial_max_plus(p2, B))


    temp=maxplus_multi(apply_polynomial_max_plus(q1, A),Wm)
    V=maxplus_multi(temp,apply_polynomial_max_plus(q2, B))


    KA = maxplus_multi(apply_polynomial_max_plus(p1, A), maxplus_multi(V, apply_polynomial_max_plus(p2, B)))
    KB = maxplus_multi(apply_polynomial_max_plus(q1, A), maxplus_multi(U, apply_polynomial_max_plus(q2, B)))

    if np.array_equal(KA, KB):
        key = KA
    else:
        key = []

    return key, U, V,A,B,Wm


# part of the following code is from https://github.com/mkotov/tropical3/blob/main/attack.py
import scipy.optimize
import heapq
import multiprocessing


def get_maximum_of_matrix(A):
    m = A[0][0]
    inds = {(0, 0)}

    for i in range(len(A)):
        for j in range(len(A[i])):
            if A[i][j] > m:
                m = A[i][j]
                inds = {(i, j)}
            elif A[i][j] == m:
                inds.add((i, j))
    return m, inds

def get_compressed_covers(F):


    def compress(G):


        def find(H, S):
            for i in range(len(H)):
                if S.inds == H[i].inds:
                    return i
            return None

        H = []
        for S in G:
            i = find(H, S)
            if i is None:
                H.append(S)
            else:
                H[i].ijs.update(S.ijs)
        return H

    def unite_sets(ss):

        return set() if len(ss) == 0 else set.union(*ss)

    def get_sets_with_unique_elements(Z):

        return list(filter(lambda S: len(S.inds.difference(unite_sets([T.inds for T in filter(lambda T: S != T, Z)]))) != 0, Z))

    def get_sets_without_elements(Z, N):

        return list(filter(lambda S: len(S.inds) != 0, [Cover(S.inds.difference(N), S.ijs) for S in Z]))

    if len(F) == 0:
        return [[]]
    Z = compress(F)
    M = get_sets_with_unique_elements(Z)
    P = get_sets_without_elements(Z, unite_sets([S.inds for S in M]))
    if len(P) == 0:
        return [M]

    P.sort(key=lambda S: len(S.inds), reverse=True)

    X = [[P[0]] + S for S in get_compressed_covers(
        get_sets_without_elements(P[1:], P[0].inds))]
    Y = get_compressed_covers(P[1:])

    return [M + S for S in X + Y]


def compute_preweights(S, R):
    lins = {}
    cols = {}
    for T in S:
        for p in T.ijs:
            i = p[0]
            j = p[1]
            if not i in lins:
                lins[i] = 0
            if not j in cols:
                cols[j] = 0
            if T != R:
                lins[i] += 1.0 / len(T.ijs)
                cols[j] += 1.0 / len(T.ijs)
    return lins, cols


def get_weighted_sets(S, solve_linprog):
    W = []
    mandatory = [next(iter(T.ijs))
                 for T in filter(lambda T: len(T.ijs) == 1, S)]

    for T in S:
        lins, cols = compute_preweights(S, T)
        if len(T.ijs) > 1:
            w = []
            for p in T.ijs:
                if solve_linprog(mandatory + [p]):
                    w.append([p, (lins[p[0]] + 1) * (cols[p[1]] + 1)])
            W.append([p[0]
                     for p in sorted(w, reverse=True, key=lambda x: x[1])])
        else:
            W.append([next(iter(T.ijs))])

    return W


def enumerate_with_queue(E, chunk_size=10):


    def heuristics_to_sort(S):
        def sum_of_cross(S, i, j):
            lins = sum(1 if s[0] == i else 0 for s in S)
            cols = sum(1 if s[1] == j else 0 for s in S)
            return lins * cols

        return -sum(sum_of_cross(S, s[0], s[1])**2 for s in S)

    q = []
    while True:
        k = 0
        for e in E:
            heapq.heappush(q, (heuristics_to_sort(e), e))
            k += 1
            if k == chunk_size:
                break

        if len(q) == 0:
            break

        r = heapq.heappop(q)
        yield r[1]


def enumerate_product_of_sets(W):

    def enumerate_product_of_sets_(W, s, i):
        if i == len(W) - 1:
            if s < len(W[i]):
                yield [W[i][s]]
            else:
                return
        else:
            for t in range(min(s + 1, len(W[i]))):
                for q in enumerate_product_of_sets_(W, s - t, i + 1):
                    yield [W[i][t]] + q

    l = sum(len(w) - 1 for w in W) + 1
    for s in range(l):
        yield from enumerate_product_of_sets_(W, s, 0)


class Cover:
    def __init__(self, inds, ijs):
        self.inds = inds
        self.ijs = ijs


def apply_attack(A,B,Wm,U,d1,d2,bounds=(None, None)):


    M = dict()
    I = []
    for i in range(d1):
        for j in range(d2):
            temp=maxplus_multi(fast_power_max_plus(A, i),Wm)
            T= maxplus_multi(temp, fast_power_max_plus(B, j)) - U

            m, inds = get_maximum_of_matrix(T)
            M[(i, j)] = m
            I.append(Cover(inds, {(i, j)}))

    G = get_compressed_covers(I)

    def solve_linprog(S):


        def make_matrices_for_linprog(S):

            c = [0 for _ in range(d1 + d2)]
            Aub = []
            bub = []
            Aeq = []
            beq = []

            for i in range(d1):
                for j in range(d2):
                    v = [1 if k == i or k == d1 +
                         j else 0 for k in range(d1 + d2)]
                    m = -M[(i, j)]
                    if (i, j) in S:
                        Aeq.append(v)
                        beq.append(m)
                    else:
                        Aub.append(v)
                        bub.append(m)

            if Aub == []:
                Aub = None
                bub = None
            if Aeq == []:
                Aeq = None
                beq = None

            return c, Aub, bub, Aeq, beq

        c, Aub, bub, Aeq, beq = make_matrices_for_linprog(S)

        T = scipy.optimize.linprog(
            c, A_ub=Aub, b_ub=bub, A_eq=Aeq, b_eq=beq, bounds=bounds)
        if T.success:
            return [[T.x[i] for i in range(d1)], [T.x[d1 + i] for i in range(d2)]]

    def enumerate_covers(G):


        for S in sorted(G, key=len):
            W = get_weighted_sets(S, solve_linprog)
            yield from enumerate_with_queue(enumerate_product_of_sets(W))


    for S in enumerate_covers(G):
        result = solve_linprog(S)
        if result:
            return result

    return None


In [3]:
def New_attack(A,B,Wm,U,d):
  C = dict()
  n=len(A)
  for i in range(d+1):
    for j in range(d+1):
      temp=maxplus_multi(fast_power_max_plus(A, i),Wm)
      T= maxplus_multi(temp, fast_power_max_plus(B,j)) - U
      m, inds = get_maximum_of_matrix(T)
      C[(i, j)] = -m

  recovered_key=[[ -np.inf for _ in range(n)] for _ in range(n)]
  for i in range(d+1):
    for j in range(d+1):
      mat = maxplus_multi(maxplus_multi(fast_power_max_plus(A, i), V), fast_power_max_plus(B, j))
      recovered_key=maxplus_sum(recovered_key,C[(i, j)]+np.array(mat))

  return recovered_key

In [4]:
def KU_attack(A,B,Wm,U,d):
  res=apply_attack(A,B,Wm,U,d+1,d+1,bounds=(None, None))
  x=res[0]
  y=res[1]
  polynomial_x = np.vstack([np.arange(d+1), x])
  polynomial_y = np.vstack([np.arange(d+1), y])


  recovered_key = maxplus_multi(apply_polynomial_max_plus(polynomial_x, A),maxplus_multi(V, apply_polynomial_max_plus(polynomial_y, B)))

  return recovered_key

In [None]:
import time
import matplotlib.pyplot as plt

d_values = [2,10,20,30,40,50,60,70,80,90,100]
num_trials = 3

success_rate_new = []
success_rate_ku = []
time_taken_new = []
time_taken_ku = []

for d in d_values:
    success_new = 0
    success_ku = 0
    time_new = 0
    time_ku = 0

    for _ in range(num_trials):
        key, U, V, A, B, Wm = generate_key_stickels(10, -100, 100, d, -100, 100)

        start_time = time.time()
        recovered_key_new = New_attack(A, B, Wm, U, d)
        time_new += time.time() - start_time
        success_new += (key == recovered_key_new).all()

        start_time = time.time()
        recovered_key_ku = KU_attack(A, B, Wm, U, d)
        time_ku += time.time() - start_time
        success_ku += (key == recovered_key_ku).all()

    success_rate_new.append(success_new / num_trials)
    success_rate_ku.append(success_ku / num_trials)
    time_taken_new.append(time_new / num_trials)
    time_taken_ku.append(time_ku / num_trials)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(d_values, success_rate_new, label='The new Attack', marker='o')
plt.plot(d_values, success_rate_ku, label='Kotov-Ushakov Attack', marker='o')
plt.xlabel('Polynomial degree D')
plt.ylabel('Success Rate')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(d_values, time_taken_new, label='The new Attack', marker='o')
plt.plot(d_values, time_taken_ku, label='Kotov-Ushakov Attack', marker='o')
plt.xlabel('Polynomial degree D')
plt.ylabel('Average Time (seconds)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()