In [118]:
class Snail:
    def __init__(self, left=None, right=None, parent=None):
        self.left = left
        self.right = right
        self.parent = parent
    
    def __repr__(self):
        return f'[{self.left},{self.right}]'

    def __add__(self, other):
        ans = Snail(self, other)
        self.parent = ans
        other.parent = ans
        ans.reduce()
        return ans
    
    def __eq__(self, value: object) -> bool:
        return isinstance(value, Snail) and self.left == value.left and self.right == value.right
    
    def is_left(self) -> bool:
        return self.parent and self.parent.left is self
    
    def is_right(self) -> bool:
        return self.parent and self.parent.right is self
    
    def incr_left_most_child(self, value):
        cur = self
        while isinstance(cur.left, Snail):
            cur = cur.left
        cur.left += value
    
    def incr_right_most_child(self, value):
        cur = self
        while isinstance(cur.right, Snail):
            cur = cur.right
        cur.right += value

    def explode(self, level) -> bool:
        
        if level >= 4 and isinstance(self.left, int) and isinstance(self.right, int):
            le, ri = self.left, self.right
            if self.is_left():
                p = self.parent
                while p.is_left():
                    p = p.parent
                if p.parent:
                    if isinstance(p.parent.left, Snail):
                        p.parent.left.incr_right_most_child(le)
                    else:
                        p.parent.left += le
                if isinstance(self.parent.right, Snail):
                    self.parent.right.incr_left_most_child(ri)
                else:
                    self.parent.right += ri
                self.parent.left = 0
            else:
                if isinstance(self.parent.left, Snail):
                    self.parent.left.incr_right_most_child(le)
                else:
                    self.parent.left += le
                p = self.parent
                while p.is_right():
                    p = p.parent
                if p.parent:
                    if isinstance(p.parent.right, Snail):
                        p.parent.right.incr_left_most_child(ri)
                    else:
                        p.parent.right += ri
                self.parent.right = 0
            return True
        else:
            return (isinstance(self.left, Snail) and self.left.explode(level + 1)) or (isinstance(self.right, Snail) and self.right.explode(level + 1))

    
    def split_number(self, num):
        left = num // 2
        right = num - left
        return Snail(left=left, right=right, parent=self)
    
    def split(self) -> bool:
        if isinstance(self.left, int) and self.left > 9:
            self.left = self.split_number(self.left)
            return True
        elif (isinstance(self.left, Snail) and self.left.split()):
            return True
        elif isinstance(self.right, int) and self.right > 9:
            self.right = self.split_number(self.right)
            return True
        elif (isinstance(self.right, Snail) and self.right.split()):
            return True
        else:
            return False

    def reduce(self):
        # print(self)
        while self.explode(0) or self.split():
            # print(self)
            pass
        return self
    
    def magnitude(self):
        if isinstance(self.left, int):
            le = self.left
        else:
            le = self.left.magnitude()
        if isinstance(self.right, int):
            ri = self.right
        else:
            ri = self.right.magnitude()
        return le * 3 + ri * 2

def parse_snail(s):
    dummy = Snail()
    cur =  dummy
    for c in s:
        if c == '[':
            new_snail = Snail(parent=cur)
            if cur.left is not None:
                cur.right = new_snail
            else:
                cur.left = new_snail
            cur = new_snail
        elif c == ',':
            continue
        elif c == ']':
            cur = cur.parent
        else:
            num = int(c)
            if cur.left is not None:
                cur.right = num
            else:
                cur.left = num
    dummy.left.parent = None
    return dummy.left

def sum_lines(lines):
    ans = parse_snail(lines[0])
    for line in lines[1:]:
        ans += parse_snail(line)
    return ans

from utils import read_lines

def part1(input_file):
    lines = read_lines(input_file)
    sn = sum_lines(lines)
    return sn.magnitude()

def part2(input_file):
    lines = read_lines(input_file)
    ans = 0
    for i in range(len(lines)):
        for j in range(len(lines)):
            if i != j:
                s1 = parse_snail(lines[i])
                s2 = parse_snail(lines[j])
                ans = max(ans, (s1 + s2).magnitude())
    return ans

In [110]:
part1('inputs/day18.txt')

4057

In [119]:
part2('inputs/day18.txt')

4683

In [109]:
assert parse_snail('[[[[[9,8],1],2],3],4]').reduce() == parse_snail('[[[[0,9],2],3],4]')
assert parse_snail('[7,[6,[5,[4,[3,2]]]]]').reduce() == parse_snail('[7,[6,[5,[7,0]]]]')
assert parse_snail('[[6,[5,[4,[3,2]]]],1]').reduce() == parse_snail('[[6,[5,[7,0]]],3]')
assert parse_snail('[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]').reduce() == parse_snail('[[3,[2,[8,0]]],[9,[5,[7,0]]]]]')
assert parse_snail('[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]').reduce() == parse_snail('[[3,[2,[8,0]]],[9,[5,[7,0]]]]')

