In [None]:
%load_ext tutormagic

# Tree Mutation

Tree mutation is when we take an existing tree and change it. Often times tree mutations are recursive.

## Example: Pruning Trees

When we have an existing tree and we remove subtrees from it, it's called Pruning. 

In general, to write a function that prunes a tree, we want to perform the pruning before we perform the recursive processing. This way, we don't ended up processing the pruned trees. This is not always possible, sometimes we have to look at the subtrees before we know whether to prune or not. In some cases we can tell just by looking at the branches directly. 

Recall the example of computing Fibonacci sequence using memoization. We went through every recursive call to `fib` and annotate it by whether it was:
1. Returned by `fib`
2. Found in cache
3. Skipped entirely

<img src = 'fib.jpg' width = 800/>

We discussed this recursive computation and effective memoization. We can illustrate this process by `fib_tree` and prune away everything that's either found in the cache or skipped entirely. 

Given a tree, we can prune any repeated branches by taking the tree and also keeping track of a list of the subtrees we've seen before. 

In [4]:
# Given a tree, we can prune any repeated branches by taking the tree
# and also keeping track of a list of the subtrees we've seen before
def prune_repeats(t, seen):
    # Tree mutation means reassigning the attribute, either the branch or the root.
    # In this case, we say that the branches of the pruned tree are all branches of the original tree
    # but we only include the branches that we have not seen before. This removes repeated
    # branches, but it won't remove the repeated subtrees deeper in the tree
    t.branches = [b for b in t.branches if b not in seen]
    
    
    # Updates the list of what Python has seen
    seen.append(t)
    
    
    
    # To remove the repeated subtrees, we need to recursive call prune_repeats
    for b in t.branches:
        prune_repeats(b, seen)

Now we have our `Tree` class,

In [3]:
class Tree:
    """A tree is a label and a list of branches."""
    def __init__(self, label, branches=[]):
        self.label = label
        for branch in branches:
            assert isinstance(branch, Tree)
        self.branches = list(branches)

    def __repr__(self):
        if self.branches:
            branch_str = ', ' + repr(self.branches)
        else:
            branch_str = ''
        return 'Tree({0}{1})'.format(repr(self.label), branch_str)

    def __str__(self):
        return '\n'.join(self.indented())

    def indented(self):
        lines = []
        for b in self.branches:
            for line in b.indented():
                lines.append('  ' + line)
        return [str(self.label)] + lines
    
    def is_leaf(self):
        return not self.branches

Along with the memoization function and the `fib_tree` function.

In [5]:
def memo(f):
    cache = {}
    def memoized(n):
        if n not in cache:
            cache[n] = f(n)
        return cache[n]
    return memoized

@memo
def fib_tree(n):
    """A Fibonacci tree.

    >>> print(fib_tree(4))
    3
      1
        0
        1
      2
        1
        1
          0
          1
    """
    if n == 0 or n == 1:
        return Tree(n)
    else:
        left = fib_tree(n-2)
        right = fib_tree(n-1)
        fib_n = left.label + right.label
        return Tree(fib_n, [left, right])

Now if we call `fib_tree(8)`,

In [6]:
fib_tree(8)

Tree(21, [Tree(8, [Tree(3, [Tree(1, [Tree(0), Tree(1)]), Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])])]), Tree(5, [Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])]), Tree(3, [Tree(1, [Tree(0), Tree(1)]), Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])])])])]), Tree(13, [Tree(5, [Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])]), Tree(3, [Tree(1, [Tree(0), Tree(1)]), Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])])])]), Tree(8, [Tree(3, [Tree(1, [Tree(0), Tree(1)]), Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])])]), Tree(5, [Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])]), Tree(3, [Tree(1, [Tree(0), Tree(1)]), Tree(2, [Tree(1), Tree(1, [Tree(0), Tree(1)])])])])])])])

In [7]:
print(fib_tree(8))

21
  8
    3
      1
        0
        1
      2
        1
        1
          0
          1
    5
      2
        1
        1
          0
          1
      3
        1
          0
          1
        2
          1
          1
            0
            1
  13
    5
      2
        1
        1
          0
          1
      3
        1
          0
          1
        2
          1
          1
            0
            1
    8
      3
        1
          0
          1
        2
          1
          1
            0
            1
      5
        2
          1
          1
            0
            1
        3
          1
            0
            1
          2
            1
            1
              0
              1


As we can see above, there are many repetitive `fib` calls! Using `prune_repeats`, we can get rid of those repetitions.

In [9]:
t = fib_tree(8)
prune_repeats(t, [])
print(t)

21
  8
    3
      1
        0
        1
      2
    5
  13


Now we have a more compact structure that don't include repetitive calls!