In [259]:
from sage.modules.free_module_integer import IntegerLattice


p = next_prime(2^16)
n = ceil(log(p, 2))
k = ceil(sqrt(n)) + ceil(log(n, 2))
s = random.randint(1, p - 1)

print(f"p = {p}")
print(f"n = {n}")
print(f"k = {k}")
print(f"s = {s}")

p = 65537
n = 17
k = 10
s = 47728


In [285]:
def represent_in_Zq(x, q):
    return x % q

def msb_k(y, k, p):
    B = p // (2 ** k)
    return y // B

In [286]:
def generate_pairs(N, q, k, s):
    pairs = []
    for i in range(N):
        x_i = random.randint(0, q - 1)
        u_i = msb_k((s * x_i) % q, k, q)
        pairs.append((x_i, u_i))
    return pairs

pairs = generate_pairs(n, p, k, s)
print(pairs)

[(14723, 181), (51525, 630), (30841, 284), (60491, 202), (40345, 680), (30564, 563), (55289, 806), (41187, 880), (54345, 317), (22073, 912), (62215, 736), (36429, 816), (14417, 337), (43728, 378), (8584, 396), (60759, 380), (257, 166)]


In [299]:
def build_integer_lattice(oracle_inputs):
    l = len(oracle_inputs) + 1  # Dimension of the lattice
    basis_vectors = []

    # Add basis vectors for oracle inputs
    for i in range(l-1):
        p_vector = [0] * (l)
        p_vector[i] = p
        basis_vectors.append(p_vector)


    # Add the last basis vector with appropriate values
    A = 1 // (2 ** k)
    last_basis_vector = oracle_inputs + [A]
    basis_vectors.append(last_basis_vector)
    print(basis_vectors)
    # Construct the lattice using IntegerLattice
    return IntegerLattice(basis_vectors)

input_lattice = build_integer_lattice([x_i for x_i, _ in pairs])
print(input_lattice)

[[65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 65537, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [303]:
def find_closest_point(lattice, target_vector, B):
    """
    Find the lattice point closest to the target vector within distance B.

    Parameters:
        lattice (IntegerLattice): The lattice.
        target_vector (list): The target vector.
        B (Integer): The distance threshold.

    Returns:
        list: The lattice point closest to the target vector.
    """
    
    closest_point = lattice.closest_vector(target_vector)
    # Calculate the distance between the closest point and the target vector
    distance = sum(abs(a - b) for a, b in zip(closest_point, target_vector))
    if distance > B:
        return None
    return closest_point

def recover_secret_from_lattice_point(lattice_point, p, k):
    """
    Recover the secret from the lattice point.

    Parameters:
        lattice_point (list): The lattice point.
        p (Integer): The prime modulus.
        k (Integer): The value of k.

    Returns:
        Integer: The recovered secret.
    """
    # Extract the last component of the lattice point
    s = lattice_point[-1]
    # Adjust the secret to be within the range of Z_p
    s = s.mod(p)
    # Adjust the secret to the correct value range
    s = s * (2 ** k) // p
    return s

# Example usage:
# Construct the target vector

B = p // (2 ** k)
oracle_inputs = [x_i for x_i, _ in pairs]
target_vector = [B * ui for ui in oracle_inputs] + [0]
# Find the lattice point closest to the target vector within distance B
closest_point = find_closest_point(input_lattice, target_vector, B)
# Recover the secret from the lattice point
recovered_secret = recover_secret_from_lattice_point(closest_point, p, k)
print("Recovered secret:", recovered_secret)

KeyboardInterrupt: 

In [291]:
def lattice_reduction(G):
    return G.LLL()

reduced_lattice = lattice_reduction(input_lattice)
print(reduced_lattice)

[  4424   8654  -4483 -10410   1097   3050   5899  -7842 -10543  -1687   3511 -12125  -6182   1842  -4890  12586   9336      0]
[ 10395   -198  -7319  -6002  26117 -14610   -968  -9923  12040   4643  -8653  -1010   6836 -18278   9457   6967   5345      0]
[ -2862   1927  -2581  19356   2342  27003   2139  20776   2643   8765   5106 -13340 -11585   7756    971  -8557  -4025      0]
[ 10307   1052  26309  -6506 -10999   2158  12425    425  -5053 -16063  14861  22578  -7855  -6819  -3245  -9875  -7877      0]
[-10714 -11655  30915   4175   3180   9857   8557  14666   -273   2219  11970 -11789  -2746   2472  -6349  13902  -4122      0]
[  7201    487   3769 -18337 -16243  16075  -4901  -3728    838  -5131  13670  15062  -3574  11857  16128 -17467  16974      0]
[  1778  -4403  11679  16141  23047    811  -6962  11366   4266 -19551   4522  -5910    419   5303   3427   9621 -27772      0]
[ -6096  15096 -11955    834  14606 -12143 -13580   7843  13461   1495 -15504   1538 -10799  19268  2570

In [297]:
def get_secret_from_reduced_lattice(reduced_lattice, q):
    shortest_vector = reduced_lattice.row(-1)
    s = (shortest_vector[-2] * (2 ** k)).ceil()
    return represent_in_Zq(s, q)

found_secret = get_secret_from_reduced_lattice(reduced_lattice, p)
print(found_secret)

36998


In [298]:
def verify_secret(s, x_values, u_values, p, k):
    for x, u in zip(x_values, u_values):
        print('expected: ', u)
        print('got: ',msb_k((s * x) % p, k, p))

        if msb_k((s * x) % p, k, p) != u:
            return False
    return True

# Verificar o segredo extraído
t = [x for x, _ in pairs]
B = p // (2 ** k)
u = [B*u for _, u in pairs] + [0]
is_correct = verify_secret(found_secret, t, u, p, k)

if is_correct:
    print("O segredo extraído está correto.")
else:
    print("O segredo extraído está incorreto.")

expected:  11584
got:  680
O segredo extraído está incorreto.