In [108]:
ss = """  [[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
+ [7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
= [[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]

  [[[[4,0],[5,4]],[[7,7],[6,0]]],[[8,[7,7]],[[7,9],[5,0]]]]
+ [[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
= [[[[6,7],[6,7]],[[7,7],[0,7]]],[[[8,7],[7,7]],[[8,8],[8,0]]]]

  [[[[6,7],[6,7]],[[7,7],[0,7]]],[[[8,7],[7,7]],[[8,8],[8,0]]]]
+ [[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
= [[[[7,0],[7,7]],[[7,7],[7,8]]],[[[7,7],[8,8]],[[7,7],[8,7]]]]

  [[[[7,0],[7,7]],[[7,7],[7,8]]],[[[7,7],[8,8]],[[7,7],[8,7]]]]
+ [7,[5,[[3,8],[1,4]]]]
= [[[[7,7],[7,8]],[[9,5],[8,7]]],[[[6,8],[0,8]],[[9,9],[9,0]]]]

  [[[[7,7],[7,8]],[[9,5],[8,7]]],[[[6,8],[0,8]],[[9,9],[9,0]]]]
+ [[2,[2,2]],[8,[8,1]]]
= [[[[6,6],[6,6]],[[6,0],[6,7]]],[[[7,7],[8,9]],[8,[8,1]]]]

  [[[[6,6],[6,6]],[[6,0],[6,7]]],[[[7,7],[8,9]],[8,[8,1]]]]
+ [2,9]
= [[[[6,6],[7,7]],[[0,7],[7,7]]],[[[5,5],[5,6]],9]]

  [[[[6,6],[7,7]],[[0,7],[7,7]]],[[[5,5],[5,6]],9]]
+ [1,[[[9,3],9],[[9,0],[0,7]]]]
= [[[[7,8],[6,7]],[[6,8],[0,8]]],[[[7,7],[5,0]],[[5,5],[5,6]]]]

  [[[[7,8],[6,7]],[[6,8],[0,8]]],[[[7,7],[5,0]],[[5,5],[5,6]]]]
+ [[[5,[7,4]],7],1]
= [[[[7,7],[7,7]],[[8,7],[8,7]]],[[[7,0],[7,7]],9]]

  [[[[7,7],[7,7]],[[8,7],[8,7]]],[[[7,0],[7,7]],9]]
+ [[[[4,2],2],6],[8,7]]
= [[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]
"""

def check_all(s):
    lines = s.split('\n')
    for i in range(0, len(lines), 4):
        s1 = parse_snail(lines[i][2:])
        s2 = parse_snail(lines[i + 1][2:])
        s3 = parse_snail(lines[i + 2][2:])
        s_sum = s1 + s2
        if s_sum != s3:
            print(lines[i][2:])
            print(lines[i + 1][2:])
            print(lines[i + 2][2:])
            print(s_sum)
            print('-----')
check_all(ss)

In [111]:
example = """[1,1]
[2,2]
[3,3]
[4,4]"""
assert sum_lines(example.split('\n')) == parse_snail('[[[[1,1],[2,2]],[3,3]],[4,4]]')

example = """[1,1]
[2,2]
[3,3]
[4,4]
[5,5]"""
assert sum_lines(example.split('\n')) == parse_snail('[[[[3,0],[5,3]],[4,4]],[5,5]]')

example = """[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
[6,6]"""
assert sum_lines(example.split('\n')) == parse_snail('[[[[5,0],[7,4]],[5,5]],[6,6]]')

example = """[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]"""
assert sum_lines(example.split('\n')) == parse_snail('[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]')

In [112]:
example = """[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]"""
sum_lines(example.split('\n'))

[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]

In [116]:
example = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]"""

lines = example.split('\n')
ans = 0
for i in range(len(lines)):
    for j in range(len(lines)):
        if i != j:
            s1 = parse_snail(lines[i])
            s2 = parse_snail(lines[j])
            ans = max(ans, (s1 + s2).magnitude())
ans

3993

In [66]:
example = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]"""
s = sum_lines(example.split('\n'))
print(s)
print(s.magnitude())

[[[[6,6],[6,7]],[[7,7],[7,7]]],[[[7,7],[8,7]],[[9,7],[0,7]]]]
4126


In [114]:
assert parse_snail('[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]').magnitude() == 3488