In [258]:
import re
from collections import Counter

In [2]:
with open('input.txt') as f:
    data = f.read().strip().splitlines()

In [378]:
class Node():
    '''Node class in a tree graph that can compute sum of child weights'''
    
    def __init__(self, id, weight, parent=None, children=None):
        self.id = id
        self.weight = weight
        self.parent = parent
        self.children = children if children else []
        self._children_sum = None
    
    @property
    def siblings(self):
        return self.parent.children
    
    def sum_branch(self):
        return self.sum_children() + self.weight
    
    def sum_children(self):
        if self._children_sum:
            return self._children_sum
        
        if not self.children:
            return 0
        
        output = 0
        for child in self.children:
            output += child.weight
            output += child.sum_children()
            
        self._children_sum = output
        
        return output

In [379]:
nodes = {}
nodes_with_children = []

for row in data:
    payload = row.split('->')
    id, weight = payload[0].strip().split(' ')
    weight = int(re.sub(r'[\(\)]','',weight))
    children = payload[1].replace(' ','').split(',') if len(payload)>1 else None
    node = Node(id, weight, children=children)
    if node.children:
        nodes_with_children.append(node)
    nodes[id] = node
    
for node in nodes_with_children:
    tmp = []
    for child_id in node.children:
        child = nodes[child_id]
        tmp.append(child)
        child.parent=node
    node.children = tmp

### part 1

In [380]:
root_node = nodes_with_children[0]
while True:
    if not root_node.parent:
        break
    root_node = root_node.parent

In [381]:
root_node.id

'mwzaxaj'

### part 2

In [382]:
def find_diff(node):
    while True:
        tmp = [c.sum_branch() for c in node.children]
        c = Counter(tmp).most_common()

        # balanced
        if len(c) > 1:
            ix = tmp.index(c[-1][0])
            node = node.children[ix]
            
        else:
            siblings_sum = [c.sum_branch() for c in node.siblings]
            other_sum = siblings_sum[len(siblings_sum)-1-ix]
            return other_sum - node.sum_children()

In [383]:
find_diff(root_node)

1219