# 二分木探索

https://interactivepython.org/courselib/static/pythonds/Trees/SearchTreeImplementation.html

In [1]:
class TreeNode:
    def __init__(self, key, val, left=None, right=None, parent=None):
        self.key = key
        self.val = val
        self.left = left
        self.right = right
        self.parent = parent
        
    def __iter__(self):
        if self:
            if self.has_left_child():
                for elem in self.left:
                    yield elem
            yield self.key
            if self.has_right_child():
                for elem in self.right:
                    yield elem
        
    def has_left_child(self):
        return self.left
    
    def has_right_child(self):
        return self.right
    
    def is_left_child(self):
        return self.parent and self.parent.left == self

    def is_right_child(self):
        return self.parent and self.parent.right == self
    
    def is_root(self):
        return not self.parent
    
    def is_leaf(self):
        return not (self.right or self.left)
    
    def has_any_children(self):
        return self.right or self.left
    
    def has_both_children(self):
        return self.right and self.left
    
    def replace_node(self, key, value, left, right):
        self.key = key
        self.value = value
        self.left = left
        self.right = right
        
        if self.has_left_child():
            self.left.parent = self
        if self.has_right_child():
            self.right.parent = self
            
    def find_successor(self):
        succ = None
        if self.has_right_child():
            succ = self.right.find_min()
        else:
            if self.parent is not None:
                if self.parent.is_left_child():
                    succ = self.parent
                else:
                    self.parent.right = None
                    succ = self.parent.find_successor()
                    self.parent.right = self
        return succ
                    
    def find_min(self):
        current = self
        while current.has_left_child():
            current = current.left
        return current
    
    def splice_out(self):
        if self.is_leaf():
            if self.is_left_child():
                self.parent.left = None
            else:
                self.parent.right = None
        elif self.has_any_children():
            if self.has_left_child():
                if self.is_left_child():
                    self.parent.left = self.left
                else:
                    self.parent.right = self.left
                self.left.parent = self.parent
            else:
                if self.is_left_child():
                    self.parent.left = self.right
                else:
                    self.parent.right = self.right
                self.right.parent = self.parent

In [2]:
class BinarySearchTree:
    def __init__(self):
        self.root = None
        self.size = 0
    
    def __len__(self):
        return self.size
    
    def __iter__(self):
        return self.root.__iter__()
    
    def __setitem__(self, key, val):
        self.put(key, val)
        
    def __getitem__(self, key):
        return self.get(key)
    
    def __contains__(self, key):
        if self._get(key, self.root):
            return True
        else:
            return False
        
    def __delitem__(self, key):
        self.delete(key)
    
    def put(self, key, val):
        if self.root is not None:
            self._put(key, val, self.root)
        else:
            self.root = TreeNode(key, val)
        self.size += 1
        
    def _put(self, key, val, current_node):
        if key < current_node.key:
            if current_node.has_left_child():
                self._put(key, val, current_node.left)
            else:
                current_node.left = TreeNode(key, val, parent=current_node)
        else:
            if current_node.has_right_child():
                self._put(key, val, current_node.right)
            else:
                current_node.right = TreeNode(key, val, parent=current_node)
                
    def get(self, key):
        if self.root is not None:
            res = self._get(key, self.root)
            if res is not None:
                return res.val
            else:
                return None
        else:
            return None
        
    def _get(self, key, current_node):
        if not current_node:
            return None
        elif current_node.key == key:
            return current_node
        elif current_node.key > key:
            return self._get(key, current_node.left)
        else:
            return self._get(key, current_node.right)
        
    def delete(self, key):
        if self.size > 1:
            node_remove = self._get(key, self.root)
            if node_remove is not None:
                self.remove(node_remove)
                self.size -= 1
            else:
                raise KeyError('Key not in tree')
        elif self.size == 1 and self.root.key == key:
            self.root = None
            self.size -= 1
        else:
            raise KeyError('Key not in tree')
            
    def remove(self, node):
        if node.is_leaf():
            if node == node.parent.left:
                node.parent.left = None
            else:
                node.parent.right = None
        elif node.has_both_children():
            succ = node.find_successor()
            succ.splice_out()            
            node.key = succ.key
            node.val = succ.val
        else:
            if node.has_left_child():
                if node.is_left_child():
                    node.left.parent = node.parent
                    node.parent.left = node.left
                elif node.is_right_child():
                    node.left.parent = node.parent
                    node.parent.right = node.left
                else:
                    node.replace_node(node.left.key, node.left.val,
                                      node.left.left, node.left.right)
            else:
                if node.is_left_child():
                    node.right.parent = node.parent
                    node.parent.left = node.right
                elif node.is_right_child():
                    node.right.parent = node.parent
                    node.parent.right = node.right
                else:
                    node.replace_node(node.right.key, node.right.val,
                                      node.right.left, node.right.right)

In [3]:
bst = BinarySearchTree()
input_ = [(3, 'red'), (4, 'blue'), (1, 'green'), (0, 'q'), (-1, 'minus'),
          (10, 'ten')]

for k, v in input_:
    bst[k] = v


print(bst[1])
print([(k, bst[k]) for k in bst])

green
[(-1, 'minus'), (0, 'q'), (1, 'green'), (3, 'red'), (4, 'blue'), (10, 'ten')]


In [4]:
del bst[3]

In [5]:
print([(k, bst[k]) for k in bst])

[(-1, 'minus'), (0, 'q'), (1, 'green'), (4, 'blue'), (10, 'ten')]
