# Speed Tests for KNN Search

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

from tbp.monty.frameworks.utils.logging_utils import load_stats

In [None]:
%matplotlib notebook

In [None]:
pretrain_path = os.path.expanduser("~/tbp/results/monty/pretrained_models/")
pretrained_dict = pretrain_path + "pretrained_ycb_v3/touch_1lm_10distinctobj/pretrained/"
log_path = os.path.expanduser("~/tbp/results/monty/projects/monty_runs/")
exp_name = "base_config_10distobj_touch/"
exp_path = log_path + exp_name
train_stats, eval_stats, detailed_stats, lm_models = load_stats(exp_path,
                                                                load_train=False,
                                                                load_eval=True,
                                                                load_detailed=True,
                                                                pretrained_dict=pretrained_dict,
                                                               )

In [None]:
model_locs = np.array(lm_models['pretrained'][0]['mug'].pos, dtype=np.float64)

In [None]:
search_locations = np.array(model_locs + 0.001, dtype=np.float64)

## KD Tree Search

In [None]:
import time
from scipy.spatial import KDTree

In [None]:
start_time = time.time()
kd_tree = KDTree(model_locs,leafsize=40)
end_time = time.time()
print(f"building KD Tree took {end_time-start_time}s")

In [None]:
start_time = time.time()
for i in range(1000):
    (_, nearest_node_ids_kd) = kd_tree.query(
                search_locations,
                k=10,
                p=2,  # eucledian distance
                workers=1,  # using more than 1 worker slows down run on lambda.
            )
end_time = time.time()
print(f"querying 1000x KD Tree took {np.round(end_time-start_time,3)}s")

In [None]:
nearest_node_ids_kd

## Lookup Table (Built from KD Tree)

In [None]:
mins = np.min(model_locs,axis=0)
maxs = np.max(model_locs,axis=0)
print(f"mins: {mins}, maxs: {maxs}")
# make mins and maxes round to 3 decimal points
mins = [-0.06,-0.06,-0.06]
maxs = [0.06,0.06,0.06]
print(f"new mins: {mins}, maxs: {maxs}")

In [None]:
num_bins = 121
x_range = np.linspace(mins[0], maxs[0], num_bins)
y_range = np.linspace(mins[1], maxs[1], num_bins)
z_range = np.linspace(mins[2], maxs[2], num_bins)

print(f"using ranges {x_range}")

In [None]:
# Generate grid for lookup table
xs, ys, zs = np.meshgrid(x_range, y_range, z_range) # each of shape (num_bins, num_bins, num_bins)
grid_locs = np.stack([xs, ys, zs], axis=-1) # shape=(num_bins, num_bins, num_bins, 3)
grid_locs = grid_locs.reshape((num_bins * num_bins * num_bins,3))# shape=(num_bins * num_bins * num_bins, 3)
# grid_locs = np.round(grid_locs,2)

In [None]:
plt.figure()
ax = plt.subplot(1,1,1,projection='3d')
# ax.scatter(grid_locs[:,0], grid_locs[:,1], grid_locs[:,2])
ax.scatter(model_locs[:,0], model_locs[:,1], model_locs[:,2])
ax.scatter(search_locs[:,0], search_locs[:,1], search_locs[:,2])
plt.show()

### Using String as Keys

In [None]:
def coords_to_keys(coords, rounding_factor=3):
#     t1 = time.time()
    rounded_locs = np.round(coords,rounding_factor)
#     t2 = time.time()
    rounded_locs = rounded_locs.astype(str).tolist()
#     t3 = time.time()
#     search_keys = [row[0] + ',' + row[1] + ',' + row[2] for row in rounded_locs]
    search_keys = [','.join(row) for row in rounded_locs]
#     t4 = time.time()
#     print(f"{t2-t1}, {t3-t2}, {t4-t3}")
    return search_keys

In [None]:
start_time = time.time()
kd_tree = KDTree(model_locs,leafsize=40)
imt = time.time()
print(f"built tree in {imt - start_time}s")
(_, nearest_node_ids) = kd_tree.query(
                grid_locs,
                k=10,
                p=2,  # eucledian distance
                workers=1,  # using more than 1 worker slows down run on lambda.
            )
imt2 = time.time()
print(f"queried tree after {imt2 - imt}s")
grid_locs_keys = coords_to_keys(grid_locs, 3)
imt3 = time.time()
print(f"turned locs to str in {imt3-imt2}s")
hash_table = dict(zip(grid_locs_keys, nearest_node_ids))
end_time = time.time()
imt4 = time.time()
print(f"building hash table took {np.round(imt4-imt3,3)}s")
print(f"overall time: {np.round(end_time-start_time,3)}")

In [None]:
start_time = time.time()
for i in range(1000):
    search_keys = coords_to_keys(search_locations, 3)
    nearest_nodes = [hash_table[key] if key in hash_table.keys() else np.zeros(10) for key in search_keys]
end_time = time.time()
print(f"querying 1000x hash table took {np.round(end_time-start_time,3)}s")

