In [1]:
from abc import ABC, abstractproperty

class INode(ABC):

    @abstractproperty 
    def value(self): 
        """"""
        
    @abstractproperty
    def left(self): 
        """"""
        
    @abstractproperty
    def right(self):
        """"""
        
    @abstractproperty
    def level():
        """"""
        
    @level.setter
    def level(level):
        """"""    

    

In [2]:

class BasicNode(INode):
    
    def __init__(self, value:float, left:INode=None, right:INode=None, level:int=0):
        super().__init__()
        self._value = value
        self._level = 0
        self._left = left
        self._right = right
        self._height = None
        if left:
            self._left._level = self._level + 1
        
        if right:
            self._right._level = self._level + 1
            
    @property 
    def value(self): 
        return self._value 
    
    @property
    def left(self): 
        return self._left
    
    @property
    def right(self):
        return self._right
    
    @property
    def level(self):
        return self._level
    
    @level.setter
    def level(self, level):
        self._level = level
        if self.left:
            self._left.level = level+1
            
        if self.right:
            self._right.level = level+1
    
        
    #==============================================
    # Comparison methods
    # =============================================
    
    def __eq__(self, other):
        if other: 
            return self.value == other.value
        else: 
            return False
        
    def __lt__(self, other):
        if other: 
            return self.value < other.value
        else: 
            return False
        
    def __le__(self, other): 
        return self < other or self == other
    
    def __gt__(self, other):
        if other: 
            return self.value > other.value
        else: 
            return False
        
    def __ge__(self, other): 
        return self > other or self == other
    
    


In [3]:
class Node(BasicNode):
    BALANCE_THRESHOLD = 1

    def __init__(self, value:float, left:BasicNode=None, right:BasicNode=None, level:int=0):
        super().__init__(value, left, right, level)
        self._value = value
        self._level = 0
        self._left = left
        self._right = right
        self._height = None
        if left:
            self._left._level = self._level + 1
        
        if right:
            self._right._level = self._level + 1

    
    def insert_value(self, value:float):
        self.insert(Node(value))
                
                
    def find_value(self, value):
        return self.find(Node(value))
    
    def find(self, node:BasicNode):
        if node < self:
            if self.left:
                return self.left.find(node)
            else:
                return None
        elif node > self: 
            if self.right:
                return self.right.find(node)
            else: 
                return None
        else: 
            return self
    
            
    def drop_value(self, value:float):
        self.drop(Node(value))

    def drop(self, node:BasicNode):
        if node < self:
            if self.left:
                if self.left == node:
                    self._left = None
                else:
                    self.left.drop(node)
        elif node > self: 
            if self.right:
                if self.right == node:
                    self._right = None
                else:
                    self.right.drop(node)
                
    def __repr__(self):
        result = ""

        if self.right:
            result += f"{self.right.__repr__()}\n"
        
        
        result += f"{'    ' * self._level}----{self._value:0.2f}"

                
        if self.left:
            result += f"\n{self.left.__repr__()}" 
            
        
        return result
    
    @property
    def balance(self):
        right_height = -1
        left_height = -1
        if self._left:
            left_height = self._left.height

        if self._right:
            right_height = self._right.height

        return left_height - right_height
    
    @property
    def height(self):
        right_height = -1
        left_height = -1
        if self._left:
            left_height = self._left.height

        if self._right:
            right_height = self._right.height

        self._height  = max(left_height, right_height) + 1
        
        return self._height
    
    
 
    def rotate_right(self):
        level = self.level
        
        x = self.left.left
        y = self.left
        z = self
        t3 = y.right            

        z._left = t3
        y._right = z
        
        y.level = level
        
        return y
    
    def rotate_left(self):
        level = self.level

        x = self.right.right
        y = self.right
        z = self
        t2 = y.left

        z._right = t2
        y._left = z
        
        y.level = level       
        
        return y
    
    @staticmethod
    def check_balance(node:BasicNode):
        balance = node.balance 
        result = node
        if  balance > Node.BALANCE_THRESHOLD:
            if node.left.balance <= -Node.BALANCE_THRESHOLD:
