In [1]:
from copy import deepcopy

In [2]:
data = open("input/11.txt").read().splitlines()

In [3]:
neighs = [(a, b) for a in [-1, 0, 1] for b in [-1, 0, 1]]
neighs.remove(tuple((0, 0)))

In [4]:
add_pos_func = lambda x, y: x + y
def add_pos(a, b):
    return tuple(map(add_pos_func, a, b))

In [5]:
class Cavern:
    def __init__(self, data, part2=False):
        self.grid = {}
        for r_idx, row in enumerate(data):
            for c_idx, elem in enumerate(row):
                self.grid[tuple((r_idx, c_idx))] = int(elem)

        self.flashqueue = {}
        self.tot_flashes = 0
        self.part2 = part2
        
    def update_neighbors(self, cur_pos):
        for neigh in neighs:
            pos = add_pos(cur_pos, neigh)
            if pos not in self.grid:
                continue
            old = self.flashqueue.get(pos)
            # already processed
            if old == True:
                continue
            # Unseen
            if old is None:
                self.grid[pos] += 1
                if self.grid[pos] > 9:
                    self.flashqueue[pos] = False
    
    def flash(self):
        for pos, elem in self.grid.items():
            if elem > 9:
                self.flashqueue[pos] = False
        
        while not all(self.flashqueue.values()):
            for pos, state in deepcopy(self.flashqueue).items():
                if state == False:
                    self.flashqueue[pos] = True
                    self.update_neighbors(pos)
        
        for pos, elem in self.grid.items():
            if elem > 9:
                self.grid[pos] = 0
                
    def step(self):
        for pos in self.grid:
            self.grid[pos] += 1
        self.flash()
        
        num_flashes = len(self.flashqueue)
        if self.part2 and num_flashes == 100:
                return True
        
        self.tot_flashes += num_flashes
        self.flashqueue = {}

# Part 1

In [6]:
cavern = Cavern(data)
for s in range(100):
    cavern.step()
part1 = cavern.tot_flashes
print(part1)
assert part1 == 1719

1719


# Part 2

In [7]:
cavern = Cavern(data, part2=True)
s = 0
while True:    
    if cavern.step():
        part2 = s + 1
        break    
    s += 1
print(part2)
assert part2 == 232

232
