# KD Tree

A KD-Tree (k-dimensional tree) is a data structure that can be used to efficiently search for nearest neighbors in multidimensional space.

In [2]:
import numpy as np 
from scipy.spatial import KDTree

- **KDTree creation:** The KDTree is built using an array of points. Each point is represented as a vector of coordinates.

- **Query:** We query the KDTree for the nearest neighbor to the point [2, 3]. The .query() method returns the distance to the nearest neighbor and the index of that neighbor in the original points array.

### Search for the nearest neighbor of a given point

In [7]:
points = np.array([[3,6], [3,4], [5,6],[7,8]])
tree = KDTree(points)

query_point = [2,3]

distance, index = tree.query(query_point)

print(f"The nearest neighbor to {query_point} is at index {index}, and the distance is {distance}")

The nearest neighbor to [2, 3] is at index 1, and the distance is 1.4142135623730951


### Find all points in a particular radius of a specific point

In [8]:
points = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])

tree = KDTree(points)

query_point = [2, 3]
radius = 3.0

indices_within_radius = tree.query_ball_point(query_point, radius)

points_within_radius = points[indices_within_radius]

print(f"Points within radius {radius} of {query_point}:")
print(points_within_radius)

Points within radius 3.0 of [2, 3]:
[[1 2]
 [3 4]]


Additional Information:<br>
You can also search for points using different distance metrics, by specifying them in the p parameter of query_ball_point, such as:

- p=2: Euclidean distance (default)
- p=1: Manhattan distance (L1 norm)
- p=∞: Maximum coordinate difference (Chebyshev distance)

### Find the indices of points within a particular radius for all points

In [13]:
points = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])

tree = KDTree(points)

radius = 3

families_indices = []
families_points = []

# Go through all points in points
for i, point in enumerate(points):

    #Find indices of all points within the given radius for each point
    indices_within_radius = tree.query_ball_point(point, radius)
    
    #Remove the index of the point itself from the list
    indices_within_radius = [idx for idx in indices_within_radius if idx != i]

    #Append the indices of the family (including the point itself)
    families_indices.append(indices_within_radius)

    #Append the family points themselves to the families array
    family_points = points[indices_within_radius]
    families_points.append(family_points)

families_indices = np.array(families_indices, dtype=object)
families_points = np.array(families_points, dtype=object)

print("Families of indices within radius: ")
print(families_indices)

print("\nFamilies of points within radius: ")
for family in families_points:
    print(family)

Families of indices within radius: 
[list([1]) list([0, 2]) list([1, 3]) list([2, 4]) list([3])]

Families of points within radius: 
[[3 4]]
[[1 2]
 [5 6]]
[[3 4]
 [7 8]]
[[ 5  6]
 [ 9 10]]
[[7 8]]
