# ECS529U Algorithms and Data Structures
# Lab sheet 7

This lab gets you to work with binary trees and binary search trees in particular.

**Marks (max 5):**  Question 1: 1.5 | Questions 2-4: 1 each | Question 5: 0.5

## Question 1

This question is about understanding binary search trees (BSTs).

a) Draw the binary search tree we obtain if we start from the empty tree and add 
consecutively the numbers:

    21, 40, 3, 16, 39, 58, 21, 46, 1, 10

b) Write down the numbers of the tree you constructed, starting from the root using 
depth-first search, and using breadth-first search (and using pre-order)

c) Let `t` point to the root node of the BST you constructed in part a. Draw the BST that
results by applying each of the following operations:

    1. t.left = t.left.right
    2. t.left.right.right = t.left.right.left

In each of these cases, is the resulting structure a binary tree? Is it a binary search 
tree?

d) Starting each time from the tree you constructed in part a, perform the following removals (using the algorithm we saw in the lectures) and draw the resulting trees:

1. remove the node with value 16
2. remove the node with value 40

## Q1 Answers

#### a)
```
        21
       /  \
      3    40
     / \   /  \
    1  16 39   58
        \      / 
        10    46
```

#### b)

Pre-order traversal:
21, 3, 1, 16, 10, 40, 39, 58, 46


Breadth-first traversal:
21, 3, 40, 1, 16, 39, 58, 10, 46

#### c)
1.
```
    21
   /  \
  16   40
  /   /  \
 10  39   58
         /
       46
```

2.
in c1 tree t.left.right.right = t.left.right.left = None therefore no change
```
    21
   /  \
  16   40
  /   /  \
 10  39   58
         /
        46
```

#### d)

1.
```
    21
   /  \
  3    40
 / \   /  \
1  10 39   58
          /
         46
```

2.
```
    21
   /  \
  3    46
 / \   / \
1  16 39  58
     \   
      10
```

The rest of the Questions ask you to work with the `BST` class and variants thereof. To help you visualise trees, we have implemented the following "pretty printing" function for `BTNode` objects:

In [1]:
def niceStr(self): # this goes in the BTNode class
    S = ["├","─","└","│"]
    angle = S[2]+S[1]+" "
    vdash = S[0]+S[1]+" "
    
    def niceRec(ptr,acc,pre):
        if ptr == None: return acc+pre+"None"
        if ptr.left==ptr.right==None: return acc+pre+str(ptr.data)
        if pre == vdash: pre2 = S[3]+"  "
        elif pre == angle: pre2 = "   "
        else: pre2 = ""
        left = niceRec(ptr.right,acc+pre2,vdash)
        right = niceRec(ptr.left,acc+pre2,angle)
        return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
        
    return niceRec(self,"","")

For example, the following tree

        22
       /  \
      20   42
     / \   / \
    11 21 22 44

is converted into a string that prints as follows:

    22 
    ├─ 42
    │  ├─ 44
    │  └─ 22
    └─ 20
       ├─ 21
       └─ 11

