In [1]:
import numpy as np
from copy import deepcopy

In [2]:
class InfiniteImage:
    def __init__(self, image):
        self.encode = {'#': 1, '.': 0}
        self.array = np.array([list(l) for l in image.split('\n')])
        self.grid_size = self.array.shape[0] // 2
        self.store = {tuple(map(lambda n: n - self.grid_size, i)): e for i, e in np.ndenumerate(self.array)}
        self.default = '.'
    
    def __getitem__(self, key):
        if key not in self.store:
            return self.default
        return self.store[key]
    
    def __setitem__(self, key, value):
        self.store[key] = value
        m = max(map(lambda n: np.abs(n), key))
        if m > self.grid_size:
            self.grid_size = m
    
    def relevant(self):
        low = - self.grid_size - 1
        high = - low + (1 if self.array.shape[0] % 2 else 0)
        for i in range(low, high):
            for j in range(low, high):
                yield (i, j)
    
    def code(self, x, y):
        code = 0
        n = 2 ** 8
        for i in range(x-1, x+2):
            for j in range(y-1, y+2):
                code += self.encode[self[i, j]] * n
                n //= 2
        return code
    
    def nlit(self):
        return len([v for v in self.store.values() if v == '#'])
    
    def __repr__(self):
        low = - self.grid_size
        high = - low + 1
        s = ''
        for i in range(low, high):
            for j in range(low, high):
                s += self[i, j]
            s += '\n' if i != high - 1 else ''
        return s

In [3]:
def count_lit_pixels(image, iterations):
    for n in range(iterations):
        new = deepcopy(image)
        for (i, j) in image.relevant():
            new[i, j] = algo[image.code(i, j)]
        new.default = algo[0 if image.default == '.' else -1]
        image = new
    return image.nlit()

In [4]:
with open('input.txt', 'r') as f:
    algo, image = f.read().split('\n\n')
    algo = np.array(list(algo))
    image = InfiniteImage(image)

In [5]:
print(f'Part 1: {count_lit_pixels(image, iterations=2)}')

Part 1: 5682


In [6]:
print(f'Part 2: {count_lit_pixels(image, iterations=50)}')

Part 2: 17628
