In [1]:
from enum import Enum

State = Enum('State', ('active', 'inactive'))

In [36]:
from typing import Generator
from itertools import product
from numba import jit

class Dimensions(object):
    def __init__(self, dims: int, data: str):
        lines = data.split('\n')
        self.dims = dims
        self.mins = [0] * dims
        self.maxs = [len(lines), len(lines)] + [0] * (dims - 2)
        self.cells = {}
        self.gen = [diff for diff in product(*[range(-1, 2) for _ in range(self.dims)]) 
                    if any(y != 0 for y in diff)]
        for x, line in enumerate(lines):
            for y, c in enumerate(line.rstrip()):
                if c == '#':
                    self.cells[tuple([x, y] + [0] * (dims - 2))] = State.active
    
    def set_mins_and_maxs(self, *args):
        if len(args) != self.dims:
            raise ValueError('wrong dimensions')
        for i in range(self.dims):
            if args[i] < self.mins[i]:
                self.mins[i] = args[i]
            if args[i] > self.maxs[i]:
                self.maxs[i] = args[i]

    def count_active_neighbors(self, *args) -> int:
        if len(args) != self.dims:
            raise ValueError('wrong dimensions')
        counter = 0
        for diff in self.gen:
            if not all(self.mins[i] <= args[i] + diff[i] <= self.maxs[i] for i in range(self.dims)):
                continue
            coords = [args[i] + diff[i] for i in range(self.dims)]
            state = self.cells.get(tuple(coords), State.inactive)
            if state == State.active:
                counter += 1
        return counter
    
    def next_cycle(self):
        new_cells = {}
        g = product(*[range(self.mins[i] - 1, self.maxs[i] + 2) for i in range(self.dims)])
        for coords in g:
            neighbors = self.count_active_neighbors(*coords)
            state = self.cells.get(coords, State.inactive)
            if state == State.active and (neighbors == 2 or neighbors == 3):
                new_cells[coords] = State.active
                self.set_mins_and_maxs(*coords)
            if state == State.inactive and neighbors == 3:
                new_cells[coords] = State.active
                self.set_mins_and_maxs(*coords)
        self.cells = new_cells
    
    def count_all_active(self) -> int:
        return len(self.cells)

In [22]:
with open('testcase1.txt') as fn:
    universe = Dimensions(3, fn.read())
universe.cells

{(0, 1, 0): <State.active: 1>,
 (1, 2, 0): <State.active: 1>,
 (2, 0, 0): <State.active: 1>,
 (2, 1, 0): <State.active: 1>,
 (2, 2, 0): <State.active: 1>}

In [23]:
universe.next_cycle()
universe.count_all_active()

11

In [24]:
universe.next_cycle()
universe.count_all_active()

21

In [25]:
universe.next_cycle()
universe.count_all_active()

38

In [26]:
def all_solution(input_file: str, dims: int, cycles: int=6) -> int:
    with open(input_file) as fn:
        universe = Dimensions(dims, fn.read())
    for i in range(cycles):
        universe.next_cycle()
    return universe.count_all_active()

In [27]:
all_solution('testcase1.txt', 3)

112

In [28]:
all_solution('input.txt', 3)

401

In [29]:
all_solution('testcase1.txt', 4)

848

In [60]:
all_solution('input.txt', 4)

2224

In [38]:
%prun all_solution('input.txt', 3)

 