# 树

> 由一个个节点连接而成，每个节点只能从一条路径访问到

## 核心概念

> 父节点：节点在路径中前一个节点
>
> 子节点：节点在路径中的后续节点
>
> 根节点：入口节点
>
> 叶子节点：没有子节点
>
> 兄弟节点：拥有同一个父节点
>
> 节点高度：节点到叶子节点的最长路径
>
> 节点深度：根节点到节点的路径
>
> 层数：深度+1

## 二叉树

> 最多拥有2个子节点
>
> 满二叉树（Full Binary Tree）：除了最后一层的叶子节点，其他节点都拥有2个子节点。或者说，每个节点都有0或者2个子节点
>
> 完全二叉树（Complete Binary Tree）：可能除了最后一层，其他层是完全填满的，并且最后一层节点从左往右是连续的
>
> 完美二叉树（Perfect Binary Tree）：每个节点都有两个子节点，并且叶子节点在同一层

### 二叉树的数据结构

> 链表存储：最常见
>
> 数组存储：适用于完全二叉树。根节点存储在数组下标1的位置，如果父节点的下标为i，左子节点的下标为2i，右子节点的下标为：2i+1

### 二叉树的遍历

In [4]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [50]:
# Tree Node
class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right
    
    def __repr__(self):
        return f"Node({self.value}, {self.left}, {self.right})"

node = Node(1, Node(2, Node(4), Node(5)), Node(3, Node(6), Node(7)))

#### 前序遍历

> 根节点 -> 左子节点 -> 右子节点

In [71]:
def pre_order(node):
    if node is None:
        return
    print(node.value)
    pre_order(node.left)
    pre_order(node.right)

pre_order(node)

1
2
4
5
3
6
7


In [72]:
def pre_order_stack(node):
    stack = [node]
    while stack:
        node = stack.pop()
        if not node:
            continue
        print(node.value)
        stack.append(node.right)
        stack.append(node.left)

pre_order_stack(node)

1
2
4
5
3
6
7


#### 中序遍历

> 左子节点 -> 根节点 -> 右子节点

In [73]:
def in_order(node):
    if not node:
        return
    in_order(node.left)
    print(node.value)
    in_order(node.right)

in_order(node)

4
2
5
1
6
3
7


In [74]:
def in_order_stack(node):
    stack = []
    while stack or node:
        while node:
            stack.append(node)
            node = node.left

        node = stack.pop()
        print(node.value)
        node = node.right

in_order_stack(node)

4
2
5
1
6
3
7


#### 后续遍历

> 左子节点 -> 右子节点 -> 根节点

In [75]:
def post_order(node):
    if not node:
        return
    post_order(node.left)
    post_order(node.right)
    print(node.value)

post_order(node)

4
5
2
6
7
3
1


In [76]:
def post_order_stack(node):
    stack, out = [node], []
    while stack:
        node = stack.pop()
        if not node:
            continue
        out.append(node.value)
        stack.append(node.left)
        stack.append(node.right)
    print(out[::-1])

post_order_stack(node)

[4, 5, 2, 6, 7, 3, 1]


#### 按层级从左到右遍历

In [51]:
def layer_order(node):
    if node is None:
        return
    cur = [node]
    next = []
    while cur:
        for node in cur:
            print(node.value)
            if node.left:
                next.append(node.left)
            if node.right:
                next.append(node.right)
        cur = next
        next = []

layer_order(node)

1
2
3
4
5
6
7


In [52]:
from collections import deque


def layer_order_v2(node):
    queue = deque(node)
    while queue:
        node = queue.popleft()
        if not node:
            continue
        print(node.value)
        queue.append(node.left)
        queue.append(node.right)

layer_order(node)

1
2
3
4
5
6
7


#### 序列化 & 反序列化

In [None]:
class Codec:
    def serialize(self, root):
        """Encodes a tree to a single string.
        
        :type root: TreeNode
        :rtype: str
        """
        def dfs(node):
            if not node:
                return "null,"
            return f"{node.val}," + dfs(node.left) + dfs(node.right)

        return dfs(root)

        # bfs shown below
        res = []
        queue = deque([root])
        while queue:
            node = queue.popleft()
            if not node:
                res.append("null")
                continue
            res.append(str(node.val))
            queue.append(node.left)
            queue.append(node.right)
        
        return ",".join(res)

    def deserialize(self, data):
        """Decodes your encoded data to tree.
        
        :type data: str
        :rtype: TreeNode
        """
        if not data:
            return None
        vals = data.split(",")
        def dfs():
            val = vals.pop(0)
            if val == "null":
                return None
            node = Node(int(val))
            node.left = dfs()
            node.right = dfs()

            return node
        
        return dfs()

        # bfs shown below
        vals = data.split(",")
        if vals[0] == "null":
            return None
        
        root = Node(int(vals[0]))
        queue = deque([root])
        i = 1
        while queue:
            node = queue.popleft()
            if vals[i] != "null":
                node.left = Node(int(vals[i]))
                queue.append(node.left)
            i += 1
            if vals[i] != "null":
                node.right = Node(int(vals[i]))
                queue.append(node.right)
            i += 1
        return root