In [None]:
correct_set = [len(set(nearest_node_ids_kd[i]).intersection(set(nearest_nodes[i])))/10 for i in range(len(nearest_nodes))]
print(f"{np.round(np.mean(correct_set),3)}% correct items in set (ignoring order)")

In [None]:
start_time = time.time()
for i in range(1000):
    nearest_nodes = [hash_table[key] if key in hash_table else np.zeros(10) for key in search_keys]
end_time = time.time()
print(f"querying 1000x hash table without key conversion took {np.round(end_time-start_time,3)}s")

### Using Tuple as Keys

In [None]:
def coords_to_keys2(coords, rounding_factor=3):
#     t1 = time.time()
    rounded_locs = np.round(coords,rounding_factor)
#     t2 = time.time()
#     search_keys = map(tuple, rounded_locs)
#     t3 = time.time()
    search_keys = [(row[0], row[1], row[2]) for row in rounded_locs]
#     t4 = time.time()
#     print(f"{t2-t1}, {t3-t2}, {t4-t3}")
    return search_keys

In [None]:
grid_locs_keys = coords_to_keys2(grid_locs, 3)

In [None]:
hash_table = dict(zip(grid_locs_keys, nearest_node_ids))

In [None]:
hash_table

In [None]:
start_time = time.time()
kd_tree = KDTree(model_locs,leafsize=40)
imt = time.time()
print(f"built tree in {imt - start_time}s")
(_, nearest_node_ids) = kd_tree.query(
                grid_locs,
                k=10,
                p=2,  # eucledian distance
                workers=1,  # using more than 1 worker slows down run on lambda.
            )
imt2 = time.time()
print(f"queried tree after {imt2 - imt}s")
grid_locs_keys = coords_to_keys2(grid_locs, 3)
imt3 = time.time()
print(f"turned locs to tuples in {imt3-imt2}s")
hash_table = dict(zip(grid_locs_keys, nearest_node_ids))

end_time = time.time()
imt4 = time.time()
print(f"building hash table took {np.round(imt4-imt3,3)}s")
print(f"overall time: {np.round(end_time-start_time,3)}")

In [None]:
start_time = time.time()
kd_tree = KDTree(model_locs,leafsize=40)
imt = time.time()
print(f"built tree in {imt - start_time}s")
(_, nearest_node_ids) = kd_tree.query(
                grid_locs,
                k=10,
                p=2,  # eucledian distance
                workers=1,  # using more than 1 worker slows down run on lambda.
            )
imt2 = time.time()
print(f"queried tree after {imt2 - imt}s")
grid_locs_keys = coords_to_keys2(grid_locs, 3)
imt3 = time.time()
print(f"turned locs to tuples in {imt3-imt2}s")
hash_table = dict(zip(grid_locs_keys, nearest_node_ids))

end_time = time.time()
imt4 = time.time()
print(f"building hash table took {np.round(imt4-imt3,3)}s")
print(f"overall time: {np.round(end_time-start_time,3)}")

In [None]:
start_time = time.time()
for i in range(1000):
    search_keys = coords_to_keys2(search_locations, 3)
    nearest_nodes = [hash_table[key] if key in hash_table.keys() else np.zeros(10) for key in search_keys]
end_time = time.time()
print(f"querying 1000x hash table took {np.round(end_time-start_time,3)}s")

In [None]:
correct_set = [len(set(nearest_node_ids_kd[i]).intersection(set(nearest_nodes[i])))/10 for i in range(len(nearest_nodes))]
print(f"{np.round(np.mean(correct_set),3)}% correct items in set (ignoring order)")

In [None]:
num_results = np.array(nearest_nodes).shape[0]*np.array(nearest_nodes).shape[1]
print(f"{np.round(np.sum(nearest_node_ids_kd == nearest_nodes) /num_results, 3)}% correct")
print(f"{np.round(np.sum(nearest_node_ids_kd != nearest_nodes) /num_results, 3)}% wrong (including order)")

In [None]:
start_time = time.time()
for i in range(1000):
    nearest_nodes = [hash_table[key] if key in hash_table else np.zeros(10) for key in search_keys]
end_time = time.time()
print(f"querying 1000x hash table without key conversion took {np.round(end_time-start_time,3)}s")

## LSH

In [None]:
from lshashpy3 import LSHash

In [None]:
# create 32-bit hashes for input data of 3 dimensions:
lsh = LSHash(32, 3)

# index vector
for i, loc in enumerate(model_locs):
    lsh.index(loc, extra_data=i)

In [None]:
lsh.hash_tables[0].keys()

In [None]:
start_time = time.time()
nn = np.zeros((search_locations.shape[0],10))
for i in range(search_locations.shape[0]):
    result = lsh.query(search_locations[i], num_results=10, distance_func="euclidean")
    try:
        nn[i] = [result[i][0][1] for i in range(len(result))]
    except:
        print([result[i][0][1] for i in range(len(result))])
end_time = time.time()
print(f"querying LSH took {end_time-start_time}s")