BST library for playing around. Include this into your notebook as follows:

```py
%run datasets.ipynb
%run trees_lib.ipynb
dataset = datasets_thurgau_7()
#dataset = datasets_cities_200k()
#dataset = datasets_world_12M()

kdtree = build_kd_tree(sorted_tuples(dataset, ['latitude', 'longitude']), len(dataset))
# All towns in Oberthurgau
lower_left = (47.5, 9.2)
upper_right = (47.6, 9.5)
print("Towns in Oberthurgau:")
for node in search_kd_tree(kdtree, [(lower_left[0], upper_right[0]), (lower_left[1], upper_right[1])]):
    print(node.key, node.value, dataset[node.value]['name'])

```

In [None]:
def walk_tree(node):
    if node is None:
        return
    yield from walk_tree(node.left)
    yield node
    yield from walk_tree(node.right)


%pip install --quiet graphviz
import graphviz

def render_tree(graph, node, nuller=""):
    """Renders the tree in graphviz."""
    if node is None:
        # In order to have separate Nil-nodes, we need to create artificially
        # named nodes with unique names. We use the 'nuller' parameter to create
        # these, which is the left/right-path down from the root.
        graph.node(nuller, "", shape="point")
        return nuller
    
    id = str(node.key)
    graph.node(id, f"< <B>key:</B> {id}<BR/><B>id:</B> {str(node.value)}<BR/><B>town:</B> {towns[node.value]['name']} >")
    left_key = render_tree(graph, node.left, nuller + "l")
    graph.edge(id, left_key)
    right_key = render_tree(graph, node.right, nuller + "r")
    graph.edge(id, right_key)
    return id

def tree_graph(tree, title):
    """Returns a graphviz tree rendering of tree."""
    dot = graphviz.Digraph(title)
    render_tree(dot, tree)
    return dot


In [None]:
class BstNode():
    """Invariant: All keys in left subtree are <= key, all keys in right subtree are >= key."""
    def __init__(self, key, value):
        self.left = None
        self.right = None
        self.key = key
        self.value = value

def key_id_tuples(dataset, attributes):
    for id,element in dataset.items():
        if type(attributes) == list:
            key = tuple(float(element[attribute]) for attribute in attributes)
        else:
            key = float(element[attributes])
        entry = (key, id)
        yield entry

def sorted_tuples(dataset, attributes):
    sorted = list(key_id_tuples(dataset, attributes))
    sorted.sort()
    return sorted
        
def build_bst(sorted_tuples, lower=None, upper=None):
    """Include all elements in sorted_tuples from lower to upper indices (inclusive)."""
    if lower is None:
        lower = 0
    if upper is None:
        upper = len(sorted_tuples) - 1
    if lower > upper:
        return None
    median_index = (lower + upper) // 2
    median_element = sorted_tuples[median_index]
    node = BstNode(median_element[0], median_element[1])
    node.left = build_bst(sorted_tuples, lower, median_index - 1)
    node.right = build_bst(sorted_tuples, median_index + 1, upper)
    return node

def tree_search(node, lower, upper):
    """Traverses tree in-order from lower to upper key."""
    if node is None:
        return  # reached a leaf node
    if lower <= node.key:
        # return elements on our left
        yield from tree_search(node.left, lower, upper)
    if lower <= node.key <= upper:
        yield node  # return this node
    if node.key <= upper:
        # return elements on our right
        yield from tree_search(node.right, lower, upper)

In [None]:
class KdNode:
    """A node in a kd-tree."""
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.left = None
        self.right = None

from tqdm.auto import tqdm
def insert_node(parent, node, depth=0):
    keys = parent.key
    dimension_index = depth % len(keys)
    key = keys[dimension_index]
    if node.key[dimension_index] <= key:
        if parent.left is None:
            parent.left = node
        else:
            insert_node(parent.left, node, depth + 1)
    else:
        if parent.right is None:
            parent.right = node
        else:
            insert_node(parent.right, node, depth + 1)   

def build_kd_tree(tuples, max_count):
    tuples = iter(tuples)
    key, value = next(tuples)
    tree = KdNode(key, value)
    for key, value in tqdm(tuples, total=max_count-1):
        node = KdNode(key, value)
        insert_node(tree, node)
    return tree

def is_within_bounds(keys, bounds):
    """Returns True if all of the keys are within the given corresponding bounds."""
    # keys: (47.56638, 9.10588)
    # bounds: [(-90, 90), (-180, 180)]
    """Checks on all dimensions in bounds whether it contains keys."""
    for (lower, upper), key in zip(bounds, keys):
        if not lower <= key <= upper:
            return False
    return True

def search_kd_tree(node, bounds, depth=0):
    """Searches a kd-tree and yields elements within bounds."""
    # Example: key is (47.56638, 9.10588)
    # Depth is 0 (splitting on latitude direction)
    # Bounds: [(-90, 90), (-180, 180)]

    if node is None:
        return
    
    # For deciding to search the subtrees, we only need to consider
    # the current dimension_index.
    dimension_index = depth % len(bounds)
    key = node.key[dimension_index]
    lower, upper = bounds[dimension_index]

    # If lower <= key, the search interval extends to our left.
    if lower <= key:
        yield from search_kd_tree(node.left, bounds, depth+1)
    
    # For inclusion, we need to check all dimensions, not just the current
    # dimension_index.
    if is_within_bounds(node.key, bounds):
        yield node
    
    # If key <= upper, the search interval extends to our right.
    if key <= upper:
        yield from search_kd_tree(node.right, bounds, depth+1)