#### 多叉树序列号 & 反序列化

In [25]:
class TreeNode:
    def __init__(self, val, children=None):
        self.val = val
        self.children = children if children else []

class TreeCodec:
    def serialize(self, root):
        res = []
        def dfs(node):
            if not node:
                return
            res.append(str(node.val))
            res.append(str(len(node.children)))
            for child in node.children:
                dfs(child)
        
        dfs(root)
        return ",".join(res)

    def deserialize(self, data):
        if not data:
            return None
        vals = iter(data.split(","))
        
        def dfs():
            val = int(next(vals))
            children_size = int(next(vals))
            node = TreeNode(val)
            node.children = [dfs() for _ in range(children_size)]
            return node
        
        return dfs()

root = TreeNode(1, [
    TreeNode(2),
    TreeNode(3, [TreeNode(5), TreeNode(6)]),
    TreeNode(4)
])


treecodec = TreeCodec()
data = treecodec.serialize(root)
print("Serialized:", data)

root2 = treecodec.deserialize(data)
print("Re-Serialized:", treecodec.serialize(root2))

Serialized: 1,3,2,0,3,2,5,0,6,0,4,0
Re-Serialized: 1,3,2,0,3,2,5,0,6,0,4,0


#### 求树的高度

In [78]:
# 按层遍历
def tree_height(root):
    if not root:
        return 0
    height = 0
    cur = [root]
    next = []
    while cur:
        height += 1
        for node in cur:
            if node.left:
                next.append(node.left)
            if node.right:
                next.append(node.right)
        cur = next
        next = []
    return height

tree_height(node)

# recursion
def tree_height_v2(root):
    if not root:
        return 0
    return 1 + max(tree_height_v2(root.left), tree_height_v2(root.right))

tree_height_v2(node)

3

3

## 二叉查找树

> 节点的值大于左子树所有节点的值，小于等于右子树所有节点的值，按中序遍历可得升序排序的数组

### 常见的操作

In [79]:
bst = Node(4, Node(2, Node(1), Node(3)), Node(6, Node(5), Node(7)))

In [80]:
# search in binary search tree (BST) with different value
def search_bst(root, val):
    while root and root.value != val:
        if root.value > val:
            root = root.left
        else:
            root = root.right
    
    return root

search_bst(bst, 6)

Node(6, Node(5, None, None), Node(7, None, None))

In [81]:
# add node to bst
def insert_into_bst(root, val):
    if not root:
        return Node(val)
    cur = root
    while True:
        if cur.value > val:
            if cur.left:
                cur = cur.left
            else:
                cur.left = Node(val)
                return root
        else:
            if cur.right:
                cur = cur.right
            else:
                cur.right = Node(val)
                return root

insert_into_bst(bst, 8)

Node(4, Node(2, Node(1, None, None), Node(3, None, None)), Node(6, Node(5, None, None), Node(7, None, Node(8, None, None))))

In [82]:
# delete node from bst
def delete_node(root, key):
    if not root:
        return root
    
    pp = None
    p = root
    while p and p.value != key:
        pp = p
        if p.value > key:
            p = p.left
        else:
            p = p.right
    
    if not p:
        return root
    
    if p.left and p.right:
        min_pp = p
        min_p = p.right
        while min_p.left:
            min_pp, min_p = min_p, min_p.left
        p.value = min_p.value
        pp, p = min_pp, min_p

    child = p.left if p.left else p.right

    if not pp:
        return child
    
    if pp.left == p:
        pp.left = child
    else:
        pp.right = child
    
    return root


delete_node(bst, 8)

Node(4, Node(2, Node(1, None, None), Node(3, None, None)), Node(6, Node(5, None, None), Node(7, None, None)))

In [83]:
# find minimum value
def minimum_in_bst(root):
    while root:
        if root.left:
            root = root.left
        else:
            return root.value

minimum_in_bst(bst)

1

In [84]:
# find maximum value
def maximum_in_bst(root):
    while root:
        if root.right:
            root = root.right
        else:
            return root.value

