# Classes, Trees and Recursion – A Primer for the First Project

In this notebook, we will introduce the concept of classes in Python, and how they can be used to create tree data structures. We will also introduce the concept of recursion, and how it can be used to traverse trees. These concepts are fundamental to the upcoming first project where you will implement a decision tree from scratch.

---

## Object-Oriented Programming in Python

Python is an object-oriented programming language. This means that it allows you to define classes, which are templates for creating objects. An object is an instance of a class. We will cover the most basic aspects of object-oriented programming in Python, which will be sufficient for the first project.

A class can have attributes (variables) and methods (functions), combined together in a single unit.

For example, you can define a class `Person` that has attributes like `name` and `age`, and methods like `say_hello`. You can then create objects of this class, like `person1` and `person2`, and call methods on them.

Let's see an example of how to define such a class in Python.

In [152]:
class Person():
    def __init__(self, name, age):
        # The __init__ method is a special method that is called when an object is created
        self.name = name
        self.age = age

    def say_hello(self):
        # The self parameter is a reference to the current instance of the class and needs to be included in the method definition
        print(f"Hello, my name is {self.name} and I am {self.age} years old")

In [153]:
# Create two instances of the Person class
person_1 = Person("John", 36)
person_2 = Person("Katy", 25)

person_1.say_hello()
person_2.say_hello()

Hello, my name is John and I am 36 years old
Hello, my name is Katy and I am 25 years old


Let us expand the `Person` class to include the possibility of adding friends to a person. We will also add a method to print the friends of a person, and a method to check if two people are friends.

In [154]:
class Person():
    def __init__(self, name, age):
        self.name = name
        self.age = age
        self.friends = []

    def say_hello(self):
        print(f"Hello, my name is {self.name} and I am {self.age} years old")

    def add_friend(self, friend):
        self.friends.append(friend)
    
    def list_friends(self):
        print("My friends are:")
        for friend in self.friends:
            print(friend.name)
    
    def is_friend(self, friend):
        return friend in self.friends

We now test the functionality of the class.

In [155]:
person_1 = Person("John", 36)
person_2 = Person("Katy", 25)
person_3 = Person("Eric", 45)
person_4 = Person("Jessica", 23)

# Add friends
person_1.add_friend(person_2)
person_1.add_friend(person_3)

# Greet and print friends of person_1
person_1.say_hello()
person_1.list_friends()

# Check if person_2 and person_4 are friends of person_1
print(person_1.is_friend(person_2)) # True since person_2 is a friend of person_1
print(person_1.is_friend(person_4)) # False since person_4 is not a friend of person_1

Hello, my name is John and I am 36 years old
My friends are:
Katy
Eric
True
False


Usually, friendship is a symmetric relationship. If `person1` is a friend of `person2`, then `person2` is also a friend of `person1`. Can you modify the `add_friend` method to make the friendship relationship symmetric? Be careful to avoid infinite loops when adding friends.

---

## Recursion

Recursion is a technique in programming where a function calls itself. It is a powerful tool that can be used to solve problems that can be broken down into smaller, similar subproblems. Trees are a common data structure where recursion is used. In particular, when learning and prediction using decision trees, we will use recursion to traverse the tree.

Let us see an example of recursion by writing a function to calculate the factorial of a number. The factorial $n!$ of a non-negative integer $n$ is defined as $n! = n \times (n-1) \times (n-2) \times \ldots \times 1$. The factorial of 0 is defined as 1. The definition of the factorial function is recursive, as the factorial of $n$ depends on the factorial of $n-1$ since $n! = n \times (n-1)!$.

In [156]:
# Recursive functions
def factorial(n):
    if n == 0:
        return 1
    return n * factorial(n - 1)

print(factorial(5)) # 5 * 4 * 3 * 2 * 1 = 120
print(factorial(10)) # 10 * 9 * 8 * 7 * 6 * 5 * 4 * 3 * 2 * 1 = 3628800

120
3628800


Let us also see an example of using recursion to calculate the sum $\sum_{i=1}^{n} i$ of the first $n$ natural numbers.

