In [1]:
import math

In [2]:
class Tree(object):
    "Generic tree node."
    def __init__(self, name='root', children=None):
        self.name = name
        self.children = []
        self.parent = None
        self.parentsCount = 0
        if children is not None:
            for child in children:
                self.add_child(child)
                
    def __repr__(self):
        return self.name + " orbits: " + str(self.parentsCount)
    
    def add_child(self, node):
        assert isinstance(node, Tree)
        node.parent = self
        node.parentsCount = node.parent.parentsCount + 1
        self.children.append(node)
        
    def traverseAndPrint(self):
        print(self)
        for c in self.children:
            c.traverseAndPrint
            
    def traverseUpAndPrint(self):
        print(self.name)
        if self.parent is not None:
            self.parent.traverseUpAndPrint()
            
    def getTotalParents(self, count = 0):
        if self.parent is not None:
            return self.parent.getTotalParents(count + 1)
        return count
    
    def find(self, name):
        if self.name == name:
            return self
        for c in self.children:
            nxt = c.find(name)
            if nxt is not None:
                return nxt
                    
    def depth(self, depth = 0):
        if self.parent is not None:
            return self.parent.depth(depth + 1)
        return depth
    
    def getChildFromMap(self, m):
        nxt = [Tree(p[1]) for p in m if p[0] == self.name]
        for p in nxt:
            self.add_child(p)
            #print("child: ", self.children)
        for c in self.children:
            c.getChildFromMap(m)
        
    def getOrbitCounts(self):
        self.totalOrbits = 0
        def aggregateOrbits(n):
            self.totalOrbits += n.getTotalParents()
            for c in n.children:
                aggregateOrbits(c)
        aggregateOrbits(self)
        return self.totalOrbits
        

In [3]:
tree = Tree('*', [Tree('1'),
               Tree('2'),
               Tree('+', [Tree('3'),
                          Tree('4')])])

In [4]:
print(tree.find('4').depth())

2


In [5]:
file = open("input.txt", "r")

content = file.read()
content = content.split('\n')
content = [each.split(')') for each in content]

content.pop(-1)

copy = content.copy()

#print(content)

In [6]:
def LCA(root, n1, n2):
    foundLCA = None
    
    for c in root.children:
        lca = LCA(c, n1, n2)
        if lca is not None:
            foundLCA = lca
            
    if foundLCA is not None:
        return foundLCA
            
    if root.find(n1) is not None and root.find(n2) is not None:
        return root
    return None

def distance(root, n1, n2):
    lca = LCA(root, n1, n2)
    node1 = root.find(n1)
    node2 = root.find(n2)
    
    dist = node1.depth() + node2.depth() - 2 * lca.depth()
    return dist

In [7]:
content = copy.copy()

t = Tree('COM')
t.getChildFromMap(content)
    
#print(t, t.children)
#t.traverse()
print(t.getOrbitCounts())
#print(t.findDistance('SAN'), -2)
print(distance(t, 'YOU', 'SAN') - 2)

249308
349