maximum_in_bst(bst)

7

In [85]:
class Path:
    def __init__(self, value, parent=None):
        self.value = value
        self.parent = parent

In [86]:
# find the biggest one which smaller than value
def precursor_node(root, val):
    pp = None
    pl = None
    # path = Path(root)
    while root and root.value != val:
        pp = root
        if root.value > val:
            root = root.left
        else:
            root = root.right
        
        if pp.right == root:
            pl = pp
        # path = Path(root, path)
    
    if not root:
        return None

    if not root.left:
        return pl.value if pl else None

        while path.parent and path.parent.value.right != path.value:
            path = path.parent
        
        if path.parent:
            return path.parent.value.value
        else:
            return None
    
    root = root.left
    while root.right:
        root = root.right
    
    return root.value

precursor_node(bst, 3)
precursor_node(bst, 4)
precursor_node(bst, 5)

2

3

4

In [87]:
# find the smallest one which bigger than value
def successor_node(root, val):
    pp = None
    pr = None
    # path = Path(root)
    while root and root.value != val:
        pp = root
        if root.value > val:
            root = root.left
        else:
            root = root.right
            
        if pp.left == root:
            pr = pp
        # path = Path(root, path)
    
    if not root.right:
        return pr.value if pr else None

        while path.parent and path.parent.value.left != path.value:
            path = path.parent
        
        if path.parent:
            return path.parent.value.value
        else:
            return None
    
    root = root.right
    while root.left:
        root = root.left
    
    return root.value

successor_node(bst, 3)
successor_node(bst, 4)
successor_node(bst, 5)

4

5

6

### 时间复杂度分析

> 当BST退化成链表时，最差时间复杂度为O(n)，也等于O(height)
>
> 当BST是完全二叉树时，遍历跟height有关，height在[log(n+1)-1, logn]之间，所以最差时间复杂度为O(logn)

> 可见性能和树的高度有关，平衡二叉搜索树高度接近logn

### 平衡二叉搜索树对比hash表的优势

> 1. 数据有序
> 2. hash表扩容、hash冲突导致性能不稳定
> 3. hash表设计相对比较复杂，要考虑hash函数设计，hash冲突、扩缩容解决方法。平衡二叉搜索树之需要考虑平衡问题


## 平衡二叉搜索树

> 严格定义：任一节点的左右子树高度差不超过1
>
> 常见：AVL树

### 红黑树（Red-Black Tree）

> 工程应用中一种流行的“不严格”的平衡二叉搜索树，子节点的高度差可能达到一倍
>
> 定义：节点被标记为黑色或者红色
> 1. 根节点是黑色
> 2. 每个叶子节点是黑色的空节点，也就是叶子节点不存储数据
> 3. 任何相邻的节点不能同时为红色
> 4. 根节点到子节点的所有路径上的黑色节点数量一致
>
> 优点：常见操作保持对数级复杂度，且为了保持相对平衡成本比较低

#### 插入

> A. 按照bst算法找到插入点插入，标红
>
> B. 进入下面调整逻辑
> 1. 如果没有父节点或者父节点黑，结束
> 2. 如果叔节点红，父、叔标黑，祖标红，关注节点变成祖，进入下一轮
> 3. 如果不同边（子在父的左边，父在祖的右边，或着反过来），关注点变成父，旋转，父标黑，祖标红，围绕祖反向旋转，进入下一轮

#### 删除

[参考](image/TreeMap红黑树源码详解.pdf)

#### 复杂度分析

> 如果把红色节点去掉，剩下黑色节点构成的树是一颗完全二叉树（根节点到子节点经过的路径上黑节点数量一样），bh（black tree height）和 n 的关系 n >= 2^bh - 1，bh <= log(n+1)
>
> 因为不会有两个相邻的红色节点，所以根节点到子节点的一条路径上，红色节点数量最多和黑色节点数量一样，所以树高最大值：h = 2bh = 2log(n+1)

## 堆

> 定义
> 1. 完全二叉树
> 2. 每个节点值大于等于子节点的值（大顶堆），每个节点值小于等于子节点的值（小顶堆）

### 常见操作 & 实现
> 使用数组存储
>
> 插入：从下往上构建堆，依次和父节点比较、交换，时间复杂度O(logn)
>
> 删除堆顶：把尾节点覆盖根节点，从上往下构建堆，依次和子节点比较、交换，时间复杂度O(logn)
>
> 建堆：O(n)，从非叶子节点从右往左，自上而下构建，非叶子节点有2^(h-1)个，每个移动1次，再上一层有2^(h-2)个，每个移动2次，以此类推，一共需要移动:2^(h-1)*1 + 2^(h-2)*2 + …… + 1*h = n