In [157]:
def sum_up_to(n):
    if n == 1:
        return 1
    return n + sum_up_to(n - 1)
    
print(sum_up_to(10)) # 10 + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2 + 1 = 55
print(sum_up_to(100)) # 100 + 99 + 98 + ... + 3 + 2 + 1 = 5050

55
5050


The two examples above can also be solved iteratively using loops. However, there are problems that are more naturally solved using recursion. For example, as we will see later in this notebook, traversing a tree data structure is often easier using recursion.

Here are the iterative versions of the two functions above:

In [158]:
def factorial(n):
    result = 1
    for i in range(1, n + 1):
        result *= i
    return result

def sum_up_to(n):
    result = 0
    for i in range(1, n + 1):
        result += i
    return result
    # Or return n * (n + 1) // 2 if you want to be fancy

print(factorial(5)) # 5 * 4 * 3 * 2 * 1 = 120
print(factorial(10)) # 10 * 9 * 8 * 7 * 6 * 5 * 4 * 3 * 2 * 1 = 3628800

print(sum_up_to(10)) # 10 + 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2 + 1 = 55
print(sum_up_to(100)) # 100 + 99 + 98 + ... + 3 + 2 + 1 = 5050

120
3628800
55
5050


## Exercise 1 - Recursive Functions

Implement the following recursive functions (even though some of them can be implemented more efficiently in other ways):

1. Write a recursive function `fibonacci(n)` that calculates the $n$-th Fibonacci number. The Fibonacci sequence is defined as $F_0 = 0$, $F_1 = 1$, and $F_n = F_{n-1} + F_{n-2}$ for $n \geq 2$. The first few Fibonacci numbers are 0, 1, 1, 2, 3, 5, 8, 13, 21, and so on.

2. Write a recursive function `power(x, n)` that calculates $x^n$ for a given real number $x$ and a non-negative integer $n$. Hint: You can use the property $x^n = x \times x^{n-1}$.

3. Write a recursive function `gcd(a, b)` that calculates the greatest common divisor of two non-negative integers $a$ and $b$. The greatest common divisor of $a$ and $b$ is the largest number that divides both $a$ and $b$. Use the Euclidean algorithm to calculate the greatest common divisor. The Euclidean algorithm states that $\gcd(a, b) = \gcd(b, a \mod b)$, and $\gcd(a, 0) = a$.
   
4. Write a recursive function `is_palindrome(s)` that checks if a given string `s` is a palindrome. A palindrome is a string that reads the same forwards and backwards. For example, "racecar" is a palindrome. Hint: You can compare the first and last characters of the string, and then check if the substring between the first and last characters is a palindrome.

In [159]:
def fibonacci(n):
    if n == 0:
        return 0
    if n == 1:
        return 1
    return fibonacci(n-1) + fibonacci(n-2)

def power(x, n):
    if n > 1:
        return x * power(x, n-1)
    else: 
        return x
        
def gcd(a, b):
    if b == 0:
        return a
    return gcd(b, a%b)
    
def is_palindrome(s: str) -> bool:
    if len(s) < 2:
        return True
    
    if s[0] == s[-1]:
        return is_palindrome(s[1:-1])
    else:
        return False

# Test fibonacci()
print(fibonacci(1))  # 1
print(fibonacci(5))  # 5
print(fibonacci(10)) # 55
print(fibonacci(20)) # 6765

# Test power()
print(power(2, 2))   # 4
print(power(4, 4))   # 256
print(power(11, 5))  # 161051
print(power(101, 5)) # 10510100501

# Test gcd()
print(gcd(19, 17)) # 1
print(gcd(9, 12))  # 3
print(gcd(64, 32)) # 32
print(gcd(48, 18)) # 6

# Test is_palindrome()
print(is_palindrome("racecar")) # True
print(is_palindrome("hello"))   # False
print(is_palindrome("level"))   # True
print(is_palindrome("world"))   # False

1
5
55
6765
4
256
161051
10510100501
1
3
32
6
True
False
True
False


**Extra challenge:** Implement the above functions *iteratively* instead of recursively. That is, write a non-recursive version of each function using loops instead of recursion.

