In [31]:
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.parent = None

In [32]:
class SplayTree:
    def __init__(self):
        self.root = None

    def _left_rotate(self, x):
        y = x.right
        B = y.left 
        
        x.right = B
        y.left = x
        
        y.parent = x.parent
        x.parent = y
        if B: B.parent = x
            
        return y
    
    def _right_rotate(self, x):
        y = x.left 
        B = y.right
        
        x.left = B 
        y.right = x 
        
        y.parent = x.parent
        x.parent = y
        if B: B.parent = x
        
        return y 
    
    def _left_left_rotate(self, x):
        return self._left_rotate(self._left_rotate(x))
    
    def _right_right_rotate(self, x):
        return self._right_rotate(self._right_rotate(x))
    
    def _left_right_rotate(self, x):
        x.left = self._left_rotate(x.left)
        return self._right_rotate(x)
    
    def _right_left_rotate(self, x):
        x.right = self._right_rotate(x.right)
        return self._left_rotate(x)
    
    def _splay(self, node):
        # case 1: node is root: done
        if node.parent:
            P = node.parent
            # case 2: single rotation
            if not P.parent:
                if node is P.left:
                    self.root = self._right_rotate(P)
                else:
                    self.root = self._left_rotate(P)
            # case 3: mutiple rotations
            else:
                G = P.parent
                parent_is_left = True if P is G.left else False
                node_is_left = True if node is P.left else False
                # (i) LL case
                if parent_is_left and node_is_left:
                    new_node = self._right_right_rotate(G)
                    if not new_node.parent:
                        self.root = new_node
                    else:
                        if G is new_node.parent.left:
                            new_node.parent.left = new_node
                        else:
                            new_node.parent.right = new_node
                        self._splay(new_node)
                # (ii) RR case
                elif not parent_is_left and not node_is_left:
                    new_node = self._left_left_rotate(G)
                    if not new_node.parent:
                        self.root = new_node
                    else:
                        if G is new_node.parent.left:
                            new_node.parent.left = new_node
                        else:
                            new_node.parent.right = new_node
                        self._splay(new_node)
                # (iii) RL case  
                elif not parent_is_left and node_is_left:
                    new_node = self._right_left_rotate(G)
                    if not new_node.parent:
                        self.root = new_node
                    else:
                        if G is new_node.parent.left:
                            new_node.parent.left = new_node
                        else:
                            new_node.parent.right = new_node
                        self._splay(new_node)                   
                # (iv) LR case
                elif parent_is_left and not node_is_left:
                    new_node = self._left_right_rotate(G)
                    if not new_node.parent:
                        self.root = new_node
                    else:
                        if G is new_node.parent.left:
                            new_node.parent.left = new_node
                        else:
                            new_node.parent.right = new_node
                        self._splay(new_node)                                                              
                                  
    def insert(self, value):
        if not self.root:
            self.root = Node(value)
        else:
            self._insert(self.root, value)
    
    def _insert(self, node, value):
        if value <= node.value:
            if not node.left:
                node.left = Node(value)
                node.left.parent = node
                self._splay(node.left)
            else:
                self._insert(node.left, value)
        else:
            if not node.right:
                node.right = Node(value)
                node.right.parent = node
                self._splay(node.right)
            else:
                self._insert(node.right, value)

    def delete(self, value):
        # top-down approach
        if self.is_in_Splay_Tree(value):
            left_subtree = self.root.left
            right_subtree = self.root.right
            if not left_subtree:
                self.root = right_subtree
                if right_subtree:
                    right_subtree.parent = None
            else:
                self.root = left_subtree
                left_subtree.parent = None
                node = self._get_in_order_predecessor(left_subtree)
                self._splay(node)
                self.root.right = right_subtree
                if right_subtree:
                    right_subtree.parent = self.root
            
    def _get_in_order_predecessor(self, node):
        temp = node
        while temp.right:
            temp = temp.right
        return temp
    
    def search(self, value):
        node = self._search(self.root, value)
        if node:
            self._splay(node)
        return node
    
    def _search(self, node, value):
        if not node:
            return None
        elif value < node.value:
            return self._search(node.left, value)
        elif value > node.value:
            return self._search(node.right, value)
        else:
            return node 
            
    def is_in_Splay_Tree(self, value):
        return self.search(value) is not None

    
    def update(self, old_value, value):
        if self.is_in_Splay_Tree(old_value):
            self.delete(old_value)
            self.insert(value)
        
    def in_order_traversal(self):
        result = []
        self._in_order_traversal(self.root, result)
        return result
    
    def _in_order_traversal(self, node, result):
        if node:
            self._in_order_traversal(node.left, result)
            result.append(node.value)
            self._in_order_traversal(node.right, result)
            

In [33]:
# Example usage:
st = SplayTree()
st.insert(10)
st.insert(5)
st.insert(20)
st.insert(3)
st.insert(7)
st.insert(15)
st.insert(30)
st.insert(11)
st.insert(29)
st.insert(10)
st.insert(2)
st.delete(10)
st.delete(10)
st.insert(39)
st.update(10,99)
print(st.in_order_traversal())

[2, 3, 5, 7, 11, 15, 20, 29, 30, 39]


In [34]:
# Example usage:
import random
n = 1 << 14
cnt = 0

for i in range(n):
    st = SplayTree()
    nums = []
    for _ in range(1000):
        num = random.randint(1, 10000)
        st.insert(num)
        nums.append(num)
    for j in range(100):
        num = nums.pop()
        st.delete(num)
    if st.in_order_traversal() == sorted(nums):
        cnt += 1
print(cnt)
print(cnt == n)

16384
True
