In [105]:
from functools import reduce

# Generic Tree
class Node():
    def __init__(self, id):
        self.id = id
        self.children = []
    
    def add_child(self, node):
        self.children.append(node)

    def __repr__(self):
        return self.id

class Tree():
    def __init__(self):
        self.root = None

    # add node to tree. node is root if parent not specified
    def add(self, node_id, parent=None):
        node = Node(node_id)
        if parent is None:
            self.root = node
        else:
            parent.add_child(node)
        return node
    
    # do breadth first search for node 
    def bfs(self, target_id):
        # queue of nodes to check
        nodes = [self.root]
        
        # while have nodes to check
        while len(nodes) > 0:
            # get oldest node added
            curr = nodes.pop(0)
            
            # check if target
            if curr.id is target_id:
                return curr
            
            # if not target, add node's children to queue
            nodes.extend(curr.children)
        return None
        

    # do depth first search for node
    def dfs(self, target_id, curr=None):
        # start with root node
        curr = self.root if curr is None else curr
        # found
        if curr.id is target_id:
            return curr

        # recursively check children subtree
        for child in curr.children:
            found = self.dfs(target_id, child)
            if found: 
                return found

        # no match in subtree
        return None
    
    def is_height_balanced(self, node=None):
        # tree is balanced if 
        # 1. subtrees are balanced
        # 2. and difference in height between subtrees <= 1
        
        node = self.root if node is None else node
        
        # no children, is balanced
        if len(node.children) is 0:
            return True
        
        # else find out if subtree of childten are balanced
        balanced = reduce(lambda acc, curr: \
                                   acc & self.is_height_balanced(curr), \
                                   node.children, True)
        
        # and biggest height difference
        heights = [self.height(child) for child in node.children]
        height_diff = max(heights) - min(heights) if len(heights) > 1 else heights[0]

        return balanced and height_diff <= 1
    
    def height(self, node=None):
        node = self.root if node is None else node
        
        # if no children, subtree height  is 1
        if len(node.children) == 0:
            return 1
        
        # subtree height is height + height of subtree
        return 1 + max([self.height(child) for child in node.children])

    # recursively print tree
    def to_string(self, curr, count=0):
        indent = "-" * count
        curr_string = "\n" + indent + str(curr)
        for child in curr.children:
            curr_string += self.to_string(child, count + 1)
        return curr_string

    def __repr__(self):
        return self.to_string(self.root)


In [106]:
def test():
    tree = Tree()
    a = tree.add('a')
    b = tree.add('b', a)
    c = tree.add('c', a)
    g = tree.add('g', c)
    tree.add('h', g)
    tree.add('d', b)
    tree.add('e', b)
    tree.add('f', b)
    
    #print(tree)
    
    assert tree.dfs('g') == g
    assert tree.bfs('g') == g

    assert tree.height() == 4
    
    balanced_tree = Tree()
    m = balanced_tree.add('m')
    n = balanced_tree.add('n', m)
    o = balanced_tree.add('o', m)
    
    assert balanced_tree.height() == 2
    assert tree.is_height_balanced() == False 
    assert balanced_tree.is_height_balanced() == True
    
    print('tests pass')

test()

In [19]:
# Binary tree classes
class BinaryTreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
        
    def __repr__(self):
        return str(self.val)
    
class BinaryTree:
    def __init__(self, root):
        self.root = root
    
    def is_valid(self):
        return self.is_valid_helper(self.root, -2**32, 2**32)
        
    def is_valid_helper(self, node, minv, maxv):
        # valid if:
        # 1: The left subtree of a node contains only nodes with keys less than the node's key.
        # 2: The right subtree of a node contains only nodes with keys greater than the node's key.
        # Both the left and right subtrees must also be binary search trees.
    
        if node is None:
            return True
     
        return node.val < maxv and node.val > minv \
                    and self.is_valid_helper(node.left, minv, node.val) \
                     and self.is_valid_helper(node.right, node.val, maxv)

    # recursively print tree
    def to_string(self, curr, count=0):
        indent = "-" * count
        curr_string = "\n" + indent + str(curr)
        
        if curr.left: curr_string += self.to_string(curr.left, count + 1) 
        if curr.right: curr_string += self.to_string(curr.right, count + 1) 
            
        return curr_string
        
    def __repr__(self):
        return self.to_string(self.root)

In [20]:
def test_binary():
    t5 = BinaryTreeNode(5)
    t4 = BinaryTreeNode(4, None, t5)
    t3 = BinaryTreeNode(3, t4, None)
    t2 = BinaryTreeNode(2)
    t1 = BinaryTreeNode(1, t2, t3)
    tree = BinaryTree(t1)
    
    print(tree)
    
    assert tree.is_valid() == False
    
    b6 = BinaryTreeNode(6)
    b1 = BinaryTreeNode(1)
    b3 = BinaryTreeNode(3)
    b2 = BinaryTreeNode(2, b1, b3)
    b5 = BinaryTreeNode(5, b2, b6)
    
    bin_tree = BinaryTree(b1)
    
    assert bin_tree.is_valid() == True
    
    
    bad6 = BinaryTreeNode(6)
    bad20 = BinaryTreeNode(20)
    bad15 = BinaryTreeNode(15, bad6, bad20)
    bad5 = BinaryTreeNode(5)
    bad10 = BinaryTreeNode(10, bad5, bad15)
    
    bad_tree = BinaryTree(bad10)
    assert bad_tree.is_valid() == False
    
    print('tests pass')
    
    
test_binary()


1
-2
-3
--4
---5
tests pass
