In [361]:
import numpy as np

In [363]:
def print_tree(n):
    if n.value is None:
        l = print_tree(n.left)
        r = print_tree(n.right)
        return '[{}, {}]'.format(l, r)
    else:
        return str(n.value)


class Node:
    def __init__(self, parent):
        self.parent = parent
        self.left = None
        self.right = None
        self.value = None
        
    def __repr__(self):
        if not self.value is None:
            return 'node<{}>'.format(self.value)
        return 'node<{}>'.format(print_tree(self))


def build_tree(x, parent=None):
    n = Node(parent)
    l, r = x[0], x[1]

    if type(l) is list:
        l = build_tree(l, n)
    else:
        x = Node(n)
        x.value = l
        l = x
    if type(r) is list:
        r = build_tree(r, n)
    else:
        x = Node(n)
        x.value = r
        r = x
    n.left = l
    n.right = r
    return n

    
t = build_tree([[[[[9,8],1],2],3],4])
t

node<[[[[[9, 8], 1], 2], 3], 4]>

In [364]:
def find_explode(n, d=0, prev=None, next_=None, exploded=None):
    if n.value is None:
        if exploded is None and d == 4:
            #print('explode', n)
            return prev, next_, n

        #print('e', exploded)
        prev, next_, exploded = find_explode(n.left, d+1, prev, next_, exploded)
        prev, next_, exploded = find_explode(n.right, d+1, prev, next_, exploded)
    else:
        #print(n)
        if exploded is None:
            prev = n
            #print('prev', prev, exploded)
        elif next_ is None:
            next_ = n
            #print('next', next_)
    return prev, next_, exploded

def explode(prev, next_, exploded):
    if not prev is None:
        prev.value += exploded.left.value
    if not next_ is None:
        next_.value += exploded.right.value
    exploded.left = None
    exploded.right = None
    exploded.value = 0
    

t = build_tree([[[[[9,8],1],2],3],4])
#t = build_tree([7,[6,[5,[4,[3,2]]]]])
#t = build_tree([[6,[5,[4,[3,2]]]],1])
#t = build_tree([[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]])
#t = build_tree([[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]])
#t = build_tree([[[[[1, 1], [2, 2]], [3, 3]], [4, 4]], [5, 5]])
#t = build_tree([[[[0, [4, 5]], [0, 0]], [[[4, 5], [2, 6]], [9, 5]]], [7, [[[3, 7], [4, 3]], [[6, 3], [8, 8]]]]])
prev, next_, exploded = find_explode(t)
if exploded is not None:
    explode(prev, next_, exploded)
t

node<[[[[0, 9], 2], 3], 4]>

In [365]:
def find_split(n):
    if not n.value is None:
        if n.value >= 10:
            return n
    else:
        if not n.left is None:
            splitted = find_split(n.left)
            if splitted is not None:
                return splitted
        if not n.right is None:
            splitted = find_split(n.right)
            if splitted is not None:
                return splitted

def split(n):
    left = n.value // 2
    right = n.value - left
    
    n.value = None
    n.left = Node(n)
    n.left.value = left
    n.right = Node(n)
    n.right.value = right

t = build_tree([[3,[2,[8,0]]],[12,[5,[4,[3,2]]]]])
splitted = find_split(t)
if not splitted is None:
    split(splitted)
t

node<[[3, [2, [8, 0]]], [[6, 6], [5, [4, [3, 2]]]]]>

In [381]:
def reduce(t):
    while True:
        #print(t)
        prev, next_, exploded = find_explode(t)
        if exploded is not None:
            explode(prev, next_, exploded)
        else:
            splitted = find_split(t)
            if not splitted is None:
                split(splitted)
        if exploded is None and splitted is None:
            break
    return t

t = build_tree([[[[[4,3],4],4],[7,[[8,4],9]]]] + [[1,1]])
reduce(t)

node<[[[[0, 7], 4], [[7, 8], [6, 0]]], [8, 1]]>

In [382]:
def magnitude(n):
    if not n.value is None:
        return n.value
    else:
        if not n.left is None:
            l = magnitude(n.left)
        if not n.right is None:
            r = magnitude(n.right)
        return 3 * l + 2 * r
    
t = build_tree([[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]])
magnitude(t)

3488

In [383]:
def sum_tree(t1, t2):
    n = Node(None)
    n.left = t1
    t1.parent = n
    n.right = t2
    t2.parent = n
    return n

In [384]:
test = """
[[[[4,3],4],4],[7,[[8,4],9]]]
[1,1]
""".strip()

test = """
[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
[6,6]
""".strip()

test = """
[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]
""".strip()

numbers = [eval(l) for l in test.splitlines()]

In [385]:
with open('input.txt', 'r') as f:
    input_ = f.read()
    
numbers = [eval(l) for l in input_.splitlines()]

In [389]:
t = build_tree(numbers[0])
for n in numbers[1:]:
    n = build_tree(n)
    t = sum_tree(t, n)
    t = reduce(t)
magnitude(t)


4116

# Part 2

In [390]:
numbers_conv = [build_tree(n) for n in numbers]

max_t = None
max_m = -1
for nn1 in numbers:
    for nn2 in numbers:
        if nn1 is nn2:
            continue

        n1 = build_tree(nn1)
        n2 = build_tree(nn2)
        
        t = sum_tree(n1, n2)
        t = reduce(t)
        m = magnitude(t)
        if m > max_m:
            max_m = m
            max_t = t

In [391]:
max_m

4638