In [84]:
from typing import List

class Grid:
    def __init__(self, input_data: List[List[str]]):
        self.grid = []
        for line in input_data:
            row = []
            for entry in line:
                if entry == "@":
                    row.append(1)
                elif entry == ".":
                    row.append(0)
                else:
                    raise ValueError("Found unknown string ", entry)
            self.grid.append(row)

    def check_position(self, row:int, column:int):
        if row < 0 or row >= len(self.grid[0]):
            return 0
        if column < 0 or column >= len(self.grid):
            return 0
        if self.grid[column][row] == 1:
            return 1
        elif self.grid[column][row] == 0:
            return 0
        else:
            return ValueError("Found unknown value in grid:", self.grid[column][row])

    def count_neighbors(self, p_row, p_column):
        total = 0
        for i in range(-1,2):
            for j in range(-1,2):
                if i == 0 and j == 0:
                    continue
                total += self.check_position(p_row + i, p_column + j)
        return total

    def check_grid(self, max: int):
        valid_rolls = []
        for i_column in range(len(self.grid)):
            for i_row in range(len(self.grid[0])):
                if self.grid[i_column][i_row] == 0:
                    continue
                if self.count_neighbors(i_row, i_column) < max:
                    valid_rolls.append([i_row, i_column])
        return valid_rolls
    
    def count_grid(self, max):
        return len(self.check_grid(max))
    
    def print_grid(self):
        for line in self.grid:
            for entry in line:
                print("@" if entry == 1 else ".", end="")
            print()

    def remove_roll(self, y, x):
        self.grid[y][x] = 0

    def empty_grid(self):
        progress = 1
        total = 0
        while progress != 0:
            rolls_to_remove = self.check_grid(max=4)
            for roll in rolls_to_remove:
                self.remove_roll(roll[1], roll[0])
            # self.print_grid()
            # print()
            progress = len(rolls_to_remove)
            total += len(rolls_to_remove)
        print(total)


In [68]:
with open("test.txt") as f:
    data = f.read().splitlines()

g = Grid(data)

# Part a
print(g.count_grid(max=4))

13


In [86]:
# Part b
with open("full.txt") as f:
    data = f.read().splitlines()

g = Grid(data)

g.empty_grid()

9122