In [160]:
def fibonacci(n):
    if n < 1:
        return 0
    a = 0
    b = 1
    for i in range(n-1):
        a, b = b , a + b
    return b


def power(x, n):
    sum = x 
    while n > 1:
        sum *= x
        n -= 1
    return sum

            

def gcd(a, b):
    while b != 0:
        a, b = b, a%b
    return abs(a)
    

def is_palindrome(s):
    half = len(s)//2
    while s[half] == s[-half-1]:
        half -= 1
        if half == 0:
            return True
    return False


# Test fibonacci()
print(fibonacci(1))  # 1
print(fibonacci(5))  # 5
print(fibonacci(10)) # 55
print(fibonacci(20)) # 6765

# Test power()
print(power(2, 2))   # 4
print(power(4, 4))   # 256
print(power(11, 5))  # 161051
print(power(101, 5)) # 10510100501

# Test gcd()
print(gcd(19, 17)) # 1
print(gcd(9, 12))  # 3
print(gcd(64, 32)) # 32
print(gcd(48, 18)) # 6

# Test is_palindrome()
print(is_palindrome("racecar")) # True
print(is_palindrome("hello"))   # False
print(is_palindrome("level"))   # True
print(is_palindrome("world"))   # False

1
5
55
6765
4
256
161051
10510100501
1
3
32
6
True
False
True
False


---

## Trees and Traversing Trees using Recursion

A tree is a data structure that consists of nodes connected by edges. Each node has a value, and can have zero or more children. The topmost node is called the root of the tree.

Decision trees, which you will implement in the first project, are a type of tree data structure. In a decision tree, each node represents a decision, and the edges represent the possible outcomes of the decision.

Let us create the class `Node` to represent a node in a tree. Each node will have a value, and a list of children nodes.

In [161]:
class Node():
    def __init__(self, data):
        self.data = data
        self.children = []

Let us use the `Node` class to create a simple tree manually. We will create the tree shown below with 6 nodes (including the root node).

```plaintext
              root
            /      \
        left       right
       /    \         \  
left-left  left-right  right-left
 ```

In [162]:
# Create a tree using our Node class
node_1 = Node("root")
node_2 = Node("left")
node_3 = Node("right")
node_4 = Node("left-left")
node_5 = Node("left-right")
node_6 = Node("right-left")

node_1.children.append(node_2)
node_1.children.append(node_3)
node_2.children.append(node_4)
node_2.children.append(node_5)
node_3.children.append(node_6)

All good! Now let's write a function to print the tree in a human-readable format.

Let us get back to **recursion**. In the context of trees, recursion is often used to traverse the tree and perform operations on each node.

We can use recursion to traverse the tree we created and print the value of each node. We will write a function `print_tree` that takes a node as input and prints the value of the node, and then calls itself on each child of the node. To make the output more readable, we will also add an argument `level` to keep track of the level of the node in the tree and indent the output accordingly.

In [163]:
# Print the tree using a recursive function
def print_tree(node, level=0):
    print("  " * level + str(node.data))
    for child in node.children:
        print_tree(child, level + 1)

Let us test the `print_tree` function on the tree we created by starting from the root node.

In [164]:
print_tree(node_1)

root
  left
    left-left
    left-right
  right
    right-left


## Exercise 2 - Functions on Trees

For this exercise, you will implement some functions on trees using recursion. First, let us define a tree using the `Node` class with integer values to test the functions on.
    
```plaintext
              2 (root node)
           /  |  \
          /   |   \
         /    |    \
        21    54   14
       /  \   |   /  \
      69  42  70 117 31
```

In [165]:
root = Node(2)
root.children = [Node(21), Node(54), Node(14)]
root.children[0].children = [Node(69), Node(42)]
root.children[1].children = [Node(70)]
root.children[2].children = [Node(117), Node(31)]

Write the following functions using recursion to operate on the tree:

1. `sum_tree`: This function should take a node as input and return the sum of all the values in the tree rooted at that node.
2. `max_tree`: This function should take a node as input and return the maximum value in the tree rooted at that node.
3. `search_tree`: This function should take a node and a value as input and return `True` if the value is present in the tree rooted at that node, and `False` otherwise.
4. `count_nodes`: This function should take a node as input and return the number of nodes in the tree rooted at that node.
5. `print_leaf_nodes`: This function should take a node as input and print the value of all the leaf nodes in the tree rooted at that node.

