<a href="https://colab.research.google.com/github/shlin101204/colab/blob/main/BST_AVL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from numpy.random import seed, randint
import ipywidgets as widgets 
from ipywidgets import interact, interactive_output, HBox, VBox, Layout
from IPython.display import display, clear_output, SVG, HTML
import graphviz as gv # for visualizing a tree using Digraph
from graphviz import Digraph
import pydot

###############
#Debug function
###############

outputdebug = True

def debug(msg):
  if outputdebug:
    print(msg)

In [None]:
class TreeNode:
  def __init__(self, key):
    self.key = key
    self.parent = None
    self.left = None
    self.right = None
    self.pos = [0, 0] # Position
    self.col = "black" # Color
    self.h = 0 # Height
  def spos(self):
    ''' Position of a node for neato '''
    return str(self.pos[0]) + "," + str(self.pos[1]) + "!"
  def to_string(self):
    return str(self.key)
  def print(self):        
    print("TreeNode: %d" % (self.key))

In [None]:
class BST():
  def __init__(self):
    self.root = None # Root of the tree
    self.list = [] # Elements of the tree as list
  # Insert a key 
  def insert(self, k):
    ''' Insert a key '''
    newNode = TreeNode(k)
    # If the tree is empty, the new node is the root
    if self.root == None:
      self.root = newNode
    else:        
      x = self.root
      # Find parent node
      while x != None:
        p = x
        if newNode.key < x.key:
          x = x.left
        elif newNode.key >= x.key:
          x = x.right
      # Set parent, left and right child for newNode 
      newNode.parent = p       
      if newNode.key < p.key:
        p.left = newNode
      else: 
        p.right = newNode
        
  def search(self, x, k):
    ''' Search for a key k starting at node x  '''
    if ( x== None):
      debug("找不到鍵值 " + str(k))
      return x
    if (x.key == k):
      return x
    if (k < x.key):
      return self.search(x.left, k)
    else: 
      return self.search(x.right, k)
  
  def find(self, k):  
    return self.search(self.root, k)   
    
  def transplant(self, u, v):
    ''' Replace subtree u with subtree v '''
    if (u.parent == None):
      self.root = u
    elif u == u.parent.left:
      u.parent.left = v
    else:
      u.parent.right = v
    if v != None:
      v.parent = u.parent
            
  def tree_minimum(self, x):
    ''' Determine the smallest key in the subtree of x'''
    while (x.left != None):
      x = x.left
    return x;

  def delete(self, k):
    '''Delete node with key k from  tree ''' 
    z = self.search(self.root, k)
    if z == None:
      debug("找不到鍵值 " + str(k))
      return
    else:
      debug("刪除鍵值 " + str(k))
    if z.left == None: # case 1: z has no left child
      self.transplant(z, z.right)
    elif z.right == None: # case 2: z has no right child
      self.transplant(z, z.left)
    else: 
      y = self.tree_minimum(z.right) # determine smallest value in z.right
      if (y.parent != z):
        self.transplant(y, y.right)
        y.right = z.right
        y.right.parent = y
      self.transplant(z, y)
      y.left = z.left
      y.left.parent = y
      if (z == self.root):
        self.root = y
                       

  def height(self, x):
    if (x == None):
      return 0
    else: 
      return 1 + np.maximum(self.height(x.left),  self.height(x.right))
    
  def inorder_traversal(self, x):
    '''Inorder traversal of the tree: returns elements sorted by key '''
    if (x != None):
      if (x == self.root):
        self.list = []
      self.inorder_traversal(x.left)
      self.list.append(x.to_string())
      self.inorder_traversal(x.right)

  def rotate_left(self, x):
    ''' Rotate left on node x'''
    debug (str(x.key) + ' 向左旋轉')
    y = x.right # set y to be right child
    # turn y's left subtree to x's right subtree
    x.right = y.left
    if y.left != None:
      y.left.parent = x
    y.parent = x.parent # link x's parent to y
    if x.parent == None: # Case 1: x is root
      self.root = y
    elif x == x.parent.left: # Case 2: x is left child
      x.parent.left = y
    else: # Case 3: x is right child
      x.parent.right = y
    y.left = x
    x.parent = y

  def rotate_right(self, x):
    ''' Rotate right on node x'''
    debug (str(x.key) + '  向右旋轉')
    y = x.left # set y to be left child
    # turn y's right subtree to x's left subtree
    x.left = y.right
    if y.right != None:
      y.right.parent = x
    y.parent = x.parent # link x's parent to y
    if x.parent == None: # Case 1: x is root
      self.root = y
    elif x == x.parent.right: # Case 2: x is right child
      x.parent.right = y
    else: # Case 3: x is left child
      x.parent.left = y
    y.right = x
    x.parent = y

  def postorder_traversal(self, x):
    '''Postorder traversal of the tree: left child, right child, root '''
    if (x != None):
      if (x == self.root):
        self.list = []
      self.postorder_traversal(x.left)
      self.postorder_traversal(x.right)
      self.list.append(x.to_string())
  def preorder_traversal(self, x):
    '''Preorder traversal of the tree: root, left child, right child '''
    if (x != None):
      if (x == self.root):
        self.list = []
      self.list.append(x.to_string())
      self.preorder_traversal(x.left)
      self.preorder_traversal(x.right)


