# ECS529U Algorithms and Data Structures
# Lab sheet 7

This lab gets you to work with binary trees and binary search trees in particular.

**Marks (max 5):**  Question 1: 1.5 | Questions 2-4: 1 each | Question 5: 0.5

## Question 1

This question is about understanding binary search trees (BSTs).

a) Draw the binary search tree we obtain if we start from the empty tree and add 
consecutively the numbers:

    21, 40, 3, 16, 39, 58, 21, 46, 1, 10

b) Write down the numbers of the tree you constructed, starting from the root using 
depth-first search, and using breadth-first search (and using pre-order)

c) Let `t` point to the root node of the BST you constructed in part a. Draw the BST that
results by applying each of the following operations:

    1. t.left = t.left.right
    2. t.left.right.right = t.left.right.left

In each of these cases, is the resulting structure a binary tree? Is it a binary search 
tree?

d) Starting each time from the tree you constructed in part a, perform the following removals (using the algorithm we saw in the lectures) and draw the resulting trees:

1. remove the node with value 16
2. remove the node with value 40

The rest of the Questions ask you to work with the `BST` class and variants thereof. To help you visualise trees, we have implemented the following "pretty printing" function for `BTNode` objects:

In [1]:
class BTNode:
    def __init__(self,d,l,r):
        self.data = d
        self.left = l
        self.right = r
          
    def updateChild(self, oldChild, newChild):
        if self.left == oldChild:
            self.left = newChild
        elif self.right == oldChild:
            self.right = newChild
        else: raise Exception("updateChild error")

    # prints the node and all its children in a string
    def __str__(self):  
        return self.niceStr()
    

    def niceStr(self): # this goes in the BTNode class
        S = ["├","─","└","│"]
        angle = S[2]+S[1]+" "
        vdash = S[0]+S[1]+" "
        
        def niceRec(ptr,acc,pre):
            if ptr == None: return acc+pre+"None"
            if ptr.left==ptr.right==None: return acc+pre+str(ptr.data)
            if pre == vdash: pre2 = S[3]+"  "
            elif pre == angle: pre2 = "   "
            else: pre2 = ""
            left = niceRec(ptr.right,acc+pre2,vdash)
            right = niceRec(ptr.left,acc+pre2,angle)
            return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
            
        return niceRec(self,"","")

In [2]:
class ArrayList:
    def __init__(self): 
        self.inArray = [0 for i in range(10)]
        self.count = 0

    def __str__(self):
        return str(self.inArray[:self.count])

    def length(self):
        return self.count

    def append(self, e):
        self.inArray[self.count] = e
        self.count += 1
        if len(self.inArray) == self.count:
            self._resizeUp()     # resize array if reached capacity

    def insert(self, i, e):
        for j in range(self.count,i,-1):
            self.inArray[j] = self.inArray[j-1]
        self.inArray[i] = e
        self.count += 1
        if len(self.inArray) == self.count:
            self._resizeUp()     # resize array if reached capacity

    def remove(self, i):
        self.count -= 1
        val = self.inArray[i]
        for j in range(i,self.count):
            self.inArray[j] = self.inArray[j+1]
        return val
    
    def _resizeUp(self):
        newArray = [0 for i in range(2*len(self.inArray))]
        for j in range(len(self.inArray)):
            newArray[j] = self.inArray[j]
        self.inArray = newArray

class Queue:
    def __init__(self):
        self.inList = ArrayList()

    def str(self):
        return str(self.inList)

    def size(self):
        return self.inList.length()

    def enq(self, e):
        self.inList.append(e)

    def deq(self):
        return self.inList.remove(0)
        
class Stack:
    def __init__(self):
        self.inList = ArrayList()

    def str(self):
        return str(self.inList)

    def size(self):
        return self.inList.length()

    def push(self, e):
        self.inList.insert(0,e)

    def pop(self):
        return self.inList.remove(0)

For example, the following tree

        22
       /  \
      20   42
     / \   / \
    11 21 22 44

is converted into a string that prints as follows:

    22 
    ├─ 42
    │  ├─ 44
    │  └─ 22
    └─ 20
       ├─ 21
       └─ 11