In [166]:
def sum_tree(node):
    return node.data + sum([sum_tree(child) for child in node.children])

def max_tree(node):
    return max([node.data] + [max_tree(child) for child in node.children])

def search_tree(node, value):
    return node.data == value or any([search_tree(child, value) for child in node.children])
    
def count_nodes(node):
    return 1 + sum([count_nodes(child) for child in node.children])

def print_leaf_nodes(node):
    if not node.children:
        print(node.data, end=" ")
    else:
        for child in node.children:
            print_leaf_nodes(child)


print(f"The sum of the node values is: {sum_tree(root)}") # 420
print(f"The maximum value in the tree is: {max_tree(root)}") # 117
print(f"Is 41 in the tree? {search_tree(root, 41)}") # False
print(f"Is 42 in the tree? {search_tree(root, 42)}") # True
print(f"The number of nodes in the tree is: {count_nodes(root)}") # 9
print(f"The leaf nodes in the tree are: ", end=" ") # 69, 42, 70, 117, 31
print_leaf_nodes(root)

The sum of the node values is: 420
The maximum value in the tree is: 117
Is 41 in the tree? False
Is 42 in the tree? True
The number of nodes in the tree is: 9
The leaf nodes in the tree are:  69 42 70 117 31 

---

## Binary Trees

A (full) binary tree is a tree in which each node has two or no children. It is simpler than the general tree we have been working with so far, but it is a very important data structure in computer science.

For implementing a decision tree with only binary splits (as you will do in the first project), it is sufficient to use a binary tree. We will rewrite the `Node` class for use with binary trees, where each node has a value, a left child, and a right child. More attributes can be added as needed.

We will also add a method `is_leaf` to check if a node is a leaf node (i.e., it has no children). This is useful for traversing the tree until we reach a leaf node such as in the case of making predictions using a decision tree.

In [167]:
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
    
    def is_leaf(self):
        return self.left is None and self.right is None

We can create a class `Tree` to represent a (binary) tree. The tree will have a root node, and we can add methods to the class to perform operations on the tree. The advantage of using a class to represent a tree is that we can encapsulate the tree operations within the class.

In the first project you will create a class called `DecisionTree` that will represent a decision tree. The `DecisionTree` will have methods to train the tree on a dataset and make predictions using the learned tree.

Let us create a simple binary tree using the `Tree` class.

```plaintext
              2 (root node)
           /      \
          /        \
         /          \
        21          54
       /  \        /  \
      69  42      70  14
```

We will add a method `__len__` to the `Tree` class that returns the number of nodes in the tree. A method called `__len__` is a special method in Python that allows you to customize the behavior of the `len` function on objects of the class.

In [168]:
class Tree:
    def __init__(self):
        self.root = None

    def count_nodes(self, node):
        if node is None:
            return 0
        return 1 + self.count_nodes(node.left) + self.count_nodes(node.right)
    
    def __len__(self):
        return self.count_nodes(self.root)

In [169]:
tree = Tree()

tree.root = Node(2)
tree.root.left = Node(21)
tree.root.right = Node(54)
tree.root.left.left = Node(69)
tree.root.left.right = Node(42)
tree.root.right.left = Node(70)
tree.root.right.right = Node(14)

print(f"The number of nodes in the tree is: {len(tree)}")

The number of nodes in the tree is: 7


## Exercise 3 - A Simple Prediction Method using a Binary Tree

You will now implement a method to make predictions using a binary tree. This is a toy version of what you will do in the first project.

In the `Tree` class below, implement a method `predict(self, value)` that takes a integer value as input and returns the value of the leaf node that the input value reaches by traversing the tree. The input value should be compared to the value of the root node, and if it is less than the root node value, it should be passed to the left child, otherwise to the right child. This process should be repeated until a leaf node is reached. When the leaf node is reached, the `value` attribute of the leaf node should be returned.