In [None]:
class BSTViz: 
  def __init__(self, bst):
    self.bst = bst
  def visualize(self, node):   
    ''' Visualize the tree using graphviz '''
    tree = self.bst.root # start with root of the tree 
    dot = Digraph()
    dot.engine = 'neato'
    # Place root node at position tree.pos 
    h = self.bst.height(tree)
    tree.pos = [0, h]
    dot.node(name=str(tree), label=str(tree.key), color = tree.col, shape="circle", 
        fixedsize="True", width="0.4", pos = tree.spos())
    # Recursively place the other nodes and edges              
    def add_nodes(tree, dot):
      col = "black"
      if tree.left: # if left subtree: position node to left of parent
        if node != None and tree.left.key == node.key :
          col= "green"
        h = self.bst.height(tree.left)            
        tree.left.pos[0] = tree.pos[0] - h/4 # x
        tree.left.pos[1] = tree.pos[1] - 0.6 # y       
        dot.node(name=str(tree.left), label=str(tree.left.key), color = col,
            shape="circle", fixedsize="True", width="0.4", pos=tree.left.spos())
        dot.edge(str(tree), str(tree.left))
        dot = add_nodes(tree.left, dot=dot)
        col = "black"
      if tree.right: # if right subtree: position node to right of parent
        if node != None and tree.right.key == node.key :
          col = "red"
        h = self.bst.height(tree.right)                    
        tree.right.pos[0] = tree.pos[0] + h/4 # x
        tree.right.pos[1] = tree.pos[1] - 0.6 # y                
        dot.node(name=str(tree.right), label=str(tree.right.key), color = col,
            shape="circle", fixedsize="True", width="0.4", pos=tree.right.spos())
        dot.edge(str(tree), str(tree.right))
        dot = add_nodes(tree.right, dot=dot)
        col="black"                    
      return dot 
    return add_nodes(tree, dot)  

In [None]:
class AVLTree(BST): 
  def __init__(self):
    super().__init__()
  
  def insert(self, k):
    ''' Insert a key '''
    super().insert(k)
    p = super().find(k) 
    prevBalance = 0
    while (p != None):
      b = super().height(p.left) - super().height(p.right)
      debug(str(p.key)+" 的平衡係數:"+str(b))
      if (b > 1 and prevBalance > 0): # Case 1 LL
        super().rotate_right(p)
      elif (b < -1 and prevBalance < 0): # Case 2 RR
        super().rotate_left(p)
      elif (b > 1 and prevBalance < 0): #  Case 3 LR
        super().rotate_left(p.left)
        super().rotate_right(p)
      elif (b < -1 and prevBalance > 0): # Case 4 RL            
        super().rotate_right(p.right)
        super().rotate_left(p)
      prevBalance = b
      p = p.parent

In [None]:
from IPython.core.display import DisplayObject
import time
import ipywidgets as widgets
from ipywidgets import BoundedIntText, Button, HTML, Tab, HBox, VBox, Output

# Create a new tree
tree = BST() 
################################
# Output objects
out, out1, out2 = Output(), Output(), Output()
tab = Tab(children = [out1, out2], layout=Layout(width='100%', height='auto'))
tab.set_title(0, '二元搜尋樹')
tab.set_title(1, '高度平衡樹 (AVL Tree)')

###############################
# Input field for keys
ui_key = BoundedIntText(value=20, min=0, max=100, 
      step=2, description='輸入鍵值(0-100):', disabled=False)
################################
# Error messages
msg = "<p>接下列任一按鍵!</p>"
msg1 = "尚未建立二元樹!"
msg2 = "未找到所需鍵值!!"  
################################
# Buttons
btn_insert = Button(description='新增', button_style='success')
def on_button_insert_clicked(b):
  with out:
    clear_output()
    tree.insert(ui_key.value)
    dot = BSTViz(tree).visulize(None)
    display(dot)       
