In [1]:
from aocd import get_data
import numpy as np


day18 = get_data(day=18,year=2021)

In [18]:
class snailfish_number():
    def __init__(self,instring):
        lev = np.cumsum([i=='[' for i in instring]) - np.cumsum([i==']' for i in instring])
        self.digits = []
        self.levels = []
        for i,l in zip(instring,lev):
            if i.isnumeric():
                self.digits.append(int(i))
                self.levels.append(l)
                
    def explode(self):
        if np.max(self.levels) <= 4:
            return
        else:
            exdigit = [i>4 for i in self.levels].index(True)
            if exdigit > 0:
                self.digits[exdigit-1] += self.digits[exdigit]
            if exdigit < (len(self.digits)-2):
                self.digits[exdigit+2] += self.digits[exdigit+1]
            del self.digits[exdigit+1:exdigit+2]
            del self.levels[exdigit+1:exdigit+2]
            self.digits[exdigit] = 0
            self.levels[exdigit] = self.levels[exdigit]-1
            return
        
    def split(self):
        if np.max(self.digits) < 10:
            return
        else:
            spdigit = [i>9 for i in self.digits].index(True)
            self.digits.insert(spdigit+1,(self.digits[spdigit]+1)//2)
            self.digits[spdigit] = self.digits[spdigit] // 2
            self.levels[spdigit] += 1
            self.levels.insert(spdigit+1,self.levels[spdigit])            
            return
        
    def reduce(self):
        done = False
        while not done:
            while np.max(self.levels) > 4:
                self.explode()
            if np.max(self.digits) > 9:
                self.split()
            if (np.max(self.levels)<=4) & (np.max(self.digits)<=9):
                done = True
        return
    
    def add(self,other):
        output = snailfish_number('')
        output.digits = self.digits + other.digits
        output.levels = self.levels + other.levels
        output.levels = [i+1 for i in output.levels]
        output.reduce()
        return output

    def magnitude(self):
        output = snailfish_number('')
        output.digits = self.digits
        output.levels = self.levels
        if np.max(output.levels) > 4:
            output.reduce()
        for lev in range(4,1,-1):
            while np.max(output.levels)==lev:
                dig = [i==lev for i in output.levels].index(True)
                output.digits[dig] = 3*output.digits[dig] + 2*output.digits[dig+1]
                output.levels[dig] = output.levels[dig] - 1
                del output.digits[dig+1]
                del output.levels[dig+1]
        return 3*output.digits[0] + 2*output.digits[1]
            

In [19]:
a = snailfish_number('[[[[7,7],[7,7]],[[8,7],[8,7]]],[[[7,0],[7,7]],9]]')
b = snailfish_number('[[[[4,2],2],6],[8,7]]')
o = a.add(b)
print(o.digits)
print(o.levels)

print(snailfish_number('[[[[8,7],[7,7]],[[8,6],[7,7]]],[[[0,7],[6,6]],[8,7]]]').magnitude())

[8, 7, 7, 7, 8, 6, 7, 7, 0, 7, 6, 6, 8, 7]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3]
3488


In [21]:
snailfish_number('[[[[6,7],[6,7]],[[7,7],[0,7]]],[[[8,7],[7,7]],[[8,8],[8,0]]]]').add(snailfish_number('[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]')).digits

[7, 0, 7, 7, 7, 7, 7, 8, 7, 7, 8, 8, 7, 7, 8, 7]

In [22]:
a = '[[[0,[5,8]],[[1,7],[9,6]]],[[4,[1,2]],[[1,4],2]]]\n[[[5,[2,8]],4],[5,[[9,9],0]]]\n[6,[[[6,2],[5,6]],[[7,6],[4,7]]]]\n[[[6,[0,7]],[0,9]],[4,[9,[9,0]]]]\n[[[7,[6,4]],[3,[1,3]]],[[[5,5],1],9]]\n[[6,[[7,3],[3,2]]],[[[3,8],[5,7]],4]]\n[[[[5,4],[7,7]],8],[[8,3],8]]\n[[9,3],[[9,9],[6,[4,9]]]]\n[[2,[[7,7],7]],[[5,8],[[9,3],[0,2]]]]\n[[[[5,2],5],[8,[3,7]]],[[5,[7,5]],[4,4]]]'

numbers = [snailfish_number(i) for i in a.split('\n')]

out = numbers[0]
for i in numbers[1:]:
    out = out.add(i)

print(out.magnitude())

4140


In [23]:
numbers = [snailfish_number(i) for i in day18.split('\n')]

out = numbers[0]
for i in numbers[1:]:
    out = out.add(i)

print(out.magnitude())

3725


That's it! Much smaller than I had thought it would be.

In [25]:
numbers = [snailfish_number(i) for i in day18.split('\n')]

maxmag = 0
for i in numbers:
    for j in numbers:
        mag = i.add(j).magnitude()
        if mag > maxmag : maxmag = mag
            
print(maxmag)

4832


Once I set up Part I correctly, Part II was easy!