# Lattice based cryptography

In [6]:
import numpy as np
from SVP_utils import SVPoint, gram_schmidt, generate_basis_vectors, generate_random_points, augment, find_difference, find_modified_average
import random

In [7]:
B = np.array([
 [37, 20, 96, 20, 34, 64, 82, 56, 47, 21, 50, 49],
 [39, 24, 19, 49, 82, 97, 88, 84, 41, 51, 36, 74],
 [19, 56, 37, 73,  4, 12, 72, 18, 46,  8, 54, 94],
 [13, 46, 26,  8, 83, 71, 45, 84, 21, 32, 53, 80],
 [65, 39, 25, 56, 52, 44, 84, 30, 69, 33, 13,  5],
 [59, 56, 90,  1, 42, 58, 90, 92,  2,  6,  7, 80],
 [18, 14, 26, 31, 91, 93, 77, 64, 95, 36, 23,  5],
 [11, 58, 22, 51, 90, 13, 93, 43, 21, 81, 12, 77],
 [42, 65, 99,  6, 23, 43, 94, 30, 37, 66, 34, 66],
 [99, 31, 24, 44, 18, 58, 17, 27, 70, 88, 59, 11],
 [30, 43, 21, 70, 48, 47, 13, 93, 94, 48, 69, 58],
 [ 7, 12, 94, 88, 59, 95, 43, 62, 71, 36, 91, 70]
])

In [8]:
# Upper bound on the length of the shortest vector by minkowski's theorem
np.sqrt([12])[0] * (np.linalg.det(B) ** (1/12))

230.47420094416253

In [9]:
def sieve(B):
    """
    Function to sieve for short vectors on a lattice with basis B
    
    Parameters:
    B (np.array): Input lattice basis vectors
    
    Returns:
    np.array: short vector on the lattice
    int: length of the short vector
    """
    # 0 Orthogonalise the basis with gram-schmidt
    # B_gs = gram_schmidt(B)

    # 1 Generate n points on the lattice
    n = 100

    points = generate_random_points(B, n, l=-4, h=5)

    # 2 Replace the points with new ones until done TODO when is this done?
    # for i in range(4000):
    while points[0].norm > 72.1:
        
        print(f"average norm in points: {int(sum([p.norm for p in points])/len(points))}, shortest = {points[0].norm}")
        # print(points[0])
        # print([p.norm for p in points])
        if not (points[0].p == np.dot(B, points[0].x)).all():
            print(points[0].p, np.dot(B, points[0].x))
        assert (points[0].p == np.dot(B, points[0].x)).all()

        f = random.choice([find_difference, find_modified_average])
        points = augment(points, B, n, 0.8, f)
        # points = augment(points, B, n, 0.8, find_modified_average)
        # points = augment(points, B, n, 0.8, find_difference)

    return points[0]

SV = sieve(B)
assert (SV.p == np.dot(B, SV.x)).all()

average norm in points: 1501, shortest = 459.0849594573972
average norm in points: 1186, shortest = 459.0849594573972
average norm in points: 978, shortest = 403.57651071389176
average norm in points: 815, shortest = 252.10513679812237
average norm in points: 703, shortest = 252.10513679812237
average norm in points: 599, shortest = 227.99122790142607
average norm in points: 515, shortest = 212.28047484401387
average norm in points: 441, shortest = 174.69974241537966
average norm in points: 380, shortest = 157.0764145249057
average norm in points: 331, shortest = 150.56892109595526
average norm in points: 286, shortest = 124.06046912695437
average norm in points: 252, shortest = 124.06046912695437
average norm in points: 222, shortest = 110.08632975987527
average norm in points: 198, shortest = 96.55050491841045
average norm in points: 178, shortest = 96.55050491841045
average norm in points: 161, shortest = 89.12911982062877
average norm in points: 147, shortest = 89.12911982062877
av

In [10]:
SV

p=[  8.  14.  32. -16.   1. -23. -24. -30.  37.   9.  14.  -2.],
x=[-2.  2. -1.  1. -3.  1.  2.  2. -2.  1.  2. -3.],
norm=72.0832851637604