<a href="https://colab.research.google.com/github/walkerjian/DailyCode/blob/main/count_unival_subtrees.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

User
A unival tree (which stands for "universal value") is a tree where all nodes under it have the same value.

Given the root to a binary tree, count the number of unival subtrees.

For example, the following tree has 5 unival subtrees:
````
   0
  / \
 1   0
    / \
   1   0
  / \
 1   1
````


Break down the problem following the MVC paradigm:

Model: This is the data representation. In this case, we will use a class to represent the binary tree and its nodes.

View: This will be the output of our function - the number of unival subtrees for a given binary tree.

Controller: This is the logic that determines the number of unival subtrees.
We'll start by defining our model.

In [2]:
class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

    def __repr__(self):
        return f"Node({self.value}, {self.left}, {self.right})"


def count_unival_subtrees(root):
    def is_unival(node, value):
        if not node:
            return True
        if node.value != value:
            return False
        return is_unival(node.left, value) and is_unival(node.right, value)

    if not root:
        return 0

    total = count_unival_subtrees(root.left) + count_unival_subtrees(root.right)

    if is_unival(root, root.value):
        total += 1

    return total


def test():
    test_cases = [
        # Sample test case
        (Node(0, Node(1), Node(0, Node(1, Node(1), Node(1)), Node(0))), 5),

        # Additional test cases
        (None, 0),  # Empty tree
        (Node(1), 1),  # Single node
        (Node(1, Node(1), Node(1)), 3),  # All nodes are the same
        (Node(1, Node(1), Node(2)), 2),  # Only one child is different
        (Node(1, Node(1, Node(1), Node(1)), Node(2)), 4),  # Nested unival subtrees
        (Node(1, Node(1, Node(2), Node(1)), Node(2)), 3),  # Mixed nested unival and non-unival subtrees
        (Node(2, Node(2, Node(2, Node(2, Node(2)))), Node(2)), 6),  # Deeply nested unival subtrees
        (Node(1, Node(1, Node(1), None), None), 3),  # Unbalanced tree with unival subtrees
        (Node(1, Node(2, Node(3), Node(4)), Node(5)), 3)  # Corrected test case
    ]

    for tree, expected in test_cases:
        result = count_unival_subtrees(tree)
        assert result == expected, f"Test failed for tree {tree}. Expected: {expected}, Result: {result}"

    return "All tests passed!"

test()


'All tests passed!'