# Compare Binary Trees

Interview Question #11

Write an efficient algorithm thats able to compare two binary search trees. The method returns true if the trees are identical (same topology with same values in the nodes) otherwise it returns false.

Algorithm :
1. Traverse trees using in-order traversal
2. Compare tree nodes one by one

In [1]:
# Node Class To represent the Vertices/Nodes in the Tree
class Node():
    
    # Constructor 
    def __init__(self, data):
        self.data = data
        self.leftChild = None
        self.rightChild = None

In [32]:
# Implementing BST Class
class BinarySearchTree():
    
    # Constructor
    def __init__(self):
        self.root = None
    
    # Insert method to insert items
    def insert(self, data):
        
        if self.root is None:
            self.root = Node(data)
        else:
            self.insertNode(data, self.root)
    
    # Method to insert data Node
    # O(logN) Time complexity : If the tree is balanced
    # O(N) : If the tree is not balanced
    def insertNode(self, data, node):
        
        if data < node.data:
            # Insert to the left subtree
            if node.leftChild is not None:
                # Recursive call
                self.insertNode(data, node.leftChild)
            else:
                node.leftChild = Node(data)
        elif data > node.data:
            # Insert to the right subtree
            if node.rightChild is not None:
                # Recursive call
                self.insertNode(data, node.rightChild)
            else:
                node.rightChild = Node(data)
    
    
    # Remove method to remove node
    # Helper method
    def remove(self, data):
        if self.root is not None:
            self.removeNode(data, self.root)
    
    # Method to remove node with given data
    def removeNode(self, data, node):
        
        if node is None:
            return node
        
        if data < node.data:
            # Item we're looking for is in the left subtree
            node.leftChild = self.removeNode(data, node.leftChild)
        elif data > node.data:
            # Item we're looking for is in the right subtree
            node.rightChild = self.removeNode(data, node.rightChild)
        else:
            # Node we're at is the node to be removed
            # 3 Cases
            if node.leftChild is None and node.rightChild is None:
                # Leaf node
                print('Removing leaf node')
                del node
                return None
            if node.leftChild is None and node.rightChild is not None:
                # Single right child
                print('Removing node with single right child')
                tempNode = node.rightChild
                del node
                return tempNode
            elif node.leftChild is not None and node.rightChild is None:
                # Single left child
                print('Removing node with single left child')
                tempNode = node.leftChild
                del node
                return tempNode
            
            print('Removing node with two children')
            tempNode = self.getPredecessor(node.leftChild)
            node.data = tempNode.data
            node.leftChild = self.removeNode(tempNode.data, node.leftChild)
    
    def getPredecessor(self, node):
        
        if node.rightChild:
            return getPredecessor(node.rightChild)
        
        return node
                
            
                
    # Method to get the minimum value of tree
    def getMinValue(self):
        
        if self.root is not None:
            return self.getMin(self.root)
    
    # Main function to traverse and get the min value
    def getMin(self, node):
        
        if node.leftChild is None:
            return node.data
        else:
            return self.getMin(node.leftChild)
    
    # Method to get the maximum value of tree
    def getMaxValue(self):
        
        if self.root is not None:
            return self.getMax(self.root)
    
    # Main funcion to traverse and get the max value
    def getMax(self, node):
        
        if node.rightChild is None:
            return node.data
        else:
            return self.getMax(node.rightChild)
    
    # Traversing the tree
    def traverse(self):
        
        if self.root is not None:
            self.traverseInOrder(self.root)
    
    # In-order traversal 
    def traverseInOrder(self, node):
        
        # In-order : Left, Root, Right
        if node.leftChild is not None:
            self.traverseInOrder(node.leftChild)
            
        print(node.data)
        
        if node.rightChild is not None:
            self.traverseInOrder(node.rightChild)

In [35]:
class TreeComparator():
    
    def compareTrees(self, node1, node2):
        # Base cases : when node1, node2 are the children of leaf nodes
        if node1 is None or node2 is None:
            return node1 == node2
        
        # If values within the nodes are not the same then we return false
        if node1.data is not node2.data:
            return False
        
        # The left subtree and right subtree values must match as well
        return self.compareTrees(node1.leftChild, node2.leftChild) and self.compareTrees(node1.rightChild, node2.rightChild)
            

In [36]:
bst = BinarySearchTree()
bst.insert(4)
bst.insert(5)
bst.insert(10)
bst.insert(25)
bst.insert(3)
bst.insert(2)
bst.insert(1)

bst2 = BinarySearchTree()
bst2.insert(4)
bst2.insert(5)
bst2.insert(10)
bst2.insert(25)
bst2.insert(3)
bst2.insert(2)
bst2.insert(1)

bst3 = BinarySearchTree()
bst3.insert(4)
bst3.insert(5)
bst3.insert(10)
bst3.insert(25)


comparator = TreeComparator()
print(comparator.compareTrees(bst.root, bst2.root))
print(comparator.compareTrees(bst.root, bst3.root))

True
