# Lattice based cryptography

In [13]:
import numpy as np
from SVP_utils import SVPoint, gram_schmidt, generate_basis_vectors, generate_random_points, augment, find_difference, find_modified_average, find_modified_average_random
import random

In [14]:
B = np.array([
 [37, 20, 96, 20, 34, 64, 82, 56, 47, 21, 50, 49],
 [39, 24, 19, 49, 82, 97, 88, 84, 41, 51, 36, 74],
 [19, 56, 37, 73,  4, 12, 72, 18, 46,  8, 54, 94],
 [13, 46, 26,  8, 83, 71, 45, 84, 21, 32, 53, 80],
 [65, 39, 25, 56, 52, 44, 84, 30, 69, 33, 13,  5],
 [59, 56, 90,  1, 42, 58, 90, 92,  2,  6,  7, 80],
 [18, 14, 26, 31, 91, 93, 77, 64, 95, 36, 23,  5],
 [11, 58, 22, 51, 90, 13, 93, 43, 21, 81, 12, 77],
 [42, 65, 99,  6, 23, 43, 94, 30, 37, 66, 34, 66],
 [99, 31, 24, 44, 18, 58, 17, 27, 70, 88, 59, 11],
 [30, 43, 21, 70, 48, 47, 13, 93, 94, 48, 69, 58],
 [ 7, 12, 94, 88, 59, 95, 43, 62, 71, 36, 91, 70]
])

In [35]:
# Upper bound on the length of the shortest vector by minkowski's theorem
np.sqrt([12])[0] * (np.linalg.det(B) ** (1/12))


# Upper bound on the length of the shortest vector by https://www.latticechallenge.org/svp-challenge/#
from math import factorial
def gamma(n):
    return factorial(int(n-1))

n = 12
g_term = ( gamma( (n/2+1) ** 1/n ) ) / np.sqrt(np.pi)
1.05 * g_term * (np.linalg.det(B) ** (1/12))

39.41359572854765

In [16]:
def sieve(B):
    """
    Function to sieve for short vectors on a lattice with basis B
    
    Parameters:
    B (np.array): Input lattice basis vectors
    
    Returns:
    np.array: short vector on the lattice
    int: length of the short vector
    """
    # 0 Orthogonalise the basis with gram-schmidt
    # B_gs = gram_schmidt(B)

    # 1 Generate n points on the lattice
    n = 500

    points = generate_random_points(B, n, l=-4, h=5)

    # 2 Replace the points with new ones until done TODO when is this done?
    # for i in range(40):
    while points[0].norm > 72.1:
        
        print(f"average norm in points: {int(sum([p.norm for p in points])/len(points))}, shortest = {points[0].norm}")
        # print(points[0])
        # print([p.norm for p in points])
        # if not (points[0].p == np.dot(B, points[0].x)).all():
        #     print(points[0].p, np.dot(B, points[0].x))
        assert (points[0].p == np.dot(B, points[0].x)).all()

        f = random.choice([find_difference, find_modified_average])
        points = augment(points, B, n, 0.85, f)
        # points = augment(points, B, n, 0.8, find_modified_average)
        # points = augment(points, B, n, 0.8, find_difference)

    return points[0]

SV = sieve(B)
assert (SV.p == np.dot(B, SV.x)).all()
print(SV)

average norm in points: 1597, shortest = 382.358470548254
average norm in points: 1307, shortest = 382.358470548254
average norm in points: 1118, shortest = 339.12976867270146
average norm in points: 1019, shortest = 339.12976867270146
average norm in points: 943, shortest = 339.12976867270146
average norm in points: 848, shortest = 262.7565413077284
average norm in points: 798, shortest = 262.7565413077284
average norm in points: 730, shortest = 198.91958174096385
average norm in points: 666, shortest = 198.91958174096385
average norm in points: 626, shortest = 198.91958174096385
average norm in points: 570, shortest = 198.91958174096385
average norm in points: 540, shortest = 198.91958174096385
average norm in points: 512, shortest = 198.91958174096385
average norm in points: 472, shortest = 159.70284906663375
average norm in points: 447, shortest = 159.70284906663375
average norm in points: 415, shortest = 159.70284906663375
average norm in points: 382, shortest = 142.70949512909084