In [1]:
from __future__ import annotations
from collections import UserList, deque
from typing import Union, List
from pyterator import iterate
import enum

In [2]:
class Direction(enum.Enum):
    LEFT = 0
    RIGHT = 1
    
    def __invert__(self):
        return list(Direction)[1-self.value]

In [3]:
NEST_LEVEL = 4

class Snf(UserList):
    
    def __init__(self, data):
        # assume every ch is <10
        if isinstance(data, str):
            self.data = [int(x) if x.isdigit() else x 
                         for x in data]
        elif isinstance(data, list):
            self.data = data
        else:
            raise ValueError

            
    def __repr__(self) -> str:
        return "".join([str(x) for x in self])
    
    
    def __add__(self, rhs: Snailfish) -> Snf:
        snf = Snf(["["] + self.data + [","] + rhs.data + ["]"] )
        snf.reduce_()
        return snf
    
    
    def reduce_(self):
        while True:
            while start_end := self.get_nested_indices():
                self.explode_(*start_end)
                
            indices_bignums = self.indices_big_numbers()
            if not indices_bignums:
                break

            for i in indices_bignums:
                self.split_at_(i)
                if self.get_nested_indices():
                    break
                
                
    def reduce(self) -> Snf:
        snf = Snf(self.data.copy())
        snf.reduce_()
        return snf
    
    
    def get_nested_indices(self) -> Optional[tuple]:
        """The indices including the brackets"""
        stack = []
        for i, ch in enumerate(self):
            if ch == '[':
                stack.append(ch)
            elif ch == ']':
                stack.pop()

            if len(stack) == NEST_LEVEL+1: # 1 includes the outermost sq bracket
                # Lookahead
                j = i + 1
                while self[j] != ']':
                    j += 1
                return i, j

        return None

    
    def explode_(self, start, end):
        direction = self.get_nearest_direction(start, end)
        self.add_nearest_(start, end, direction)
        self.add_nearest_(start, end, ~direction)
        del self[start:end]
    

    def __delitem__(self, indices: Union[slice,int]):
        del self.data[indices]  # default behaviour
        if isinstance(indices, slice):
            self[indices.start] = 0
    
    
    def split_at_(self, idx):
        assert is_int(self[idx]) and self[idx] > 9
        
        left, addend = divmod(self[idx], 2)
        to_add = ["[", left, ",", left+addend, "]"]
        del self[idx]
        self.data = self.data[:idx] + to_add + self.data[idx:]
    
    
    def get_nearest_direction(self, start, end) -> Direction:
        assert start >= 1
        if self[start-1] == ',':
            return Direction.LEFT
        else:
            return Direction.RIGHT
        
        
    def add_nearest_(self, start, end, direction):

        if direction==Direction.RIGHT:
            begin, diff, tmp = end, +1, self[end-1]
        else:
            begin, diff, tmp = start, -1, self[start+1]

        i = begin+diff
        while 0<=i<len(self) and not is_int(self[i]):
            i += diff

        if not 0<=i<len(self):
            return
        
        self[i] += tmp
        
        
    def indices_big_numbers(self) -> list:
        idx = [i for i,x in enumerate(self) if is_int(x) and x > 9]
        idx = [x+i*4 for i,x in enumerate(idx)]
        return idx

In [4]:
def is_int(x):
    return isinstance(x, int)

In [None]:
assert Snf(["[", 15, ",", 1, "]"]).reduce() == Snf("[[7,8],1]")
assert Snf(["[", 14, ",", 1, "]"]).reduce() == Snf("[[7,7],1]")

In [None]:
Snf(["[", 15, ",", 1, "]"]).indices_big_numbers()

In [None]:
snf = Snf("[[[[0,7],4],[15,[0,13]]],[1,1]]")
del snf[13]
snf[13] = 15


In [None]:
inputs = [("""[1,1]
[2,2]
[3,3]
[4,4]""", "[[[[1,1],[2,2]],[3,3]],[4,4]]"),
("""[1,1]
[2,2]
[3,3]
[4,4]
[5,5]""", "[[[[3,0],[5,3]],[4,4]],[5,5]]"),
("""[1,1]
[2,2]
[3,3]
[4,4]
[5,5]
[6,6]""", "[[[[5,0],[7,4]],[5,5]],[6,6]]")]

In [None]:
for inp, expected in inputs:
    assert iterate(inp.split()).map(Snf).reduce(lambda a,b: a+b) == Snf(expected)

In [None]:
assert Snf("[[[[[9,8],1],2],3],4]").reduce() == Snf("[[[[0,9],2],3],4]")
assert Snf("[7,[6,[5,[4,[3,2]]]]]").reduce() == Snf("[7,[6,[5,[7,0]]]]")
assert Snf("[[6,[5,[4,[3,2]]]],1]").reduce() == Snf("[[6,[5,[7,0]]],3]")
assert Snf("[[3,[2,[1,[7,3]]]],[6,[5,[4,[3,2]]]]]").reduce() == Snf("[[3,[2,[8,0]]],[9,[5,[7,0]]]]")
assert Snf("[[3,[2,[8,0]]],[9,[5,[4,[3,2]]]]]").reduce() == Snf("[[3,[2,[8,0]]],[9,[5,[7,0]]]]")