btn_insert.on_click(on_button_insert_clicked)

btn_delete = Button(description='刪除', button_style='primary')
def on_button_delete_clicked(b):
  with out:
    if (tree.root is None):
      print(msg1)
    else:
      clear_output()
      tree.delete(ui_key.value)
      dot = BSTViz(tree).visualize(None)
      display(dot)

btn_delete.on_click(on_button_delete_clicked)

btn_reset = Button(description='清除二元樹', button_style='danger')
def on_button_reset_clicked(b):
  with out:
    clear_output()
    tree.root = None 
btn_reset.on_click(on_button_reset_clicked)

btn_search = Button(description='搜尋', button_style='info')
def on_button_search_clicked(b):
  with out:
    if (tree.root is None):
      print(msg1)
    else:
      clear_output()
      foundnode = tree.find(ui_key.value)
      if (foundnode != None):
        print("找到鍵值：" + foundnode.to_string())
        dot = BSTViz(tree).visualize(foundnode)
        display(dot)
      else:
        clear_output()
        print(msg2)                    
        dot = BSTViz(tree).visualize(None)
        display(dot)
btn_search.on_click(on_button_search_clicked)

btn_rnd = widgets.Button(description='隨機產生二元樹', button_style='warning')
seed(1) # Set the seed for the random number generator
def on_button_rnd_clicked(b):
  with out:
    clear_output()
    tree.root = None
    keys = randint(1, 90, 7)
    for key in keys:
      tree.insert(key)
    dot = BSTViz(tree).visualize(None)
    display(dot)
btn_rnd.on_click(on_button_rnd_clicked)

btn_inorder = widgets.Button(description='中序走訪', button_style='info')
seconds = int(time.time() * 1000) % 1000
seed(seconds) # Set the seed for the random number generator
def on_button_inorder_clicked(b):
    with out:
        clear_output()
        tree.root = None
        keys = randint(1, 50, 7)
        for key in keys:
            tree.insert(key)
        dot = tree.visualize()
        display(dot)
btn_inorder.on_click(on_button_inorder_clicked)
################################
# Layout
layout_displ=Layout(height='300px', border='1px dotted blue', overflow ='auto')
layout_ctrl=Layout(height='50px')
################################

# Create a new AVL tree
avl = AVLTree()
# Buttons
btn_insert_avl = Button(description='新增', button_style='success')
def on_button_insert_avl_clicked(b):
  with out:
    clear_output()
    avl.insert(ui_key.value)
    dot = BSTViz(avl).visualize(None)
    display(dot)       
btn_insert_avl.on_click(on_button_insert_avl_clicked)

btn_delete_avl = Button(description='刪除', button_style='primary')
def on_button_delete_avl_clicked(b):
  with out:
    clear_output()
    avl.delete(ui_key.value)
    dot = BSTViz(avl).visualize(None)
    display(dot)        
btn_delete_avl.on_click(on_button_delete_avl_clicked)

btn_rnd_avl = Button(description='隨機產生AVL', button_style='warning')
seconds = int(time.time() * 1000) % 1000
seed(seconds) # Set the seed for the random number generator

def on_button_rnd_avl_clicked(b):
  with out:
    clear_output()
    rt = None
    keys = randint(1, 90, 7)
    for key in keys:    
      avl.insert(key)        
    dot = BSTViz(avl).visualize(None)
    display(dot)  
btn_rnd_avl.on_click(on_button_rnd_avl_clicked)

btn_clear_avl = Button(description='清除AVL', button_style='danger')
def on_button_clear_avl_clicked(b):
    with out:
        clear_output()
        avl.root = None 
btn_clear_avl.on_click(on_button_clear_avl_clicked)
################################

with out1: 
    htm = HTML(msg)
    msgbox = HTML("<p></p>")
    displ = HBox([out], layout=layout_displ)
    ctrl = HBox([ui_key, btn_insert, btn_delete, btn_search, btn_rnd, btn_reset, btn_inorder,], layout=layout_ctrl)
    display(VBox([displ, htm, msgbox, ctrl]))

with out2:
    msgbox = HTML("<p></p>")
    displ = HBox([out], layout=layout_displ)
    ctrl = HBox([ui_key, btn_insert_avl, btn_delete_avl, btn_rnd_avl, btn_clear_avl], layout=layout_ctrl)
    display(VBox([displ, msgbox, ctrl]))

display(tab)


Tab(children=(Output(), Output()), layout=Layout(height='auto', width='100%'), _titles={'0': '二元搜尋樹', '1': '高度…