## Question 2

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def min(self)`

that returns the smallest element of the tree. If the tree is empty, the function should return `None`.

b) `def max(self)`

that returns the largest element of the tree. If the tree is empty, the function should return `None`.

c) `def removeAll(self, d)`

that removes all occurrences of the element `d` in the tree and returns the number of occurrences removed. For example, if the tree `t` is:

        22
       /  \
      20   42
     / \   / \
    11 21 22 44
    
then `t.removeAll(22)` should change `t` to:

        42
       /  \
      20   44
     / \  
    11 21 
    
and return `2`.

In [3]:
class BST:
    def __init__(self):
        self.root = None
        self.size = 0

    def __str__(self):
        return str(self.root)
        
    def search(self, d):
        ptr = self.root
        while ptr != None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        if self.root == None:
            self.root = BTNode(d,None,None)
        else:
            ptr = self.root
            while True:
                if d < ptr.data:
                    if ptr.left == None:
                        ptr.left = BTNode(d,None,None)
                        break
                    ptr = ptr.left
                else:
                    if ptr.right == None:
                        ptr.right = BTNode(d,None,None)
                        break
                    ptr = ptr.right
        self.size += 1

    def remove(self,d):
        if self.root == None: return
        if self.root.data == d:
            self.size -= 1
            return self._removeRoot()
        parentPtr = None
        ptr = self.root
        while ptr != None and ptr.data != d:
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        if ptr != None:
            self.size =- 1
            self._removeNode(ptr,parentPtr)

    def _removeNode(self, ptr, parentPtr):
        if ptr.left == ptr.right == None:
            parentPtr.updateChild(ptr,None)
        elif ptr.left == None:
            parentPtr.updateChild(ptr,ptr.right)
        elif ptr.right == None:
            parentPtr.updateChild(ptr,ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left != None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            parentMinRNode.updateChild(minRNode,minRNode.right)

    def _removeRoot(self):
        parentRoot = BTNode(None,self.root,None)
        self._removeNode(self.root,parentRoot)
        self.root = parentRoot.left

    def min(self):
        if self.size == 0:
            return None
        ptr = self.root
        while ptr.left != None:
            ptr = ptr.left
        return ptr.data

    def max(self):
        if self.size == 0:
            return None
        ptr = self.root
        while ptr.right != None:
            ptr = ptr.right
        return ptr.data

    def removeAll(self, d):
        while self.search(d):
            self.remove(d)


tree = BST()
B = [1,2,3,4,4,4,4,4,5,6]
for i in range(len(B)):
    tree.add(B[i])
print(tree.min())
print(tree.max())
print(tree.search(4))
print(tree.removeAll(4))
print(tree.search(4))

1
6
True
None
False


## Question 3

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def _sumAllRec(self, ptr)`

that <u>uses recursion</u> and returns the sum of all the elements of the subtree starting from the node `ptr`.

_Hint:_ you can simply use depth-first search, and ignoring the fact that this is a BST 
rather than a simple binary tree.

b) `def sumAll(self)`

that sums all the elements of the tree (use the function from part a).

c) `def sumAllBFS(self)`

that sums all the elements of the tree using breadth-first search.

_Hint:_ you can adapt the code for breadth-first search that we saw in the lecture 
(week 6). You will need to use a queue (see lecture of week 5).