### 常见应用
> 1. 排序
> 2. topK
> 3. 求中位数
> 4. 优先队列

In [None]:
# Max Heap
class Heap:
    def __init__(self, array):
        self.array = array
        self.count = len(array)
        self._build_heap()

    def push(self, val):
        self.array.append(val)
        i = self.count
        while ((i - 1) >> 1) >= 0 and self.array[i] > self.array[(i-1)>>1]:
            self.array[i], self.array[(i-1)>>1] = self.array[(i-1)>>1], self.array[i]
            i = (i - 1) >> 1
        self.count += 1
    
    def pop(self):
        if self.count == 0:
            return
        self.array[0], self.array[-1] = self.array[-1], self.array[0]
        top = self.array.pop()
        self.count -= 1
        Heap._heapify(self.array, self.count, 0)
        return top
    
    @staticmethod
    def _heapify(array, n, i):
        while True:
            max_pos = i
            if 2*i+1 < n and array[2*i+1] > array[max_pos]:
                max_pos = 2*i+1
            
            if 2*i+2 < n and array[2*i+2] > array[max_pos]:
                max_pos = 2*i+2
            
            if max_pos == i:
                break

            array[i], array[max_pos] = array[max_pos], array[i]
            i = max_pos
    
    def _build_heap(self):
        '''time complexity O(n)'''
        for i in range((self.count-1) >> 1, -1, -1):
            Heap._heapify(self.array, self.count, i)
    
    def sorted(self):
        array = self.array[:self.count]
        for i in range(len(array)-1, -1, -1):
            array[0], array[i] = array[i], array[0]
            Heap._heapify(array, i ,0)
        return array

    # def topK(self, k):
    #     array = self.array[:self.count]
    #     k = min(k, len(array))
    #     for i in range(k):
    #         max_pos = len(array)-1-i
    #         array[0], array[max_pos] = array[max_pos], array[0]
    #         Heap._heapify(array, max_pos ,0)
    #     return array[-k:][::-1]



array = [7,5,19,8,4,1,20,13,16]

heap = Heap(array)
array
heap.push(22)
array
heap.pop()
array
heap.sorted()
# heap.topK(5)

[20, 16, 19, 13, 4, 1, 7, 5, 8]

[22, 20, 19, 13, 16, 1, 7, 5, 8, 4]

22

[20, 16, 19, 13, 4, 1, 7, 5, 8]

[1, 4, 5, 7, 8, 13, 16, 19, 20]

[20, 19, 16, 13, 8]

In [47]:
# topk
import heapq
import math

def topk(array, k):
    priortiy = [float('-inf')] * k
    for val in array:
        if val > priortiy[0]:
            heapq.heappushpop(priortiy, val)
    return priortiy

topk(array, 5)

# quantile
class QuantileFinder:
    def __init__(self, quantile=0.5):
        self.quantile = quantile
        self.max_queue = []
        self.min_queue = []
    
    def insert(self, val):
        heapq.heappush(self.max_queue, -val)

        if self.min_queue and -self.max_queue[0] > self.min_queue[0]:
            heapq.heappush(self.min_queue, -heapq.heappop(self.max_queue))
        
        if len(self.max_queue) > math.ceil(self.quantile * (len(self.max_queue) + len(self.min_queue))):
            heapq.heappush(self.min_queue, -heapq.heappop(self.max_queue))
        elif len(self.min_queue) > int((1 - self.quantile) * len(self.max_queue) + len(self.min_queue)):
            heapq.heappush(self.max_queue, -heapq.heappop(self.min_queue))
    
    def take(self):
        return -self.max_queue[0]
        

median_finder = QuantileFinder(0.3)
median_finder.insert(1)
median_finder.take()
median_finder.insert(2)
median_finder.take()
median_finder.insert(3)
median_finder.take()

[8, 13, 20, 19, 16]

1

1

1

## 线段树（Segment Tree）

> 应用场景

> 给一个数组
> 1. 更新数组的值
> 2. 求一定范围内的值（求和、最大、最小等等）

> 常规解决思路
> 1. 直接用原始数组，更新值的时间复杂度是O(1)，求范围值的时间复杂度是O(n)
> 2. 使用前缀和数据（数组第i-1个位置保存前i个值的和），更新值的时间复杂度是O(n), 求范围值的时间复杂度是O(n)

> 线段树的结构
> 1. 数组的值是树的叶子节点
> 2. 重复把数组一份为2，直到每一份只有一个节点为止
> 3. 父节点是子节点的范围值

