# Lab 10: BSTree

## Exercise 1: BSTree operations

For this exercise you'll implement three additional methods in the binary search tree data structure completed in class, so that you have an opportunity to practice both using the recursive pattern covered in class and navigating the binary tree structure.

The methods you'll implement are:

1. `count_less_than`: takes an argument `x`, and returns the number of elements in the tree with values less than `x`
2. `successor`: takes an argument `x`, and returns the smallest value from the tree that is larger than `x` (note that `x` itself does not need to be in the tree); if there are no values larger than `x`, returns `None`
3. `descendants`: takes an argument `x`, and returns all descendants of `x` in the tree (i.e., all values in the subtree rooted at `x`), ordered by value; if `x` has no descendants or does not exist in the tree, returns an empty list


The cell below contains the (read-only) BSTree implementation from lecture. Beneath that is the cell containing the methods you will be implementing, followed by unit test cells.

In [7]:
class BSTree:
    class Node:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
            
    def __init__(self):
        self.size = 0
        self.root = None

    def __contains__(self, val):
        def contains_rec(node):
            if not node:
                return False
            elif val < node.val:
                return contains_rec(node.left)
            elif val > node.val:
                return contains_rec(node.right)
            else:
                return True
        return contains_rec(self.root)

    def add(self, val):
        assert(val not in self)
        def add_rec(node):
            if not node:
                return BSTree.Node(val)
            elif val < node.val:
                return BSTree.Node(node.val, left=add_rec(node.left), right=node.right)
            else:
                return BSTree.Node(node.val, left=node.left, right=add_rec(node.right))
        self.root = add_rec(self.root)
        self.size += 1
        
    def __delitem__(self, val):
        assert(val in self)
        def delitem_rec(node):
            if val < node.val:
                node.left = delitem_rec(node.left)
                return node
            elif val > node.val:
                node.right = delitem_rec(node.right)
                return node
            else:
                if not node.left and not node.right:
                    return None
                elif node.left and not node.right:
                    return node.left
                elif node.right and not node.left:
                    return node.right
                else:
                    # remove the largest value from the left subtree as a replacement
                    # for the root value of this tree
                    t = node.left # refers to the candidate for removal
                    if not t.right:
                        node.val = t.val
                        node.left = t.left
                    else:
                        n = t
                        while n.right.right:
                            n = n.right
                        t = n.right
                        n.right = t.left
                        node.val = t.val
                    return node
                
        self.root = delitem_rec(self.root)
        self.size -= 1

    def __iter__(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield node.val
                yield from iter_rec(node.right)
                    
        return iter_rec(self.root)
            
    def __len__(self):
        return self.size
    
    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.val, width=width//2**level)
        print(repr_str)
    
    def height(self):
        """Returns the height of the longest branch of the tree."""
        def height_rec(t):
            if not t:
                return 0
            else:
                return max(1+height_rec(t.left), 1+height_rec(t.right))
        return height_rec(self.root)

In [74]:
class BSTree(BSTree):
    def count_less_than(self, x):
        def less_rec(t):
            if not t:
                return 0
            elif t.val >= x:
                return less_rec(t.left)
            else:
                return 1 + less_rec(t.left) + less_rec(t.right) 
        return less_rec(self.root)
    
    def successor(self, x):
        def successor_rec(t):
            if not t:
                return None
            if t.val > x:
                ret = successor_rec(t.left)
                if ret != None:
                    return ret
                else:
                    return t.val
            else:
                return successor_rec(t.right)
       
        return successor_rec(self.root)
                
        
    def descendants(self, x):
        
        def descendants_rec(t):
            if not t:
                return []
            if t.val < x:
                return descendants_rec(t.right)
            elif t.val > x:
                return descendants_rec(t.left)
            else:
                def inorder(t):
                    if not t:
                        return []
                    return inorder(t.left) + [t.val] + inorder(t.right)
                return inorder(t.left) + inorder(t.right)
        return descendants_rec(self.root)


In [75]:
# 3 points

t = BSTree()
for x in [6, 3, 5, 4, 7, 1, 2, 9, 8, 0]:
    t.add(x)
    
assert t.count_less_than(6) == 6
assert t.count_less_than(0) == 0
assert t.count_less_than(9) == 9
assert t.count_less_than(100) == 10
assert t.count_less_than(-100) == 0

In [76]:
# 3 points

t = BSTree()
for x in [6, 3, 5, 4, 7, 1, 2, 9, 8, 0]:
    t.add(x)

assert t.successor(6) == 7
assert t.successor(6.5) == 7
assert t.successor(4) == 5
assert t.successor(5) == 6
assert t.successor(8) == 9
assert t.successor(-1) == 0
assert t.successor(9) is None
assert t.successor(10) is None

In [77]:
# 3 points

t = BSTree()
for x in [6, 3, 5, 4, 7, 1, 2, 9, 8, 0]:
    t.add(x)

assert t.descendants(6) == [0, 1, 2, 3, 4, 5, 7, 8, 9]
assert t.descendants(3) == [0, 1, 2, 4, 5]
assert t.descendants(7) == [8, 9]
assert t.descendants(1) == [0, 2]
assert t.descendants(0) == []
assert t.descendants(8) == []
assert t.descendants(100) == []
assert t.descendants(-100) == []

## Exercise 2: BSTree as a mapping structure

For this next exercise you will re-implemet the binary search tree so that it can be used as a mapping structure. The `Node` class will be updated so as to hold separate key and value attributes (instead of a single value, as it currently does), and instead of the `add` method, you should implement the [`__getitem__`](https://docs.python.org/3.5/reference/datamodel.html#object.__getitem__) and [`__setitem__`](https://docs.python.org/3.5/reference/datamodel.html#object.__setitem__) methods in order to associate keys and values. `__delitem__`, `__contains__`, and `__iter__` will also need to be updated, to perform key-based removal, search, and iteration. Finally, the `keys`, `values`, and `items` methods will return iterators that allow the keys, values, and key/value tuples of the tree (all sorted in order of their associated keys) to be traversed.

If `__setitem__` is called with an existing key, the method will simply locate the associated node and update its value with the newly provided value (as you would expect a mapping structure to do). If either `__getitem__` or `__delitem__` are called with a key that does not exist in the tree, a `KeyError` should be raised.

The API described above will allow the tree to be used as follows:

    t = BSTree()
    t[0] = 'zero'
    t[5] = 'five'
    t[2] = 'two'

    print(t[5])
    
    t[5] = 'FIVE!!!'

    for k,v in t.items():
        print(k, '=', v)

    del t[2]

    print('length =', len(t))
    
The expected output of the above follows:

    five
    0 = zero
    2 = two
    5 = FIVE!!!
    length = 2

The following `BSTree` class contains an updated `Node`, and stubs for the methods you are to implement. The first few simple test cases beneath the class definition should help clarify the required behavior.

In [112]:
class BSTree:
    class Node:
        def __init__(self, key, val, left=None, right=None):
            self.key = key
            self.val = val
            self.left = left
            self.right = right
            
    def __init__(self):
        self.size = 0
        self.root = None
        
    def __getitem__(self, key):
        def contains_rec(node):
            if not node:
                raise KeyError
            elif key < node.key:
                return contains_rec(node.left)
            elif key > node.key:
                return contains_rec(node.right)
            else:
                return node.val
        return contains_rec(self.root)
    
    def __setitem__(self, key, val):
        def set_rec(node):
            if not node:
                return BSTree.Node(key,val)
            elif key < node.key:
                node.left = set_rec(node.left)
            elif key > node.key:
                node.right=set_rec(node.right)
            else:
                node.val = val
            return node
        self.root = set_rec(self.root)  
        self.size += 1
        
    def __delitem__(self, key):
        def delitem_rec(node):
            if not node:
                raise KeyError
            if key < node.key:
                node.left = delitem_rec(node.left)
                return node
            elif key > node.key:
                node.right = delitem_rec(node.right)
                return node
            else:
                if not node.left and not node.right:
                    return None
                elif node.left and not node.right:
                    return node.left
                elif node.right and not node.left:
                    return node.right
                else:
                    # remove the largest value from the left subtree as a replacement
                    # for the root value of this tree
                    t = node.left # refers to the candidate for removal
                    if not t.right:
                        node.val = t.val
                        node.key = t.key
                        node.left = t.left
                    else:
                        n = t
                        while n.right.right:
                            n = n.right
                        t = n.right
                        n.right = t.left
                        node.val = t.val
                        node.key = t.key
                    return node
                
        self.root = delitem_rec(self.root)
        self.size -= 1   
        
    def __contains__(self, key):
        def contains_rec(node):
            if not node:
                return False
            elif key < node.key:
                return contains_rec(node.left)
            elif key > node.key:
                return contains_rec(node.right)
            else:
                return True
        return contains_rec(self.root) 
    
    def __len__(self):
        return self.size
    
    def __iter__(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield (node.key,node.val)
                yield from iter_rec(node.right)
                    
        return iter_rec(self.root)      
        
    def keys(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield node.key
                yield from iter_rec(node.right)
                    
        return iter_rec(self.root)          

    def values(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield node.val
                yield from iter_rec(node.right)
                    
        return iter_rec(self.root)       

    def items(self):
        def iter_rec(node):
            if node:
                yield from iter_rec(node.left)
                yield (node.key,node.val)
                yield from iter_rec(node.right)
                    
        return iter_rec(self.root)       
        
    def pprint(self, width=64):
        """Attempts to pretty-print this tree's contents."""
        height = self.height()
        nodes  = [(self.root, 0)]
        prev_level = 0
        repr_str = ''
        while nodes:
            n,level = nodes.pop(0)
            if prev_level != level:
                prev_level = level
                repr_str += '\n'
            if not n:
                if level < height-1:
                    nodes.extend([(None, level+1), (None, level+1)])
                repr_str += '{val:^{width}}'.format(val='-', width=width//2**level)
            elif n:
                if n.left or level < height-1:
                    nodes.append((n.left, level+1))
                if n.right or level < height-1:
                    nodes.append((n.right, level+1))
                repr_str += '{val:^{width}}'.format(val=n.key, width=width//2**level)
        print(repr_str)
    
    def height(self):
        """Returns the height of the longest branch of the tree."""
        def height_rec(t):
            if not t:
                return 0
            else:
                return max(1+height_rec(t.left), 1+height_rec(t.right))
        return height_rec(self.root)


In [113]:
# 2 points

from unittest import TestCase

tc = TestCase()
t = BSTree()
tc.assertEqual(len(t), 0)
tc.assertFalse(0 in t)
t[0] = 'zero'
tc.assertTrue(0 in t)
tc.assertEqual(len(t), 1)

In [114]:
# 2 points

from unittest import TestCase

tc = TestCase()
t = BSTree()
tc.assertEqual(len(t), 0)
t[0] = 'zero'
tc.assertEqual(t[0], 'zero')

In [115]:
# 2 points

from unittest import TestCase

tc = TestCase()
t = BSTree()
tc.assertEqual(len(t), 0)
t[0] = 'zero'
del t[0]
tc.assertFalse(0 in t)
tc.assertEqual(len(t), 0)

In [116]:
# 2 points

from unittest import TestCase

tc = TestCase()
t = BSTree()
key_vals = [(0, 'zero'), (2, 'two'), (1, 'one')]
sorted_key_vals = sorted(key_vals)

for k,v in key_vals:
    t[k] = v

for i,k in enumerate(t.keys()):
    tc.assertEqual(k, sorted_key_vals[i][0])

In [117]:
# 1 point

from unittest import TestCase

tc = TestCase()
t = BSTree()
key_vals = [(0, 'zero'), (2, 'two'), (1, 'one')]
sorted_key_vals = sorted(key_vals)

for k,v in key_vals:
    t[k] = v

for i,v in enumerate(t.values()):
    tc.assertEqual(v, sorted_key_vals[i][1])

In [118]:
# 1 point

from unittest import TestCase

tc = TestCase()
t = BSTree()
key_vals = [(0, 'zero'), (2, 'two'), (1, 'one')]
sorted_key_vals = sorted(key_vals)

for k,v in key_vals:
    t[k] = v

for i,(k,v) in enumerate(t.items()):
    tc.assertEqual(k, sorted_key_vals[i][0])
    tc.assertEqual(v, sorted_key_vals[i][1])

In [119]:
# 3 points

from unittest import TestCase
import random

tc = TestCase()
t = BSTree()
keys = list(range(100, 1000, 11))
random.shuffle(keys)
vals = [random.randrange(1000) for _ in range(100, 1000, 11)]

for i in range(len(keys)):
    t[keys[i]] = vals[i]

for i in range(len(keys)):
    tc.assertEqual(t[keys[i]], vals[i])

In [120]:
# 3 points

from unittest import TestCase
import random

tc = TestCase()
t = BSTree()
keys = list(range(100, 1000, 11))
shuffled_keys = keys.copy()
random.shuffle(shuffled_keys)

for k in shuffled_keys:
    t[k] = str(k)

for i,k in enumerate(t.keys()):
    tc.assertEqual(k, keys[i])

for i,v in enumerate(t.values()):
    tc.assertEqual(v, str(keys[i]))

for i,(k,v) in enumerate(t.items()):
    tc.assertEqual(k, keys[i])
    tc.assertEqual(v, str(keys[i]))

In [121]:
# 3 points

from unittest import TestCase
import random

tc = TestCase()
t = BSTree()
keys = list(range(0, 100, 2))
random.shuffle(keys)

for x in keys:
    t[x] = x*2

for k in range(1, 101, 2):
    with tc.assertRaises(KeyError):
        v = t[k]