In [9]:
class BST:
    def __init__(self):
        self.root = None
        self.size = 0

    def __str__(self):
        return str(self.root)
                   
    def search(self, d):
        ptr = self.root
        while ptr != None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        if self.root == None:
            self.root = BTNode(d,None,None)
        else:
            ptr = self.root
            while True:
                if d < ptr.data:
                    if ptr.left == None:
                        ptr.left = BTNode(d,None,None)
                        break
                    ptr = ptr.left
                else:
                    if ptr.right == None:
                        ptr.right = BTNode(d,None,None)
                        break
                    ptr = ptr.right
        self.size += 1

    def remove(self,d):
        if self.root == None: return
        if self.root.data == d:
            self.size -= 1
            return self._removeRoot()
        parentPtr = None
        ptr = self.root
        while ptr != None and ptr.data != d:
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        if ptr != None:
            self.size =- 1
            self._removeNode(ptr,parentPtr)

    def _removeNode(self, ptr, parentPtr):
        if ptr.left == ptr.right == None:
            parentPtr.updateChild(ptr,None)
        elif ptr.left == None:
            parentPtr.updateChild(ptr,ptr.right)
        elif ptr.right == None:
            parentPtr.updateChild(ptr,ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left != None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            parentMinRNode.updateChild(minRNode,minRNode.right)

    def _removeRoot(self):
        parentRoot = BTNode(None,self.root,None)
        self._removeNode(self.root,parentRoot)
        self.root = parentRoot.left

    def min(self):
        if self.size == 0:
            return None
        ptr = self.root
        while ptr.left != None:
            ptr = ptr.left
        return ptr.data

    def max(self):
        if self.size == 0:
            return None
        ptr = self.root
        while ptr.right != None:
            ptr = ptr.right
        return ptr.data

    def removeAll(self, d):
        while self.search(d):
            self.remove(d)

    def _sumAllRec(self, ptr):
        if ptr == None:
            return 0
        else:
            return ptr.data + self._sumAllRec(ptr.left) + self._sumAllRec(ptr.right)

    def sumAll(self):
        return self._sumAllRec(self.root)

    def sumAllBFS(self):
        ptr = self.root
        if ptr == None:
            return
        q = Queue()
        sum = 0
        q.enq(ptr)
        while q.size() > 0:
            ptr = q.deq() 
            if ptr.left != None:
                q.enq(ptr.left)
            if ptr.right != None: 
                q.enq(ptr.right)
            sum += ptr.data
        return ((sum))

    def toSortedArray(self):
        ptr = self.root
        B = [0 for i in range(self.size)]
        index = 0
        if ptr == None:
            return
        s = Stack()
        
        while ptr != None or s.size() != 0:
            while ptr != None:
                s.push(ptr)
                ptr = ptr.left
                
            if ptr == None :    
                ptr = s.pop()
                B[index] = ptr.data
                index += 1
         
            ptr = ptr.right
                
        return B


# Stack[]
#ptr = None
# B[2,4,5,6,7,8,9,10]

tree = BST()
B = [6,1,5,3,2,7,8,9,4,5,10]
for i in range(len(B)):
    tree.add(B[i])
print(tree._sumAllRec(tree.root.left))
print(tree.sumAll())
print(tree.sumAllBFS())
print(tree.toSortedArray())


20
60
60
[1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 10]


## Question 4

Add in `BST` a function 

    def toSortedArray(self)

that returns an array containing the elements of the tree in ascending order.

_Hint:_ Use a helper function to do an inorder traversal of the BST.

In [5]:
# Minimal testing Questions 2-4

print("Question 2")
t = BST()
print(t)
print(t.min(),t.max())
A = [22,20,11,21,42,22,44,1]
for x in A: t.add(x)
print(t)
print(t.min(),t.max())
t.removeAll(22)
print(t)

print("\nQuestion 3")
t = BST()
print(t)
print(t.sumAll(), t.sumAllBFS())
A = [22,20,11,21,42,22,44]
for x in A: t.add(x)
print(t)
print(t.sumAll(), t.sumAllBFS())

print("\nQuestion 4")
print(BST().toSortedArray())
print(t.toSortedArray())

Question 2
None
None None
22
├─ 42
│  ├─ 44
│  └─ 22
└─ 20
   ├─ 21
   └─ 11
      ├─ None
      └─ 1
1 44
42
├─ 44
└─ 20
   ├─ 21
   └─ 11
      ├─ None
      └─ 1

Question 3
None
0 None
22
├─ 42
│  ├─ 44
│  └─ 22
└─ 20
   ├─ 21
   └─ 11
182 182

Question 4
None
[11, 20, 21, 22, 22, 42, 44]


## Question 5

You are asked to write a class `BST2` which implements a BST in which each node has a multiplicity counter (`mult`), which counts how many times the node's value is stored in the tree. This way, there is no need to store duplicate nodes in the tree: 
- adding a value that already exists in the tree simply amounts to increase the counter of the value's node by 1; 
- removing a value from the tree amounts to reducing the counter of its node by 1, and if the counter becomes 0 then the node is removed altogether.

Below we have provided you with a class of nodes `BTNode2` to use, and we made a start in implementing `BST2`.You are asked to implement the following functions:

- `add(self,d)` for adding the value `d` in the BST2. This should use BST 
search and either increase the `mult` counter of the `BTNode2` containing `d` or, if `d` is not in the tree, create a new `BTNode2` for `d`.

- `search(self,d)` for searching the value `d` in the BST2. This should use BST 
search and return `True` if the value is found, and `False` otherwise.

- `count(self,d)` for counting the times the value `d` appears in the BST2. This 
should use BST search and return the number of times that the value appears in 
the BST2.

- `remove(self,d)` for removing one occurrence of the value `d` from the BST2.

In [7]:
class BTNode2:
    def __init__(self,d,l,r):
        self.data = d
        self.left = l
        self.right = r
        self.mult = 1
          
    # prints the node and all its children in a string
    def __str__(self):  
        st = str(self.data)+" ("+str(self.mult)+")-> ["
        if self.left != None:
            st += str(self.left)
        else: st += "None"
        if self.right != None:
            st += ", "+str(self.right)
        else: st += ", None"
        return st + "]"
    
class BST2:
    def __init__(self):
        self.root = None
        self.size = 0
        
    def __str__(self):
        if self.root == None: return "None"        
        return str(self.root)

    def search(self, d):
        ptr = self.root
        while ptr != None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        if self.root == None:
            self.root = BTNode2(d,None,None)
        if self.search(d):
            ptr = self.root
            while ptr != None:
                if d == ptr.data:
                    ptr.mult +=1
                if d < ptr.data:
                    ptr = ptr.left
                else:
                    ptr = ptr.right
        else:
            ptr = self.root
            while True:
                if d < ptr.data:
                    if ptr.left == None:
                        ptr.left = BTNode2(d,None,None)
                        break
                    ptr = ptr.left
                else:
                    if ptr.right == None:
                        ptr.right = BTNode2(d,None,None)
                        break
                    ptr = ptr.right
        self.size += 1

    def count(self,d):
        if self.search(d):
            ptr = self.root
            while ptr != None:
                if d == ptr.data:
                    return ptr.mult
                if d < ptr.data:
                    ptr = ptr.left
                else:
                    ptr = ptr.right
        return 0

    def remove(self,d):
        if self.search(d):
            ptr = self.root
            while ptr != None:
                if d == ptr.data:
                    ptr.mult -= 1
                if d < ptr.data:
                    ptr = ptr.left
                else:
                    ptr = ptr.right
        else:
            remove2(d)

    def remove2(self,d):
        if self.root == None: return
        if self.root.data == d:
            self.size -= 1
            return self._removeRoot()
        parentPtr = None
        ptr = self.root
        while ptr != None and ptr.data != d:
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        if ptr != None:
            self.size =- 1
            self._removeNode(ptr,parentPtr)

    def _removeNode(self, ptr, parentPtr):
        if ptr.left == ptr.right == None:
            parentPtr.updateChild(ptr,None)
        elif ptr.left == None:
            parentPtr.updateChild(ptr,ptr.right)
        elif ptr.right == None:
            parentPtr.updateChild(ptr,ptr.left)
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left != None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            parentMinRNode.updateChild(minRNode,minRNode.right)
        
print("Question 5")
t = BST2()
A = [22,20,11,21,42,11,22,44,1]
for x in A: t.add(x)
print(t)
for x in A:
    print(x,t.search(x),t.count(x),t.search(-x),t.count(-x))
for x in A:
    t.remove(x); print("take",x,":\n",t)

Question 5
22 (3)-> [20 (1)-> [11 (2)-> [1 (1)-> [None, None], None], 21 (1)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
22 True 3 False 0
20 True 1 False 0
11 True 2 False 0
21 True 1 False 0
42 True 1 False 0
11 True 2 False 0
22 True 3 False 0
44 True 1 False 0
1 True 1 False 0
take 22 :
 22 (2)-> [20 (1)-> [11 (2)-> [1 (1)-> [None, None], None], 21 (1)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 20 :
 22 (2)-> [20 (0)-> [11 (2)-> [1 (1)-> [None, None], None], 21 (1)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 11 :
 22 (2)-> [20 (0)-> [11 (1)-> [1 (1)-> [None, None], None], 21 (1)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 21 :
 22 (2)-> [20 (0)-> [11 (1)-> [1 (1)-> [None, None], None], 21 (0)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 42 :
 22 (2)-> [20 (0)-> [11 (1)-> [1 (1)-> [None, None], None], 21 (0)-> [None, None]], 42 (0)-> [None, 44 (1)-> [None, None]]]
take 11 :
 22 (2)-> [20 (0)-> [11 (