> 更新值
> 1. 如果更新的下标i落在节点的范围内，更新节点的值
> 2. 使用二分重复步骤1，直到节点范围不包括i，或者找到i子节点

> 求范围值
> 1. 若范围完全覆盖节点范围，返回节点值
> 2. 若范围完全不覆盖节点范围，返回0
> 3. 否则二分

> [参考](https://www.geeksforgeeks.org/dsa/segment-tree-sum-of-given-range/)

In [6]:
class SegmentTree:
    def __init__(self, array):
        self.array = array
        self.n = len(array)
        self.st = [0] * self.n * 4
        self._buildST(0, self.n-1, 0)
    
    def _buildST(self, start, end, idx):
        if start == end:
            self.st[idx] = self.array[start]
            return self.array[start]

        mid = start + (end - start) // 2

        val = self._buildST(start, mid, 2*idx+1) + self._buildST(mid+1, end, 2*idx+2)
        self.st[idx] = val
        return val
    
    def update(self, i, v):
        diff = v - self.array[i]
        self.array[i] = v
        self._updateST(i, diff, 0, self.n-1, 0)
    
    def _updateST(self, i, diff, start, end, idx):
        if i < start or end < i:
            return
        
        self.st[idx] += diff
        if start != end:
            mid = start + (end - start) // 2
            self._updateST(i, diff, start, mid, 2*idx+1)
            self._updateST(i, diff, mid+1, end, 2*idx+2)

    def accumulate(self, left, right):
        return self._accumulateST(left, right, 0, self.n-1, 0)
    
    def _accumulateST(self, left, right, start, end, idx):
        if left <= start and end <= right:
            return self.st[idx]
        if end < left or right < start:
            return 0
        
        mid = start + (end - start) // 2
        return self._accumulateST(left, right, start, mid, 2*idx+1) + self._accumulateST(left, right, mid+1, end, 2*idx+2)



arr = [1,3,5,7,9,11]
st = SegmentTree(arr)
st.st
st.update(1, 2)
st.st
st.accumulate(1,3)

[36, 9, 27, 4, 5, 16, 11, 1, 3, 0, 0, 7, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

[35, 8, 27, 3, 5, 16, 11, 1, 2, 0, 0, 7, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

14

## 树状数组（Binary Index Tree, Fenwick Tree）

> 和线段树一样，数据结构上都是数组
> 相比线段树，优点是只需要 N 额外的空间，实现逻辑更简单

> 结构
> 申请一个长度为 n + 1 的数组，从下标 1 开始存数据

> 初始化 & 更新
> 初始化：对每个数据 arr[i]，存储到 BIT[j]（j 初始值 i+1）+= arr[i]，循环更新下标 j += j & (-j)，并更新 BIT[j]，直到 j > n 停止
> 更新：和初始化的逻辑一样

> 前缀和（前 i 个数值之和）查询
> 累加 BIT[i] 值
> 更新下标 i -= i & (-i)，如果 i > 0，则重复上面的步骤

> [参考](https://www.geeksforgeeks.org/dsa/binary-indexed-tree-or-fenwick-tree-2/)

In [23]:
class BIT:
    def __init__(self, array):
        self.array = array
        self.bit = [0] * (len(array) + 1)
        self._buildBIT()
    
    def _buildBIT(self):
        for i, v in enumerate(self.array):
            self._update(i, v)
    
    def _update(self, i, v):
        i += 1
        while i < len(self.bit):
            self.bit[i] += v
            i += i & (-i)
    
    def _getPrefixSum(self, i):
        i += 1
        sum_ = 0
        while i > 0:
            sum_ += self.bit[i]
            i -= i & (-i)
        return sum_

    def update(self, i, v):
        if i < 0 or i >= len(self.array):
            raise IndexError("out of index")
        diff = v - self.array[i]
        self.array[i] = v
        self._update(i, diff)
    
    def getSum(self, left, right):
        if left < 0 or right >= len(self.array):
            raise IndexError("out of index")
        
        if right < left:
            raise TypeError("left must less than right")
        
        return self._getPrefixSum(right) - self._getPrefixSum(left-1)
        

bit = BIT([2,1,1,3,2,3,4,5,6,7,8,9])
bit.bit
bit.update(2, 2)
bit.bit
bit.getSum(2, 4)

[0, 2, 3, 1, 7, 2, 5, 4, 21, 6, 13, 8, 30]

[0, 2, 3, 2, 8, 2, 5, 4, 22, 6, 13, 8, 30]

7