# kd-tree探索法(2D)

実装参考
+ [C言語によるkd-treeの実装 Qiita](https://qiita.com/fj-th/items/1bb2dc39f3088549ad6e)  
kd-tree探索法が何かも↑の記事(とリンク先)を読めばわかると思う


※後からC#に書き換えるのであまりPythonのライブラリに依存しないようにしたい

### ソート関数
Python仕様(というよりnumpy仕様)に対応するためのコード

In [None]:
import numpy as np
import matplotlib.pyplot as plt
"""
特定列を基準にソート
axis_xy : xなら0 yなら1
"""
def sortxy(arr:np.ndarray,axis_xy:int,offset:int = 1):
    axis = axis_xy + offset # offset
    return arr[np.argsort(arr[:,axis])]

# test
arr = np.array([[1000000,2,50],
               [2000000,0,100],
               [3000000,1,200]])
r = sortxy(arr,0)

### メイン

In [None]:
offset = 1
x_col = 0 + offset
y_col = 1 + offset
class Tree:
    def build(self,points):
        node = Node().set_node(points,len(points) - 1,0)
        self.top_node = node
        return node
    def search(self,sx:int,tx:int,sy:int,ty:int):
        search_results = []
        def _search(v:Node):
            nonlocal search_results,sx,tx,sy,ty
            if v.right_most < sx or v.left_most > tx or v.bottom_most > ty or v.top_most < sy:
                return;
            if v.left_child is None and v.right_child is None:
                if sx <= v.location[x_col] and sy <= v.location[y_col] and tx >= v.location[x_col] and ty >= v.location[y_col]:
                    self.search_results.append(v.location)
                    return
            if not(v.left_child is None):
                if(v.left_child.is_contained(sx,tx,sy,ty)):
                    search_results += self.report_subtree(v.left_child)
                else:
                    _search(v.left_child)
            if not(v.right_child is None):
                if(v.right_child.is_contained(sx,tx,sy,ty)):
                    search_results += self.report_subtree(v.right_child)
                else:
                    _search(v.right_child)
        _search(self.top_node)
        return search_results
    def report_subtree(self,node:Node = None,parent:Node=None,parent_border_axis:tuple = None,draw_border:tuple = None):
        """
        draw_border:tuple = (max_x,min_x,max_y,min_y)
        """
        if node is None:
            node = self.top_node
        if node.left_child is None and node.right_child is None:
            return [node.location]
        pba = parent_border_axis
        if not (draw_border is None) :
            xmax,xmin,ymax,ymin = draw_border
            s = node.border.copy()
            e = [0,0]
            if parent is None:
                s[1] = ymax
                e = np.array([s[0],ymin])
                pba = (s[0],None)
            else:
                if pba[0] is None:
                    y = ymax if(pba[1] < s[1]) else ymin
                    s[1] = pba[1]
                    e = np.array([s[0],y])
                    pba = (s[0],None)
                else: #つまり親がy軸と並行
                    x = xmax if(pba[0] < s[0]) else xmin
                    s[0] = pba[0]
                    e = np.array([x,s[1]])
                    pba = (None,s[1])
            plt.plot([s[0],e[0]],[s[1],e[1]])
        arr = []
        if not (node.left_child is None):
            arr += self.report_subtree(node.left_child,node,pba,draw_border)
        if not (node.right_child is None):
            arr += self.report_subtree(node.right_child,node,pba,draw_border)
        return arr
class Node:
    def set_node(self,points:np.ndarray,right:int,depth:int):
        if right < 0:
            return None
        elif right == 0:
            return self.set_leaf(points[right],depth)
        
        axis = depth % 2
        sorted_points = sortxy(points[:right + 1],axis)
        if axis == 0:
            self.left_most = sorted_points[0][x_col]
            self.right_most = sorted_points[right][x_col]
        else:
            self.bottom_most = sorted_points[0][y_col]
            self.top_most = sorted_points[right][y_col]
        median = int(right / 2)
        self.border = (sorted_points[median][1:] + sorted_points[median+1][1:])/2.0
        self.location = sorted_points[median]
        self.depth = depth
        self.right_child = Node().set_node(sorted_points[median
                                                          + 1:],right -( median + 1),depth + 1)
        self.left_child = Node().set_node(sorted_points,median,depth + 1)
        
        #ここから先は関連付け
        if(axis == 0):
            if(not(self.right_child is None) and not(self.left_child is None)):
                self.top_most = self.right_child.top_most if self.right_child.top_most > self.left_child.top_most else self.left_child.top_most
                self.bottom_most = self.right_child.bottom_most if self.right_child.bottom_most < self.left_child.bottom_most else self.left_child.bottom_most
            elif not(self.right_child is None):
                self.top_most = self.right_child.top_most
                self.bottom_most = self.right_child.bottom_most
            elif not(self.left_child is None):
                self.top_most = self.left_child.top_most
                self.bottom_most = self.left_child.bottom_most
            else:
                self.top_most = self.location[y_col]
                self.bottom_most = self.location[y_col]
        else:
            if(not(self.right_child is None) and not(self.left_child is None)):
                self.right_most = self.right_child.right_most if self.right_child.right_most > self.left_child.right_most else self.left_child.right_most
                self.left_most = self.right_child.left_most if self.right_child.left_most < self.left_child.left_most else self.left_child.left_most
            elif not(self.right_child is None):
                self.right_most = self.right_child.right_most
                self.left_most = self.right_child.left_most
            elif not(self.left_child is None):
                self.right_most = self.left_child.right_most
                self.left_most = self.left_child.left_most
            else:
                self.right_most = self.location[x_col]
                self.left_most = self.location[x_col]
                
        return self
    def set_leaf(self,location:np.ndarray,depth:int):
        self.location = location
        self.left_child = None
        self.right_child = None
        self.depth = depth
        self.left_most = location[x_col]
        self.right_most = location[x_col]
        self.top_most = location[y_col]
        self.bottom_most = location[y_col]
        return self
    def is_contained(self,sx:int,tx:int,sy:int,ty:int):
        """
        > regionノードから始まるsubtreeの要素が指定領域にすっぽり収まっているか否かを返す。  
        > 各ノードに、自分以下の子の最大、最小値をもたせているので、それを参照するだけで判定できる。
        実装参考のis_contained関数より引用
        """
        return not(self.left_most < sx or self.right_most > tx or self.top_most > ty or self.bottom_most < sy)

In [None]:
# test
points = [[1,2],
          [2,4],
          [5,6],
          [7,8],
          [3,10],
          [11,11],
         [10,10],
         [0,8]]
points = np.array([[i,*p] for i, p in enumerate(points)])
tree = Tree()
node = tree.build(points)
plt.scatter(points[:,1],points[:,2])
plt.show()
plt.scatter(points[:,1],points[:,2])
tree.report_subtree(draw_border=(points[:,1].max(),points[:,1].min(),points[:,2].max(),points[:,2].min()))
plt.show()

### 探索
search関数の第一引数:x最小値,第二引数:x最大値,第三引数:y最小値,第四引数:y最大値

In [None]:
results = np.array(tree.search(0,4,1,4))
print(results.shape)
plt.scatter(results[:,x_col],results[:,y_col])
plt.xlim([points[:,x_col].min()-1,points[:,x_col].max()+1])
plt.ylim([points[:,y_col].min()-1,points[:,y_col].max()+1])
print(tree.top_node.right_most)