## Question 2

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def min(self)`

that returns the smallest element of the tree. If the tree is empty, the function should return `None`.

b) `def max(self)`

that returns the largest element of the tree. If the tree is empty, the function should return `None`.

c) `def removeAll(self, d)`

that removes all occurrences of the element `d` in the tree and returns the number of occurrences removed. For example, if the tree `t` is:

        22
       /  \
      20   42
     / \   / \
    11 21 22 44
    
then `t.removeAll(22)` should change `t` to:

        42
       /  \
      20   44
     / \  
    11 21 
    
and return `2`.

In [2]:
class BTNode:
    def __init__(self, d, l, r):
        self.data = d
        self.left = l
        self.right = r

    def updateChild(self, oldChild, newChild):
        if self.left == oldChild:
            self.left = newChild
        elif self.right == oldChild:
            self.right = newChild
        else:
            raise Exception("updateChild error")
    
    def niceStr(self):
        S = ["├","─","└","│"]
        angle = S[2]+S[1]+" "
        vdash = S[0]+S[1]+" "
        
        def niceRec(ptr,acc,pre):
            if ptr == None: return acc+pre+"None"
            if ptr.left==ptr.right==None: return acc+pre+str(ptr.data)
            if pre == vdash: pre2 = S[3]+"  "
            elif pre == angle: pre2 = "   "
            else: pre2 = ""
            left = niceRec(ptr.right,acc+pre2,vdash)
            right = niceRec(ptr.left,acc+pre2,angle)
            return acc+pre+str(ptr.data)+"\n"+left+"\n"+right
            
        return niceRec(self,"","")


class BST:
    def __init__(self):
        self.root = None
        self.size = 0

    def search(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return True
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def add(self, d):
        if self.root is None:
            self.root = BTNode(d, None, None)
        else:
            ptr = self.root
            while True:
                if d < ptr.data:
                    if ptr.left is None:
                        ptr.left = BTNode(d, None, None)
                        break
                    ptr = ptr.left
                else:
                    if ptr.right is None:
                        ptr.right = BTNode(d, None, None)
                        break
                    ptr = ptr.right
        self.size += 1

    def remove(self, d):
        if self.root is None:
            return
        if self.root.data == d:
            self.size -= 1
            return self._removeRoot()

        parentPtr = None
        ptr = self.root
        while ptr is not None and ptr.data != d:
            parentPtr = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right

        if ptr is not None:
            self.size -= 1
            self._removeNode(ptr, parentPtr)

    def _removeNode(self, ptr, parentPtr):
        # There are 3 cases to consider:
        # 1. ptr is a leaf
        if ptr.left is None and ptr.right is None:
            parentPtr.updateChild(ptr, None)
        # 2. ptr has exactly one child
        elif ptr.left is None:
            parentPtr.updateChild(ptr, ptr.right)
        elif ptr.right is None:
            parentPtr.updateChild(ptr, ptr.left)
        # 3. ptr has both children
        else:
            parentMinRNode = ptr
            minRNode = ptr.right
            while minRNode.left is not None:
                parentMinRNode = minRNode
                minRNode = minRNode.left
            ptr.data = minRNode.data
            parentMinRNode.updateChild(minRNode, minRNode.right)
    
    def _removeRoot(self):
        parentRoot = BTNode(None,self.root,None)
        self._removeNode(self.root,parentRoot)
        self.root = parentRoot.left

In [3]:
def min(self):
    """
    Return the smallest element in the BST, or None if the tree is empty.
    """
    if self.root is None:
        return None
    ptr = self.root
    while ptr.left is not None:
        ptr = ptr.left
    return ptr.data

def max(self):
    """
    Return the largest element in the BST, or None if the tree is empty.
    """
    if self.root is None:
        return None
    ptr = self.root
    while ptr.right is not None:
        ptr = ptr.right
    return ptr.data

def removeAll(self, d):
    """
    Remove all occurrences of the element `d` from the BST.
    Returns the number of occurrences removed.
    """
    count = 0
    while self.search(d):
        self.remove(d)
        count += 1
    return count

BST.min = min; BST.max = max; BST.removeAll = removeAll


bst = BST()
for value in [22, 20, 42, 11, 21, 22, 44]:
    bst.add(value)


print("Min:", bst.min())
print("Max:", bst.max())

removed_count = bst.removeAll(22)
print("Removed Count for 22:", removed_count)
print("Search for 22 after removal:", bst.search(22))


Min: 11
Max: 44
Removed Count for 22: 2
Search for 22 after removal: False


## Question 3

Add in `BST` the following functions, assuming that we work with BSTs that store integers:

a) `def _sumAllRec(self, ptr)`

that <u>uses recursion</u> and returns the sum of all the elements of the subtree starting from the node `ptr`.

_Hint:_ you can simply use depth-first search, and ignoring the fact that this is a BST 
rather than a simple binary tree.

b) `def sumAll(self)`

that sums all the elements of the tree (use the function from part a).

c) `def sumAllBFS(self)`

that sums all the elements of the tree using breadth-first search.

_Hint:_ you can adapt the code for breadth-first search that we saw in the lecture 
(week 6). You will need to use a queue (see lecture of week 5).

In [4]:
def _sumAllRec(self, ptr):
    """
    Recursively compute the sum of all elements in the subtree starting at `ptr`.
    """
    if ptr is None:
        return 0
    return ptr.data + self._sumAllRec(ptr.left) + self._sumAllRec(ptr.right)

def sumAll(self):
    """
    Compute the sum of all elements in the tree using recursive helper.
    """
    return self._sumAllRec(self.root)

def sumAllBFS(self):
    """
    Compute the sum of all elements in the tree using breadth-first search (BFS).
    """
    if self.root is None:
        return 0

    from collections import deque

    queue = deque([self.root])  # Queue for BFS
    total_sum = 0

    while queue:
        node = queue.popleft()
        total_sum += node.data
        if node.left is not None:
            queue.append(node.left)
        if node.right is not None:
            queue.append(node.right)

    return total_sum

BST._sumAllRec = _sumAllRec; BST.sumAll = sumAll; BST.sumAllBFS = sumAllBFS

bst = BST()
bst_list = [22, 20, 42, 11, 21, 22, 44]
for value in bst_list:
    bst.add(value)

tot = 0
for i in range(len(bst_list)):
    tot += bst_list[i]

print("expected sum", tot)
print(bst.sumAll())
print(bst.sumAllBFS())

expected sum 182
182
182


## Question 4

Add in `BST` a function 

    def toSortedArray(self)

that returns an array containing the elements of the tree in ascending order.

_Hint:_ Use a helper function to do an inorder traversal of the BST.

In [5]:
def _inOrderTraversal(self, ptr, result):
    """
    Helper function to perform an in-order traversal and collect elements in `result`.
    """
    if ptr is None:
        return
    self._inOrderTraversal(ptr.left, result)
    result.append(ptr.data)
    self._inOrderTraversal(ptr.right, result)

def toSortedArray(self):
    """
    Returns an array containing the elements of the tree in ascending order.
    """
    result = []
    self._inOrderTraversal(self.root, result)
    return result

BST._inOrderTraversal = _inOrderTraversal; BST.toSortedArray = toSortedArray

In [6]:
# Minimal testing Questions 2-4

print("Question 2")
t = BST()
print(t)
print(t.min(),t.max())
A = [22,20,11,21,42,22,44,1]
for x in A: t.add(x)
print(t)
print(t.min(),t.max())
t.removeAll(22)
print(t)

print("\nQuestion 3")
t = BST()
print(t)
print(t.sumAll(), t.sumAllBFS())
A = [22,20,11,21,42,22,44]
for x in A: t.add(x)
print(t)
print(t.sumAll(), t.sumAllBFS())

print("\nQuestion 4")
print(BST().toSortedArray())
print(t.toSortedArray())

Question 2
<__main__.BST object at 0x7ef6d5569d60>
None None
<__main__.BST object at 0x7ef6d5569d60>
1 44
<__main__.BST object at 0x7ef6d5569d60>

Question 3
<__main__.BST object at 0x7ef6d5539850>
0 0
<__main__.BST object at 0x7ef6d5539850>
182 182

Question 4
[]
[11, 20, 21, 22, 22, 42, 44]


## Question 5

You are asked to write a class `BST2` which implements a BST in which each node has a multiplicity counter (`mult`), which counts how many times the node's value is stored in the tree. This way, there is no need to store duplicate nodes in the tree: 
- adding a value that already exists in the tree simply amounts to increase the counter of the value's node by 1; 
- removing a value from the tree amounts to reducing the counter of its node by 1, and if the counter becomes 0 then the node is removed altogether.

Below we have provided you with a class of nodes `BTNode2` to use, and we made a start in implementing `BST2`.You are asked to implement the following functions:

- `add(self,d)` for adding the value `d` in the BST2. This should use BST 
search and either increase the `mult` counter of the `BTNode2` containing `d` or, if `d` is not in the tree, create a new `BTNode2` for `d`.

- `search(self,d)` for searching the value `d` in the BST2. This should use BST 
search and return `True` if the value is found, and `False` otherwise.

- `count(self,d)` for counting the times the value `d` appears in the BST2. This 
should use BST search and return the number of times that the value appears in 
the BST2.

- `remove(self,d)` for removing one occurrence of the value `d` from the BST2.

In [None]:
class BTNode2:
    def __init__(self,d,l,r):
        self.data = d
        self.left = l
        self.right = r
        self.mult = 1
          
    # prints the node and all its children in a string
    def __str__(self):  
        st = str(self.data)+" ("+str(self.mult)+")-> ["
        if self.left != None:
            st += str(self.left)
        else: st += "None"
        if self.right != None:
            st += ", "+str(self.right)
        else: st += ", None"
        return st + "]"
    
class BST2:
    def __init__(self):
        self.root = None
        self.size = 0
        
    def __str__(self):
        if self.root == None: return "None"        
        return str(self.root)
    
    def add(self, d):
        if self.root is None:
            self.root = BTNode2(d, None, None)
        else:
            ptr = self.root
            while True:
                if d == ptr.data:
                    ptr.mult += 1
                    break
                elif d < ptr.data:
                    if ptr.left is None:
                        ptr.left = BTNode2(d, None, None)
                        break
                    ptr = ptr.left
                else:  # d > ptr.data
                    if ptr.right is None:
                        ptr.right = BTNode2(d, None, None)
                        break
                    ptr = ptr.right
        self.size += 1

    def search(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return True
            elif d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return False

    def count(self, d):
        ptr = self.root
        while ptr is not None:
            if d == ptr.data:
                return ptr.mult
            elif d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right
        return 0

    def remove(self, d):
        """Remove one occurrence of `d` from the tree."""
        parent = None
        ptr = self.root

        while ptr is not None and ptr.data != d:
            parent = ptr
            if d < ptr.data:
                ptr = ptr.left
            else:
                ptr = ptr.right

        if ptr is None:
            return

        if ptr.mult > 1:
            ptr.mult -= 1
            self.size -= 1
        else:
            self.size -= 1
            if ptr.left is None and ptr.right is None:
                if parent is None:
                    self.root = None
                elif parent.left == ptr:
                    parent.left = None
                else:
                    parent.right = None
            elif ptr.left is None:
                if parent is None:
                    self.root = ptr.right
                elif parent.left == ptr:
                    parent.left = ptr.right
                else:
                    parent.right = ptr.right
            elif ptr.right is None:
                if parent is None:
                    self.root = ptr.left
                elif parent.left == ptr:
                    parent.left = ptr.left
                else:
                    parent.right = ptr.left
            else:
                successor_parent = ptr
                successor = ptr.right
                while successor.left is not None:
                    successor_parent = successor
                    successor = successor.left
                ptr.data = successor.data
                ptr.mult = successor.mult
                if successor_parent.left == successor:
                    successor_parent.left = successor.right
                else:
                    successor_parent.right = successor.right
        
print("Question 5")
t = BST2()
A = [22,20,11,21,42,11,22,44,1]
for x in A: t.add(x)
print(t)
for x in A:
    print(x,t.search(x),t.count(x),t.search(-x),t.count(-x))
for x in A:
    t.remove(x); print("take",x,":\n",t)

Question 5
22 (2)-> [20 (1)-> [11 (2)-> [1 (1)-> [None, None], None], 21 (1)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
22 True 2 False 0
20 True 1 False 0
11 True 2 False 0
21 True 1 False 0
42 True 1 False 0
11 True 2 False 0
22 True 2 False 0
44 True 1 False 0
1 True 1 False 0
take 22 :
 22 (1)-> [20 (1)-> [11 (2)-> [1 (1)-> [None, None], None], 21 (1)-> [None, None]], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 20 :
 22 (1)-> [21 (1)-> [11 (2)-> [1 (1)-> [None, None], None], None], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 11 :
 22 (1)-> [21 (1)-> [11 (1)-> [1 (1)-> [None, None], None], None], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 21 :
 22 (1)-> [11 (1)-> [1 (1)-> [None, None], None], 42 (1)-> [None, 44 (1)-> [None, None]]]
take 42 :
 22 (1)-> [11 (1)-> [1 (1)-> [None, None], None], 44 (1)-> [None, None]]
take 11 :
 22 (1)-> [1 (1)-> [None, None], 44 (1)-> [None, None]]
take 22 :
 44 (1)-> [1 (1)-> [None, None], None]
take 44 :
 1 (1)-> [None, None]
take 1 