In [None]:
assert Snf("[[[[4,3],4],4],[7,[[8,4],9]]]") + Snf("[1,1]") == Snf("[[[[0,7],4],[[7,8],[6,0]]],[8,1]]")

In [None]:
inp = """[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]"""

In [None]:
assert iterate(inp.split()).map(Snf).reduce(lambda a,b: a+b) == Snf("[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]")

In [5]:
def magnitude(pair: List[int]):
    
    if is_int(pair):
        return pair
    
    lhs, rhs = pair
    return 3*magnitude(lhs) + 2*magnitude(rhs)

In [None]:
assert magnitude([9,1]) == 29
assert magnitude([1,9]) == 21
assert magnitude([[9,1],[1,9]]) == 129
assert magnitude([[1,2],[[3,4],5]]) == 143
assert magnitude([[[[0,7],4],[[7,8],[6,0]]],[8,1]])
assert magnitude([[[[1,1],[2,2]],[3,3]],[4,4]]) == 445
assert magnitude([[[[3,0],[5,3]],[4,4]],[5,5]]) == 791
assert magnitude([[[[5,0],[7,4]],[5,5]],[6,6]]) == 1137
assert magnitude([[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]) == 3488


In [None]:
inp = """[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]
[[[5,[2,8]],4],[5,[[9,9],0]]]
[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]
[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]
[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]
[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]
[[[[5,4],[7,7]],8],[[8,3],8]]
[[9,3],[[9,9],[6,[4,9]]]]
[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]
[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]"""

In [None]:
answer = iterate(inp.split()).map(Snf).reduce(lambda a,b: a+b)
answer

In [None]:
assert answer == Snf("[[[[6,6],[7,6]],[[7,7],[7,0]]],[[[7,7],[7,7]],[[7,8],[9,9]]]]")

In [None]:
answer = eval(str(answer))
answer

In [None]:
magnitude(answer)

In [6]:
inp = """[[[[3,9],[0,5]],[4,6]],3]
[[[8,[3,0]],[[8,4],[9,4]]],[[[0,9],4],[[3,8],2]]]
[3,[3,[[2,8],[1,4]]]]
[8,[[3,6],[[8,9],[4,1]]]]
[[3,[[5,8],[3,3]]],[[[9,3],[6,3]],[[7,0],[8,8]]]]
[[6,[[5,8],7]],[8,[[1,6],7]]]
[[[[8,6],[9,3]],[3,[2,7]]],[[[6,7],[2,8]],[6,7]]]
[[[9,[1,6]],0],[[7,3],[2,4]]]
[[[[4,9],3],6],[[7,5],8]]
[[[[8,3],8],[2,[6,5]]],[[6,[1,9]],[0,2]]]
[[[9,9],[[9,8],1]],[[[7,4],[1,4]],[[1,1],4]]]
[[5,[[8,2],[8,6]]],[9,[7,[8,9]]]]
[[4,6],[8,[3,[1,2]]]]
[[[2,[7,9]],7],[[2,0],[9,2]]]
[[4,9],[[[3,4],[2,9]],5]]
[[[[0,0],[3,7]],[[6,1],8]],[[[4,0],4],8]]
[[4,[[8,9],[2,2]]],[[[1,8],[2,7]],[[6,8],0]]]
[[7,5],[[7,0],1]]
[[[5,[1,0]],1],[[[7,7],[2,2]],[[4,2],8]]]
[[[7,1],[7,3]],[2,0]]
[[[[6,2],3],[3,[5,2]]],[[7,2],[[9,5],[0,1]]]]
[[[[0,3],2],6],9]
[[[9,8],[[7,8],[5,9]]],[[[4,8],[0,2]],[[6,8],[2,3]]]]
[2,[[3,7],9]]
[[[9,9],1],[7,[7,[5,8]]]]
[[8,[1,1]],[8,8]]
[[[[3,3],[1,4]],[[5,3],4]],[5,2]]
[[[[0,9],1],[[3,8],8]],[9,[[8,8],[0,7]]]]
[[[9,4],1],[[9,7],[[6,1],[9,5]]]]
[[[1,[4,0]],9],[[3,7],2]]
[[[5,[0,5]],[5,[9,2]]],[[[2,2],[8,0]],[3,[7,8]]]]
[[[[8,2],3],3],[[[5,4],[0,5]],9]]
[[[3,[6,2]],0],[[[7,3],[6,3]],[[6,3],2]]]
[[6,1],[[[1,2],2],[9,4]]]
[[[1,[9,0]],[[8,2],[4,9]]],[[0,[9,6]],[[0,4],[4,0]]]]
[9,[4,[7,0]]]
[[7,2],[[9,5],8]]
[[6,[[0,6],0]],[[[2,0],[4,1]],[[9,5],4]]]
[[[6,[0,0]],5],[[[5,2],[7,3]],[[2,8],[3,2]]]]
[[[2,7],[[8,2],2]],[[5,[0,6]],[[9,8],[0,4]]]]
[[[8,9],[[4,1],2]],[[[3,4],[4,5]],[[7,4],0]]]
[[5,[2,[2,1]]],[[5,6],[[6,2],[3,0]]]]
[[8,[0,0]],[[6,1],[9,[1,3]]]]
[[[9,[5,8]],5],[[8,[6,6]],[7,5]]]
[3,2]
[[8,[[6,3],[8,4]]],[[2,7],[8,[9,5]]]]
[[[4,[9,1]],[[3,6],[8,8]]],[[[9,0],6],[[3,7],6]]]
[[9,[[4,9],6]],[[8,2],[1,3]]]
[[[2,[4,3]],[[5,6],[7,3]]],7]
[[[[0,1],7],[[9,1],9]],[[[0,1],[6,5]],1]]
[[[7,[5,3]],[[6,6],6]],[[2,7],3]]
[[1,[[5,8],[1,7]]],[[[5,0],[4,7]],[[3,3],[3,7]]]]
[[[[8,8],[2,6]],[1,2]],[[[2,6],4],[1,[1,8]]]]
[5,[[8,[8,2]],0]]
[[6,[[5,9],[8,4]]],[7,[5,9]]]
[[7,3],[[[2,5],4],[[1,1],8]]]
[[[0,1],7],[0,8]]
[[7,[6,6]],[2,9]]
[[[[1,9],1],[[4,8],5]],[[0,[8,3]],[[0,9],[1,5]]]]
[[[0,9],[[6,7],5]],[4,[[1,1],[0,6]]]]
[[[6,1],7],[[[1,4],8],[[9,0],4]]]
[5,[3,[[0,7],[4,9]]]]
[[[[6,0],[1,5]],[[1,5],1]],[[1,[7,1]],[[6,2],7]]]
[[[9,0],8],[[[4,1],[5,4]],[4,[5,1]]]]
[3,[5,9]]
[6,[6,5]]
[[1,[8,0]],[9,0]]
[[[[1,8],3],0],[7,[[0,8],6]]]
[[[[4,2],2],3],[[2,5],[[9,2],4]]]
[[1,[[1,1],[8,4]]],[[[8,1],0],[0,2]]]
[[[[0,7],[8,7]],[9,6]],0]
[[3,7],[[1,[0,9]],[1,[7,6]]]]
[[[[3,5],[4,6]],[[7,1],[8,0]]],6]
[[7,[5,[7,7]]],[4,[5,3]]]
[1,[[[0,0],[4,6]],[7,[1,9]]]]
[[[3,7],[7,[0,6]]],[7,[5,3]]]
[[[[5,3],0],2],[[[2,7],[7,9]],[[1,4],3]]]
[[[[8,3],9],[[8,3],[7,4]]],[[4,[6,0]],[7,[3,7]]]]
[[[6,[5,0]],8],[[[4,5],3],[1,[5,9]]]]
[[7,8],[[6,8],[[8,4],[3,1]]]]
[[[2,7],[6,3]],[[0,0],4]]
[[1,[[6,5],[4,8]]],[[8,[2,7]],[[7,8],[6,8]]]]
[[[2,3],[7,7]],[0,[3,3]]]
[5,[[2,8],[2,[6,9]]]]
[[[[6,3],2],[[2,8],9]],[[[5,6],[8,0]],[[9,3],[5,0]]]]
[[[[6,2],7],[6,1]],[[[5,9],4],4]]
[[[[7,2],[0,4]],[[6,7],7]],[6,[[8,5],[9,0]]]]
[[[[9,6],8],[2,[3,7]]],6]
[[0,[[1,0],4]],[5,[[7,4],[2,4]]]]
[[[[4,4],[4,7]],[[7,4],3]],5]
[[[[8,2],[0,3]],[[7,2],1]],[[7,[1,2]],6]]
[[[3,8],[3,1]],[7,7]]
[[[6,5],[[8,7],4]],3]
[[7,[2,[2,5]]],[9,1]]
[9,2]
[[4,[2,9]],[[4,[2,9]],0]]
[[[0,2],[[2,1],[9,2]]],[[6,[8,2]],[4,[3,8]]]]
[1,[[[2,2],6],[[3,5],6]]]
[[[9,[4,8]],[1,4]],[4,[1,[9,1]]]]
[[[8,0],[[8,4],3]],9]"""

In [None]:
answer = iterate(inp.split()).map(Snf).reduce(lambda a,b: a+b)

In [None]:
answer

In [None]:
magnitude(eval(str(answer)))

In [None]:
Snf("[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]") + Snf("[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]")

In [None]:
iterate([1,2]).product([3,4]).to_list()

In [8]:
from itertools import combinations

In [11]:
g = (
    iterate(combinations(inp.split(),2))
    .starmap(lambda a,b: Snf(a)+Snf(b))
    .map(lambda number: magnitude(eval(str(number))))
    .to_list()
)

In [12]:
g.sort()

In [15]:
g[-2]

4712

In [10]:
# 4746 too high