#                 print("LR")
                node._left = node.left.rotate_left()
                result = node.rotate_right()
            else:
#                 print("R")
                result = node.rotate_right()
                
        if  balance < -Node.BALANCE_THRESHOLD:
            if node.right.balance <= -Node.BALANCE_THRESHOLD:
#                 print("L")
                result = node.rotate_left()
            else:
#                 print("RL")
                node._right = node.right.rotate_right()
                result = node.rotate_left()

        return result

    
    def insert(self, node:BasicNode):
        if node < self:
            if self.left:
                self.left.insert(node)
            else:
                self._left = node
                self._left._level = self._level + 1
                
            self._left = Node.check_balance(self.left)
            
        elif node > self: 
            if self.right:
                self.right.insert(node)
            else: 
                self._right = node
                self._right._level = self._level + 1
            self._right = Node.check_balance(self.right)
        

In [4]:
btree = Node(4)
btree.insert_value(3)
btree.insert_value(2)
print(btree)
print(btree.balance)
Node.check_balance(btree)

----4.00
    ----3.00
        ----2.00
2


    ----4.00
----3.00
    ----2.00

In [5]:
btree = Node(6)
btree.insert_value(7)
btree.insert_value(8)
print (btree.balance)
print(btree)
Node.check_balance(btree)

-2
        ----8.00
    ----7.00
----6.00


    ----8.00
----7.00
    ----6.00

In [6]:
btree = Node(4)
btree.insert_value(2)
btree.insert_value(3)
print (btree.balance)
print(btree)
Node.check_balance(btree)

2
----4.00
        ----3.00
    ----2.00


    ----4.00
----3.00
    ----2.00

In [7]:
btree = Node(6)
btree.insert_value(8)
btree.insert_value(7)
print (btree.balance)
print(btree)
Node.check_balance(btree)

-2
    ----8.00
        ----7.00
----6.00


    ----8.00
----7.00
    ----6.00

In [8]:
class BinaryTree:
        
    def __init__(self, root_value:float = None):
        self._root_node = Node(root_value) if root_value else None
    
    
    def insert(self, value:float):
        self._root_node.insert_value(value) 
        self._root_node = Node.check_balance(self._root_node)

    
    def find(self, value) -> Node:
        return self._root_node.find_value(value)
    
    def drop(self, value) -> Node:
        self._root_node.drop_value(value) 
    
    def __repr__(self):
        return self._root_node.__repr__()

    
    @property 
    def balance(self):
        return self._root_node.balance
    
    @property 
    def height(self):
        return self._root_node.height

In [9]:
btree = BinaryTree(4)
btree.insert(3)
btree.insert(1)
btree.insert(2)
btree.insert(0)
print(btree)
print(btree.balance)

    ----4.00
----3.00
        ----2.00
    ----1.00
        ----0.00
1


In [10]:
node = Node(10, Node(9), Node(11))
assert node.value == 10 and node.left.value == 9 and node.right.value == 11

In [11]:
assert Node(9) <= Node(11) and Node(9) >= Node(8) and Node(9) == Node(9)

In [12]:
btree = BinaryTree(10)
for i in range(127):
    btree.insert(i)

In [13]:
btree

                        ----126.00
                    ----125.00
                        ----124.00
                ----123.00
                        ----122.00
                    ----121.00
                        ----120.00
            ----119.00
                        ----118.00
                    ----117.00
                        ----116.00
                ----115.00
                        ----114.00
                    ----113.00
                        ----112.00
        ----111.00
                        ----110.00
                    ----109.00
                        ----108.00
                ----107.00
                        ----106.00
                    ----105.00
                        ----104.00
            ----103.00
                        ----102.00
                    ----101.00
                        ----100.00
                ----99.00
                        ----98.00
                    ----97.00
                        ----96.00
    ----95.00
         

In [14]:
btree = BinaryTree(10)
btree.insert(20)
btree.insert(30)
btree.insert(40)
btree.insert(50)
btree.insert(25)
print(btree)
print(btree.balance)

        ----50.00
    ----40.00
----30.00
        ----25.00
    ----20.00
        ----10.00
0


In [15]:
btree.height

2