In [1]:
class Node:
    __slots__ = ("key", "left", "right", "parent")
    def __init__(self, key):
        self.key = key
        self.left = self.right = self.parent = None

class SplayTree:
    def __init__(self):
        self.root = None

    # —— 基础旋转 ——
    def _rotate_left(self, x):
        y = x.right
        if not y: return
        x.right = y.left
        if y.left: y.left.parent = x
        y.parent = x.parent
        if not x.parent:
            self.root = y
        elif x is x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.left = x
        x.parent = y

    def _rotate_right(self, x):
        y = x.left
        if not y: return
        x.left = y.right
        if y.right: y.right.parent = x
        y.parent = x.parent
        if not x.parent:
            self.root = y
        elif x is x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y
        y.right = x
        x.parent = y

    # —— Splay：把 x 旋到根 ——
    def _splay(self, x):
        while x.parent:
            p = x.parent
            g = p.parent
            if not g:
                # Zig
                if x is p.left:
                    self._rotate_right(p)
                else:
                    self._rotate_left(p)
            else:
                if x is p.left and p is g.left:
                    # Zig-Zig (LL)
                    self._rotate_right(g)
                    self._rotate_right(p)
                elif x is p.right and p is g.right:
                    # Zig-Zig (RR)
                    self._rotate_left(g)
                    self._rotate_left(p)
                elif x is p.right and p is g.left:
                    # Zig-Zag (LR)
                    self._rotate_left(p)
                    self._rotate_right(g)
                else:
                    # Zig-Zag (RL)
                    self._rotate_right(p)
                    self._rotate_left(g)

    # —— BST 插入基础 + splay 到根 ——
    def insert(self, key):
        if not self.root:
            self.root = Node(key); return self.root
        cur = self.root
        while True:
            if key < cur.key:
                if cur.left: cur = cur.left
                else:
                    cur.left = Node(key); cur.left.parent = cur
                    self._splay(cur.left); return self.root
            elif key > cur.key:
                if cur.right: cur = cur.right
                else:
                    cur.right = Node(key); cur.right.parent = cur
                    self._splay(cur.right); return self.root
            else:
                # 已存在：把它旋到根（可选策略，看需求）
                self._splay(cur); return self.root

    # —— 查找：若命中，把该节点 splay 到根；若未命中，把最后访问节点 splay 到根 ——
    def find(self, key):
        cur = self.root
        last = None
        while cur:
            last = cur
            if key < cur.key:
                cur = cur.left
            elif key > cur.key:
                cur = cur.right
            else:
                self._splay(cur); return cur  # 命中并伸展
        if last: self._splay(last)  # 未命中，伸展“最近”节点
        return None

    # —— 删除：把 key splay 到根后重组 —— 
    def delete(self, key):
        node = self.find(key)
        if not node or node.key != key:
            return  # 不存在
        # 现在 node 是根
        left, right = node.left, node.right
        if left: left.parent = None
        if right: right.parent = None
        # 清空根
        self.root = None

        if not left:
            self.root = right
            return
        # 把 left 的最大节点旋到 left 的根（此时它没有右孩子）
        self.root = left
        # 最大节点：一路向右
        cur = self.root
        while cur.right: cur = cur.right
        self._splay(cur)  # 把最大节点旋到根
        # 挂上原右子树
        self.root.right = right
        if right: right.parent = self.root

    # 方便测试：中序遍历（应为递增序）
    def inorder(self):
        res = []
        def dfs(x):
            if not x: return
            dfs(x.left); res.append(x.key); dfs(x.right)
        dfs(self.root); return res

# --- 简单示例 ---
if __name__ == "__main__":
    T = SplayTree()
    for x in [10, 5, 15, 2, 7, 12, 20]:
        T.insert(x)
    print("中序：", T.inorder())      # 应该是有序的
    T.find(7)                         # 访问 7，7 会被旋到根
    print("访问7后根：", T.root.key)   # 7
    T.delete(10)                      # 删除 10
    print("删除10后中序：", T.inorder())


中序： [2, 5, 7, 10, 12, 15, 20]
访问7后根： 7
删除10后中序： [2, 5, 7, 12, 15, 20]
