# Subtree Pattern Matching
The subtree pattern matching problem can be described as follows. 
*Given an input tree and another pattern tree, print the occurrences of the pattern tree in the input tree.* 

Let's consider the following example
- Input Tree 
<img src="1.png" width="600" height="600" />
- Pattern Tree
<img src="4.png" width="200" height="200" />

From the example, we can see that there are two occurrences of the pattern.
<img src="2.png" width="600" height="600" />
##### Analysing the Subtree Pattern Matching Problem

- As the goal is to print all the occurrences of the pattern in the input tree, in the worst case, there could be a exponential number of pattern matchings that can occur.
- This is because of the fact that the number of unordered subtrees in a tree is exponential.
- Using the above notion we can come up with the idea that if the tree is traversed and all the subtrees are generated â€“ for every subtree generated, we could check if the subtree is a match or not. This is a brute force solution to the problem. 
- There cannot be a polynomial time solution to the problem because we are bound by the number of subtrees as we have to atleast be printing them if each and every subtree is a match. Hence, time is lowerbound by the number of subtrees.

##### Solving the Subtree Pattern Matching Problem

We can come up with the following recurrence. 
***The subtrees starting from any node can be computed by using subtrees of its children in combination and passing via this node.***
<img src="3.png" width="600" height="600" />
From the diagram above, we can see that 1 has three children : 2,3,4. So if we would like to get all the subtrees starting from node 1, we would need to figure out the subtrees starting from 2,3 and 4. Once we figure out those subtrees, we can use them to compute the answer for node 1.

>That is done by using the following idea. 
The subtrees starting from 1 would include all possible combinations of subtrees starting from its children. Let us denote function num(i) as the number of subtrees starting from node i.
*num(1) = num(2) + num(3) + num(4) + num(2)* x *num(3) + num(2)* x *num(4) + num(3)* x  *num(4) + num(2)* x *num(3)* x *num(4) + 1*

What the above equation says is : 
- For all subtree combinations of the current node's children, we would need them to pass via the current node. 
- Since there could be many such combinations of subtrees, we would need to check all of them. There could 2 x 2 x 2 = 8 combinations which is basically all the subsets of the set <2,3,4>. This in code can be solved by bitmasks. 
- So essentially, if the subtrees of all the children are computed, we would need to merge them in combination with subtrees of its siblings(the current node's other children). 
- This requires a merge step which merges all subtrees of the subset in consideration via the current node. 
- This would give all subtrees starting from the current node.

Since we have formulated the recurrence for any node, we can extend this for all nodes in the tree and the answer is a union of all the subtrees starting from each node.

>Now, once all the subtrees are retrieved, we need to do a matching step with the pattern. For this, hashing lists in python help a lot since the hash computed for a list remains the same irrespective of the order. hash([1,2,3]) equals hash([2,3,1])

So, our final answer would simply require comparison of hashed values for all subtrees with the pattern.


In [14]:
# Set up
import copy
import pprint
pp = pprint.PrettyPrinter(indent=4)

In [23]:
# Input tree pattern as an Adjacency List
pattern = {
    'p':['q','r'],
    'q':[],
    'r':['s'],
    's':[]
}
# Pattern Node values
valpattern = {
    'p':'a',
    'q':'b',
    'r':'c',
    's':'d'
}
rootpattern = 'p'

In [20]:
# Input Tree as an Adjacency List
tree = {
    '1':['2','3','4'],
    '2':['10','11'],
    '3':['5','6'],
    '4':['7','8','9'],
    '5':['19'],
    '6':['12'],
    '7':[],
    '8':['13','14'],
    '9':[],
    '10':['15'],
    '11':['16','17','18'],
    '12':[],
    '13':[],
    '14':[],
    '15':[],
    '16':[],
    '17':[],
    '18':[],
    '19':[]
}
# Tree Node Values
valtree={
    '1':'a',
    '2':'b',
    '3':'c',
    '4':'a',
    '5':'d',
    '6':'e',
    '7':'b',
    '8':'c',
    '9':'c',
    '10':'c',
    '11':'b',
    '12':'f',
    '13':'e',
    '14':'d',
    '15':'d',
    '16':'g',
    '17':'f',
    '18':'e',
    '19':'a'
}
roottree = '1'

In [17]:
# Add the current node iterated to the subtrees
def addToTree(trees,commonroot,treeroots):
    for node in trees:
        for treeroot in treeroots:
            if commonroot not in node:
                node[commonroot] = [treeroot]
            else:
                node[commonroot].append(treeroot)
    return trees

# Merge two subtrees
def mergeTree(tree1,tree2):
    if len(tree1) == 0:
        return tree2
    if len(tree2) == 0:
        return tree1
    res = []
    for node1 in tree1:
        for node2 in tree2:
            res.append(dict(node1.items() + node2.items()))
    return res

# Get the hash value for the pattern tree
def getHashPattern(tree,root):
    hashvalues = []
    for child in tree[root]:
        hashvalues.append(getHashPattern(tree,child))
    if len(hashvalues) == 0:
        return hash(frozenset([valpattern[root]]))
    return hash(frozenset(hashvalues))

# Get the hash value for the subtree from the input tree
def getHashTree(tree,root):
    hashvalues = []
    for child in tree[root]:
        hashvalues.append(getHashTree(tree,child))
    if len(hashvalues) == 0:
        return hash(frozenset([valtree[root]]))
    return hash(frozenset(hashvalues))

In [25]:
# Traverse the tree to find all subtrees. Once a subtree is found, it is hash-compared with the pattern.
treesAtNode = {}
def dfs(node):
    if node in treesAtNode:
        return treesAtNode[node]
    n = len(tree[node])
    allTrees = []
    for i in range(1<<n):
        res = []
        nbors = []
        if(i == 0):
            res = [{node:[]}]
        else:
            for j in range(n):
                if(i & (1<<j)):
                    res = mergeTree(res,dfs(tree[node][j]))
                    nbors.append(tree[node][j])
        addToTree(res,node,nbors)
        allTrees = allTrees + res

    treesAtNode[node] = copy.deepcopy(allTrees)
    return allTrees

def printMatchedSubtrees():
    cnt = 0
    for node in tree:
        if node not in treesAtNode:
            continue
        allTrees = treesAtNode[node]
        for allTree in allTrees:
            if(getHashPattern(pattern,rootpattern)==getHashTree(allTree,node)):
                pp.pprint(allTree)
                cnt = cnt + 1
    print cnt

dfs(roottree)
printMatchedSubtrees()


{   '1': ['2', '3'], '2': [], '3': ['5'], '5': []}
{   '10': ['15'], '11': [], '15': [], '2': ['10', '11']}
{   '14': [], '4': ['7', '8'], '7': [], '8': ['14']}
3
