# Locality Sensitive Hashing: Exploration

This is a redo of Alex Klibisz's [sample notebook](https://github.com/alexklibisz/elastik-nearest-neighbors/blob/master/scratch/lsh-experiments/lsh-explore.ipynb) using [`plotly express`](https://medium.com/@plotlygraphs/introducing-plotly-express-808df010143d) instead of `matplotlib`.

In [None]:
import plotly.graph_objs as go
import numpy as np
import pandas as pd

from sklearn.neighbors import NearestNeighbors

In [None]:
# Generate random points in 2D space.
rng = np.random.RandomState(33)
X = rng.normal(5, 1, size=(200, 2))

df = pd.DataFrame(X, columns=['x', 'y'])
fig = go.FigureWidget(data=[dict(type='scatter', x = df.x, y = df.y,
                      mode = 'markers', showlegend=False)])
# fig.layout.xaxis.rangemode = fig.layout.yaxis.rangemode = 'tozero'
fig.layout.width = fig.layout.height = 600
fig.layout.yaxis = {'scaleanchor':'x', 'scaleratio':1}
fig.layout.hovermode = 'closest'

In [None]:
# Value closer to 0.5 are good here..
# print('Proportion positive at each hash: ', H.mean())

# Compute the recall @ 10...
knn = NearestNeighbors(n_neighbors=10, algorithm='brute', metric='euclidean')
nbrs_true = knn.fit(X).kneighbors(X, return_distance=False)
# nbrs_hash = knn.fit(H).kneighbors(H, return_distance=False)

# recalls = np.array([len(np.intersect1d(a, b)) for a, b in zip(nbrs_true, nbrs_hash)])
# print('Recall @10 min, mean, median, max = %.2lf, %.2lf, %.2lf, %.2lf' % (
#     recalls.min(), recalls.mean(), np.median(recalls), recalls.max()))

In [None]:
def hover_fn(trace, points, state):
    point_of_interest = points.point_inds[0]
    neighbor_colors = ['blue']*len(X)
#     neighbors = [n for n,h in enumerate(H) if (h == H[point_of_interest]).all()]
    nbrs_true = knn.fit(X).kneighbors(X, return_distance=False)
    for n in nbrs_true[point_of_interest]:
        neighbor_colors[n] = 'orange'
    neighbor_colors[point_of_interest] = 'green'
    fig.data[0].marker.color = neighbor_colors
    fig.data[0].marker.opacity = 1
    
    C_index = [n for n,x in enumerate(X) if (x in C)]
#     c_colors = ['blue']*len(X)
    for n in C_index:
        neighbor_colors[n] = 'red'
    fig.data[0].marker.color = neighbor_colors

In [None]:
def make_lsh_model(nb_tables, nb_bits, nb_dimensions, vector_sample):
    # vector_sample: np arr w/ shape (2 * nb_tables * nb_tables, nb_dimensions).
    # normals, midpoints: np arrs w/ shape (nb_bits, nb_dimensions)
    # thresholds: np arrs w/ shape (nb_bits)
    # all_normals, all_thresholds: lists w/ one normal, one threshold per table.
    all_normals, all_thresholds = [], []
    for i in range(0, len(vector_sample), 2 * nb_bits):
            vector_sample_a = vector_sample[i:i + nb_bits]
            vector_sample_b = vector_sample[i + nb_bits: i + 2 * nb_bits]
            midpoints = (vector_sample_a + vector_sample_b) / 2
            normals = vector_sample_a - midpoints
            thresholds = np.zeros(nb_bits)
    for j in range(nb_bits):
            thresholds[j] = normals[j].dot(midpoints[j])
            all_normals.append(normals)
            all_thresholds.append(thresholds)
    return all_normals, all_thresholds

In [None]:
def get_lsh_hashes(vec, all_normals, all_thresholds):
    # vec: np arr w/ shape (nb_dimensions, )
    # hashes: one hash per table.
    hashes = dict()
    for normal, thresholds in zip(all_normals, all_thresholds):
        hsh = 0
        dot = vec.dot(normal.T)  # shape (nb_bits,)
    for i, (d, t) in enumerate(zip(dot, thresholds)):
        if d > t:
            hsh += i ** 2
            hashes[len(hashes)] = hsh
    return hashes

In [None]:
nb_tabs = 10
nb_bits = 2
nb_dims = 2
vector_sample = np.random.normal(0, 3, (2 * nb_tabs * nb_bits, nb_dims))
all_normals, all_thresholds = make_lsh_model(nb_tabs, nb_bits, nb_dims, vector_sample)
vec = np.random.normal(0, 3, (nb_dims,))
hashes = get_lsh_hashes(vec, all_normals, all_thresholds)

In [None]:
vec

In [None]:
hashes

In [None]:
bits = 2
nb_tables = 1
nb_dimensions = X.shape[-1]
# Randomly sample (indices of) dataset X.
vector_sample = np.array([X[i] for i in rng.choice(np.arange(len(X)), size=2*nb_tables*bits, replace=False)])

all_normals, all_thresholds = [], []
vector_sample_p, vector_sample_q = [], []
for i in range(0, len(vector_sample), 2 * bits):
    vector_sample_p.append(vector_sample[i:i + bits][0])
    vector_sample_q.append(vector_sample[i + bits: i + 2 * bits][0])
    
# midpoints = (vector_sample_a + vector_sample_b) / 2
# normals = vector_sample_a - midpoints
# thresholds = np.zeros(bits)
# for j in range(bits):
#         thresholds[j] = normals[j].dot(midpoints[j])
#         all_normals.append(normals)
#         all_thresholds.append(thresholds)

for p,q in zip(vector_sample_p, vector_sample_q):
    # This is effectively the only information that needs to be stored.
    m = (p + q) / 2 # Midpoint.
    n = m - q       # Normal vector.

    fig.add_scatter(x=[p[0], q[0]], y=[p[1], q[1]], mode='markers')

    # Some arithmetic to plot the lines.
    Z = np.vstack([np.linspace(0, 10, 2), np.zeros(2)]).T
    Z[:,1] = (n[0] * Z[:,0] - n.dot(m)) / (-1 * n[1])
    fig.add_scatter(x=Z[:,0], y=Z[:,1], mode='lines', line={'dash':'dash', 'color':'red'}, showlegend=False)

fig.data[0].on_hover(hover_fn)
# fig.data[0].hoverinfo = 'none'
fig

In [None]:
vector_sample

In [None]:
vector_sample_p