You can make a dict from a similar method to list comprehension. It's similar to using square brackets, but you use round, and then you wrap them in dict. You need to structure each row like (key, values), putting values in a list or as tuples as necessary. 

Examples: `dict((x, x**2) for x in range(10))`, 
`dict((x, [x**2, x**3]) for x in range(10))`

In [151]:
import re
import numpy as np
from functools import reduce

In [54]:
# split into dict with (weight, [above]) tuples 
def extract_program_information(file_name):
    txt = [x.strip() for x in open(file_name)]
    expression = '(\w+) \((\d+)\)( -> ([\w, ]+))?'
    re_groups = [re.match(expression, line).groups() for line in txt]
    nodes = dict((group[0], 
          (int(group[1]), 
          group[3].split(', ') 
             if group[3] else []))
     for group in re_groups)
    return nodes



In [118]:
test_nodes  = extract_program_information('input_2017_07_test.txt')
real_nodes = extract_program_information('input_2017_07.txt')

## Part 1: find the bottom node

Key insight: the bottom node has no programs on it

In [94]:
def find_bottom_program(nodes):
    program_set = set(nodes)
    nodes_above_list = [set(x[1]) for x in nodes.values() if x[1] != []]
    nodes_above_set = reduce(lambda x,y: x.union(y), nodes_above_list)
    return((program_set - nodes_above_set).pop())
    
    # also possible in one line, but readibility is difficult 
    # set(val for x in t.values() for val in x[1] if x[1])
    
    # maybe a more readable version? 
#     s1 = set()
#     for dict_name in t:
#         for node_above in t[dict_name][1]: 
#             s1.add(node_above)

In [95]:
find_bottom_program(test_nodes)

'tknk'

In [119]:
find_bottom_program(real_nodes)

'svugo'

## Part 2 

In [203]:
# Recursive function to calculate weight 
def get_weight(name):
    weight_sum = nodes[name][0]
    for name_above in nodes[name][1]: 
        weight_sum += get_weights(name_above)
    return weight_sum

In [204]:
def is_balanced(name): 
    s = set(get_weight(x) for x in nodes[name][1])
    return len(s) == 1 

In [205]:
def identify_unbalanced_branch(name):
    weights = [get_weight(x) for x in nodes[name][1]]
    unbalanced_index = np.argmax(abs(weights - np.mean(weights)))
    return nodes[name][1][unbalanced_index]

In [207]:
def find_correct_weight(nodes_to_use):
    global nodes
    nodes = nodes_to_use
    n = find_bottom_program(nodes)
    while not is_balanced(n): 
        old_n = n
        prev_weights = [get_weight(x) for x in nodes[n][1]]
        n = identify_unbalanced_branch(n) 

    correct_branch_weight = (np.sum(prev_weights) - get_weight(n)) / (len(prev_weights) -1 )
    return nodes[n][0] + correct_branch_weight - get_weight(n)

In [210]:
find_correct_weight(test_nodes)

60.0

In [211]:
find_correct_weight(real_nodes)

1152.0