In [19]:
from utils import read_lines


deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]

def matrix_to_string(matrix):
    return ''.join([''.join(row) for row in matrix])

def calc_score(matrix):
    s = matrix_to_string(matrix)
    ans = 0
    for i, c in enumerate(s):
        if c == '#':
            ans += 2**i
    return ans

def count_neighbors(matrix, i, j):
    ans = 0
    for di, dj in deltas:
        ni, nj = i+di, j+dj
        if 0 <= ni < len(matrix) and 0 <= nj < len(matrix[0]) and matrix[ni][nj] == '#':
            ans += 1
    return ans

def spread(matrix):
    new_matrix = [['.'] * 5 for _ in range(5)]
    for i in range(5):
        for j in range(5):
            cnt = count_neighbors(matrix, i, j)
            if matrix[i][j] == '#' and cnt == 1:
                new_matrix[i][j] = '#'
            elif matrix[i][j] == '.' and cnt in (1, 2):
                new_matrix[i][j] = '#'
    return new_matrix

def parse_input(input_file):
    matrix = []
    for line in read_lines(input_file):
        matrix.append([c for c in line])
    return matrix

def print_matrix(matrix):
    for row in matrix:
        print(''.join(row))

def part1(input_file):
    matrix = parse_input(input_file)
    seen = set([matrix_to_string(matrix)])
    while True:
        matrix = spread(matrix)
        ms = matrix_to_string(matrix)
        if ms in seen:
            return calc_score(matrix)
        seen.add(ms)

def count_row(matrix, i):
    ans = 0
    for j in range(5):
        if matrix[i][j] == '#':
            ans += 1
    return ans

def count_col(matrix, j):
    ans = 0
    for i in range(5):
        if matrix[i][j] == '#':
            ans += 1
    return ans

def count_neighbors2(all_mats, level, i, j):
    matrix = all_mats[level]
    ans = 0
    for di, dj in deltas:
        ni, nj = i+di, j+dj
        if ni == 2 and nj == 2:
            continue
        if 0 <= ni < len(matrix) and 0 <= nj < len(matrix[0]) and matrix[ni][nj] == '#':
            ans += 1
    if level - 1 in all_mats:
        if i == 0 and all_mats[level -1][1][2] == '#':
            ans += 1
        elif i == 4 and all_mats[level -1][3][2] == '#':
            ans += 1 
        if j == 0 and all_mats[level -1][2][1] == '#':
            ans += 1
        elif j == 4 and all_mats[level -1][2][3] == '#':
            ans += 1
    if level + 1 in all_mats:
        if i == 1 and j == 2:
            ans += count_row(all_mats[level+1], 0)
        elif i == 3 and j == 2:
            ans += count_row(all_mats[level+1], 4)
        elif i == 2 and j == 1:
            ans += count_col(all_mats[level+1], 0)
        elif i == 2 and j == 3:
            ans += count_col(all_mats[level+1], 4)
    return ans

def spread2(all_mats):
    new_all_mats = {
        k: [['.'] * 5 for _ in range(5)] for k in all_mats
    }
    for level, matrix in all_mats.items():
        for i in range(5):
            for j in range(5):
                if i == 2 and j == 2:
                    continue
                cnt = count_neighbors2(all_mats, level, i, j)
                if matrix[i][j] == '#' and cnt == 1:
                    new_all_mats[level][i][j] = '#'
                elif matrix[i][j] == '.' and cnt in (1, 2):
                    new_all_mats[level][i][j] = '#'
    min_level, max_level = min(all_mats), max(all_mats)
    min_mat, max_mat = all_mats[min_level], all_mats[max_level]
    count_up = count_row(min_mat, 0)
    count_down = count_row(min_mat, 4)
    count_left = count_col(min_mat, 0)
    count_right = count_col(min_mat, 4)
    if any(c in (1, 2) for c in [count_up, count_down, count_left, count_right]):
        new_mat = [['.'] * 5 for _ in range(5)]
        if count_up in (1, 2):
            new_mat[1][2] = '#'
        if count_down in (1, 2):
            new_mat[3][2] = '#'
        if count_left in (1, 2):
            new_mat[2][1] = '#'
        if count_right in (1, 2):
            new_mat[2][3] = '#'
        new_all_mats[min_level - 1] = new_mat

    if max_mat[1][2] == '#' or max_mat[3][2] == '#' or max_mat[2][1] == '#' or max_mat[2][3] == '#':
        new_mat = [['.'] * 5 for _ in range(5)]
        if max_mat[1][2] == '#':
            for j in range(5):
                new_mat[0][j] = '#'
        if max_mat[3][2] == '#':
            for j in range(5):
                new_mat[4][j] = '#'
        if max_mat[2][1] == '#':
            for i in range(5):
                new_mat[i][0] = '#'
        if max_mat[2][3] == '#':
            for i in range(5):
                new_mat[i][4] = '#'
        new_all_mats[max_level + 1] = new_mat
    return new_all_mats

def count_matrix(matrix):
    ans = 0
    for i in range(5):
        for j in range(5):
            if matrix[i][j] == '#':
                ans += 1
    return ans

def part2(input_file, round=200):
    matrix = parse_input(input_file)
    all_mats = {0: matrix}
    for _ in range(round):
        all_mats = spread2(all_mats)
    return sum(count_matrix(m) for m in all_mats.values())


In [14]:
part1('inputs/day24_test.txt')

2129920

In [12]:
part1('inputs/day24.txt')

18370591

In [20]:
part2('inputs/day24_test.txt', 10)

99

In [21]:
part2('inputs/day24.txt')

2040