It can be useful to implement a recursive helper function `_predict(self, node, value)` that takes a node and value as input and performs the traversal. Then, the `predict` method can call this helper function with the root node and the input value.

In [170]:
class Tree:
    def __init__(self):
        self.root = None

    def predict(self, value):
        return self._predict(self.root, value)
    
    def _predict(self, node, value):
        while not node.is_leaf():
            if value < node.value:
                node = node.left
            else:
                node = node.right
        return node.value

Test your implementation on the tree we created above by running the code below. The expected output is commented in the code.

In [171]:
tree = Tree()

tree.root = Node(2)
tree.root.left = Node(21)
tree.root.right = Node(54)
tree.root.left.left = Node(69)
tree.root.left.right = Node(42)
tree.root.right.left = Node(70)
tree.root.right.right = Node(14)

print(tree.predict(42)) # 70
print(tree.predict(55)) # 14
print(tree.predict(1)) # 69

70
14
69


Above, you implemented a type of conditional traversal of a binary tree. Let us take it a step further and make the prediction method more similar to what you will do in the first project.

Let us start by extending the `Node` class to include the following attributes:

1. `value`: The value of the node (the decision made at the node). This attribute is only used for leaf nodes.
2. `threshold`: The threshold used to make a decision at the node. This attribute is only used for non-leaf nodes.
3. `feature_index`: The index of the feature used to make a decision at the node. This attribute is only used for non-leaf nodes.
4. `left`: The left child of the node.
5. `right`: The right child of the node.

We will also keep the `is_leaf` method to check if a node is a leaf node.

The prediction method will now be a bit more complex. It will take a feature vector as input and traverse the tree to make a prediction. At each non-leaf node, the feature value at the index given by `feature_index` is compared to the `threshold`. If the feature value is less than the threshold, the left child is visited, otherwise the right child is visited. This process is repeated until a leaf node is reached, and the value of the leaf node is returned.

In [172]:
class Node:
    def __init__(self, feature_index=None, threshold=None, value=None):
        self.feature_index = feature_index
        self.threshold = threshold
        self.value = value
        self.left = None
        self.right = None
    
    def is_leaf(self):
        return self.left is None and self.right is None

Implement the `predict` method in the `Tree` class below. The method should take a feature vector as input and return the value of the leaf node that the feature vector reaches by traversing the tree.

In [173]:
class Tree:
    def __init__(self):
        self.root = None

    def predict(self, features):
        return self._predict(self.root, features)

    def _predict(self, node, features):
        while not node.is_leaf():
            if features[node.feature_index] < node.threshold:
                node = node.left
            else:
                node = node.right
        return node.value


We create a decision tree with the following structure:

```
                (feature_index=1, threshold=50)
                        /                   \
                       /                     \
                      /                       \
    (feature_index=0, threshold=30)          (value=1)
            /               \
           /                 \
          /                   \
    (value=0)                (value=2)
```

The tree above makes decisions based on two features. If the value of the feature at index 1 is less than 50, the tree goes to the left child, otherwise it goes to the right child. If the value of the feature at index 0 is less than 30, the tree goes to the left child, otherwise it goes to the right child. The value of the leaf node reached is the prediction output.

Run the code below to test your implementation. The expected output is commented in the code.

In [174]:
# Create a decision tree
tree = Tree()
tree.root = Node(feature_index=1, threshold=50)
tree.root.right = Node(value=1)
tree.root.left = Node(feature_index=0, threshold=30)
tree.root.left.left = Node(value=0)
tree.root.left.right = Node(value=2)

# Predict using the decision tree
x = [35, 55] # Feature 0 is 35 and feature 1 is 55
print(tree.predict(x)) # 1 since 55 > 50

x = [25, 45] # Feature 0 is 25 and feature 1 is 45
print(tree.predict(x)) # 0 since 45 < 50 and 25 < 30

x = [45, 25] # Feature 0 is 45 and feature 1 is 25
print(tree.predict(x)) # 2 since 25 < 50 and 45 > 30

1
0
2


Congratulations on completing this primer on classes, trees, and recursion! 

**Good luck with the first project!**