# [Advent of Code 2021](https://adventofcode.com/2021)

# The toolbox


Generalised pieces of code that either can be used in multiple questions or that simply makes understand the implementation easier.

In [1249]:
from collections import defaultdict, Counter, namedtuple
from itertools import chain, count
from heapq import heappush, heappop
from typing import Iterator
import copy 
import io
import math
import operator


def Input(day, parser=str.strip, whole_file=False):
    "Fetch the data input from disk."
    filename = f"../data/advent2021/input{day}.txt"
    with open(filename) as fin:
        if whole_file:
            return parser(fin)
        return mapt(parser, fin)


def mapt(fn, *args):
    "Do a map, and convert the results to a tuple"
    return tuple(map(fn, *args))


avg = lambda n: sum(n) / len(n)


NEIGHBOUR8_DELTAS = (
    (-1, -1), ( 0, -1), (1, -1),
    (-1,  0),           (1,  0),
    (-1,  1), ( 0,  1), (1,  1),
)

NEIGHBOUR4_DELTAS = (
              ( 0, -1),
    (-1,  0),           (1,  0),
              ( 0,  1),
)


def neighbours4(x, y):
    return tuple((x + dx, y + dy) for dx, dy in NEIGHBOUR4_DELTAS)


def neighbours8(x, y):
    return tuple((x + dx, y + dy) for dx, dy in NEIGHBOUR8_DELTAS)


def a_star(start, h_func, moves, cost=lambda s1, s2: 1):
    """A* implementation.

    Finds the shortest sequence of states from to a goal (a state where
    the hueristic function - h_func - is zero). We use a heap
    as our priority queue, and processes those with the smallest
    overall cost first (the cost of the path + the distance to target).

    start:  the initial state to explore from
    h_func: hueristic function that gives a "distance" to to target state,
            when this is zero we're done.
    moves:  function that generates all possible states from the supplied state
            (these can go bakckwards, but will never be processed)
    cost:   the cost of moving from ones state to another

    return: list of states used to find the final state, wll raise an exception
            if none was found.

    """
    # The priority queue that we'll be reading from
    queue = []
    # We often care about the path taken, so persist the lowest costing path
    # to each state
    previous = {start: None}
    # Lookup of state costs, we initialize at zero for the starting state
    costs = {start: 0}

    # Initialize our queue, this is ordered by path cost (f(n) = g(n) + h(n))
    add_to_queue = lambda state: heappush(queue, (costs[state] + h_func(state), state))

    # Recursively walk backwards to build the full path
    get_path = (
        lambda state: [] if state is None else get_path(previous[state]) + [state]
    )

    # Set the intial position and go!
    add_to_queue(start)

    while queue:
        _, state = heappop(queue)
        if h_func(state) == 0:
            # We're done!
            return get_path(state)

        for new_state in moves(state):
            new_cost = costs[state] + cost(state, new_state)

            if new_state not in costs or new_cost < costs[new_state]:
                # We've found a new state or a better path
                costs[new_state] = new_cost
                previous[new_state] = state
                # We've modified our costs in some way, we need
                # to explore from this state so add to the heap
                add_to_queue(new_state)

    # No solution was found
    raise Exception("No solution for A* was discovered.")
    

class Node(object):
    """
    Generic tree node.
    """
    def __init__(self, name):
        self.name = name
        self._children = []
        self.parent = None
        
    def __repr__(self):
        return f'<Node 0x:{id(self)} name={self.name}>'

    @property
    def children(self):
        return tuple(self._children)
    
    @children.setter
    def children(self, value):
        self._children = []
        for val in value:
            self.add_child(val)

    def add_child(self, child):
        self._children.append(child)
        child.parent = self
    
    def remove_child(self, child):
        self._children.remove(child)
    
    @property
    def ancestors(self):
        node = self
        while node.parent:
            yield node.parent
            node = node.parent
    
    @property
    def depth(self):
        return len(list(self.ancestors))
    
    @property
    def is_leaf(self):
        return len(self._children) == 0
    
    @property
    def root(self):
        node = self
        while node.parent:
            node = node.parent
        return node

## [Day 1](https://adventofcode.com/2021/day/1)

In [1250]:
def count_increasing_measurements(scans):
    return sum(scan > scans[index - 1] for index, scan in enumerate(scans[1:], 1))

In [1251]:
data1 = Input(1, int)
count_increasing_measurements(data1)

1121

In [1252]:
assert _ == 1121, 'Day 1.1'

In [1253]:
chunks = [data1[i:i + 3] for i in range(len(data1) - 2)]
count_increasing_measurements(mapt(sum, chunks))

1065

In [1254]:
assert _ == 1065, 'Day 1.2'

## [Day 2](https://adventofcode.com/2021/day/2)

In [1255]:
def follow_commands(commands):
    x = d = 0
    for direction, amount in commands:
        if direction == "forward":
            x += amount
        elif direction == "down":
            d += amount
        elif direction == "up":
            d -= amount
        else:
            print("unknown command!")
    return x * d


def parse_input_2(line):
    chunks = line.split(" ")
    return chunks[0], int(chunks[1])


data2 = Input(2, parse_input_2)
follow_commands(data2)

2036120

In [1256]:
assert _ == 2036120, "Day 2.1"

In [1257]:
def follow_commands(commands):
    x = d = aim = 0
    for direction, amount in commands:
        if direction == "forward":
            x += amount
            d += aim * amount
        elif direction == "down":
            aim += amount
        elif direction == "up":
            aim -= amount
        else:
            print("unknown command!")
    return x * d


follow_commands(data2)

2015547716

In [1258]:
assert _ == 2015547716, "Day 2.2"

## [Day 3](https://adventofcode.com/2021/day/3)

Not the prettiest, but a nice example of using `zip` to unzip the string into individual digits.

In [1259]:
def find_power_consumption(report: list[list[int]]) -> int:
    digits = list(zip(*report))
    avg_bits = [round(avg(digits[bit_index])) for bit_index in range(len(digits))]

    most_common = lambda index: str(avg_bits[index])
    least_common = lambda index: str(int(not avg_bits[index]))

    gamma = "".join(mapt(most_common, range(len(digits))))
    epsilon = "".join(mapt(least_common, range(len(digits))))
    return int(gamma, 2) * int(epsilon, 2)


def parse_input(line):
    return mapt(int, line.strip())


data3 = Input(3, parse_input)
find_power_consumption(data3)

3969000

In [1260]:
assert _ == 3969000, 'Day 3.1'

Part two was fun, here we continue to abuse bools and recurse through `bit_criteria` to remove digits.

In [1261]:
def bit_criteria(
    report: list[list[int]], keep_most_common: bool, bit_index: int = 0
) -> int:
    if len(report) == 1:
        return report[0]

    digits = list(zip(*report))
    avg_bit = avg(digits[bit_index])

    # I can't be bothered to use Decimal ROUND UP here...
    most_common = 1 if avg_bit == 0.5 else round(avg_bit)

    winning_bit = most_common if keep_most_common else int(not most_common)

    return bit_criteria(
        [line for line in report if line[bit_index] == winning_bit],
        keep_most_common,
        bit_index + 1,
    )


def find_life_support_rating(report: list[list[int]]) -> int:
    oxygen_generator = bit_criteria(report[:], True)
    co2_scrubber = bit_criteria(report[:], False)

    to_int = lambda bits: int("".join(mapt(str, bits)), 2)

    return to_int(oxygen_generator) * to_int(co2_scrubber)


find_life_support_rating(data3)

4267809

In [1262]:
assert _ == 4267809, 'Day 3.2'

## [Day 4](https://adventofcode.com/2021/day/4)

Not particularly clean, and I don't really think all the sets are required, I could have simply read out rows and cols from a larger list! Oh well, did the job.

In [1263]:
BingoBoard = namedtuple("BingoBoard", "rows,cols")


def play_bingo(numbers: [int], boards: [BingoBoard], win_last: bool = False) -> int:
    # translate boards to a list of sets so we can easily remove seen numbers
    board_nums = dict(
        (boardId, [*[set(r) for r in board.rows], *[set(c) for c in board.cols]])
        for boardId, board in enumerate(boards)
    )

    # keep track of the most recent winningi board
    winning_board = None

    # track boards that are still in play
    active_boards = set(board_nums.keys())

    seen_numbers = set()
    for number in numbers:
        seen_numbers.add(number)

        for board_id in list(active_boards):
            for row_or_col_nums in board_nums[board_id]:
                row_or_col_nums -= seen_numbers
                if len(row_or_col_nums) == 0:
                    # Don't exit now as we'll need to remove this number
                    # from the row or col too!
                    winning_board = board_id

                    if board_id in active_boards:
                        active_boards.remove(board_id)

        have_a_winner = not win_last and winning_board is not None
        should_last_win = win_last and len(active_boards) == 0
        if have_a_winner or should_last_win:
            break

    if winning_board:
        remaining = set()
        for nums in board_nums[winning_board]:
            remaining |= nums
        return sum(remaining) * number

    print("No winning board found!")
    return -1


def parse_input(lines: [str]) -> [[int], [BingoBoard]]:
    numbers = []
    boards = []

    for line in lines:
        if not numbers:
            numbers = mapt(int, line.strip().split(","))
            continue

        if not line.strip():
            # add a new board
            boards.append(BingoBoard([], []))
            continue
        boards[-1].rows.append(mapt(int, line.strip().split()))

    # rotate rows into cols
    for board in boards:
        board.cols.extend((list(zip(*board.rows))))

    return (numbers, boards)


data4 = Input(4, parse_input, whole_file=True)
play_bingo(*data4)

8580

In [1264]:
assert _ == 8580, "Day 4.1"

In [1265]:
play_bingo(*data4, win_last=True)

9576

In [1266]:
assert _ == 9576, "Day 4.2"

## [Day 5](https://adventofcode.com/2021/day/5)

Slight trick on this one was to sort the inputs so we knew we'd always be _increasing_ in x or y, makes finding the delta much simpler.

In [1267]:
Point = namedtuple("Point", "x,y")


def count_overlapping_points(point_lines, horizontal_only=True):
    area = Counter()

    if horizontal_only:
        point_lines = [
            line
            for line in point_lines
            if line[0].x == line[1].x or line[0].y == line[1].y
        ]

    for start, stop in point_lines:
        if start.x == stop.x:
            delta = Point(0, 1)
        elif start.y == stop.y:
            delta = Point(1, 0)
        elif start.y < stop.y:
            delta = Point(1, 1)
        else:
            delta = Point(1, -1)

        line = [start]
        while line[-1] != stop:
            last_point = line[-1]
            next_point = Point(last_point.x + delta.x, last_point.y + delta.y)
            line.append(next_point)

        area.update(line)

    return len([val for _, val in area.items() if val >= 2])


def parse_input(line):
    chunks = line.split(" -> ")
    points = (
        Point(*mapt(int, chunks[0].split(","))),
        Point(*mapt(int, chunks[1].split(","))),
    )

    # sorting the points ensures we start from the left most one first
    return sorted(points)


data5 = Input(5, parse_input)
count_overlapping_points(data5)

6710

In [1268]:
assert _ == 6710, 'Day 5.1'

In [1269]:
count_overlapping_points(data5, horizontal_only=False)

20121

In [1270]:
assert _ == 20121, 'Day 5.2'

## Day 6

In [1271]:
def grow_lantern_fish(fish: [int], remaining_days: int):
    # first naive recursive implementation
    if remaining_days == 0:
        return fish
    
    next_generation = []
    for f in fish:
        next_val = f - 1
        # -1 here as the zeroth day counts
        if next_val == -1:
            next_generation.append(8)
            next_val = 6
        next_generation.append(next_val)
    return grow_lantern_fish(next_generation, remaining_days - 1)
        
            

data5_input = "1,5,5,1,5,1,5,3,1,3,2,4,3,4,1,1,3,5,4,4,2,1,2,1,2,1,2,1,5,2,1,5,1,2,2,1,5,5,5,1,1,1,5,1,3,4,5,1,2,2,5,5,3,4,5,4,4,1,4,5,3,4,4,5,2,4,2,2,1,3,4,3,2,3,4,1,4,4,4,5,1,3,4,2,5,4,5,3,1,4,1,1,1,2,4,2,1,5,1,4,5,3,3,4,1,1,4,3,4,1,1,1,5,4,3,5,2,4,1,1,2,3,2,4,4,3,3,5,3,1,4,5,5,4,3,3,5,1,5,3,5,2,5,1,5,5,2,3,3,1,1,2,2,4,3,1,5,1,1,3,1,4,1,2,3,5,5,1,2,3,4,3,4,1,1,5,5,3,3,4,5,1,1,4,1,4,1,3,5,5,1,4,3,1,3,5,5,5,5,5,2,2,1,2,4,1,5,3,3,5,4,5,4,1,5,1,5,1,2,5,4,5,5,3,2,2,2,5,4,4,3,3,1,4,1,2,3,1,5,4,5,3,4,1,1,2,2,1,2,5,1,1,1,5,4,5,2,1,4,4,1,1,3,3,1,3,2,1,5,2,3,4,5,3,5,4,3,1,3,5,5,5,5,2,1,1,4,2,5,1,5,1,3,4,3,5,5,1,4,3"
data5 = mapt(int, data5_input.split(','))
len(grow_lantern_fish(data5, 80))

346063

In [1272]:
assert _ == 346063, 'Day 6.1'

In [1273]:
def count_fish_in_state(fish: [int], days: int) -> int:
    # find the fish in each state, creating a list of length 8
    days_state = [fish.count(days_left) for days_left in range(9)]

    for _ in range(days):
        reproducing_fish = days_state[0]
        # remove the fish that have reproduced
        days_state = days_state[1:]
        # reset day 8 with new fish
        days_state.append(reproducing_fish)
        # move that have reproduced to day 6
        days_state[6] += reproducing_fish
    return sum(days_state)

count_fish_in_state(data5, 256)

1572358335990

In [1274]:
assert _ == 1572358335990, 'Day 6.2'

## [Day 7](https://adventofcode.com/2021/day/7)

Here, we're looking to minimise [aggregating discrepancies](https://www.johnmyleswhite.com/notebook/2013/03/22/modes-medians-and-means-an-unifying-perspective/).

For part one, we're looking at **absolute deviation**, or that the value of the discrepancy increases linearly with the distance from it. We're seeking to minimise this value, which equates to the median value of the set.

In [1275]:
def align_crab_subs(positions):
    positions = sorted(positions)
    median = positions[int(len(positions) / 2)]
    return sum(abs(pos - median) for pos in positions)


data7 = Input(7, lambda l: mapt(int, l.split(",")))[0]
align_crab_subs(data7)

341534

In [1276]:
assert _ == 341534, 'Day 7.1'

Then for part two, we're told that it's the sum of integers up to a limit, the more general form of which is `sum(1..n) -> n(n+1)/2` ([wikipedia](https://en.wikipedia.org/wiki/1_%2B_2_%2B_3_%2B_4_%2B_%E2%8B%AF)). If we expand we have `n^2 + n / 2` , because `n^2 >> n` we can make an approximation that we're looking to minimise the the average square distance, which is another way of expressing the mean.

This tells us the we can expect the answer to be around the mean value - I got lucky, as it was just the integer of the mean!

In [1277]:
def align_crab_subs2(positions):
    mean = int(avg(positions))
    return sum(int(abs(pos - mean) * (abs(pos - mean) + 1) / 2) for pos in positions)

align_crab_subs2(data7)

93397632

In [1278]:
assert _ == 93397632, 'Day 7.2'

## [Day 8](https://adventofcode.com/2021/day/8)

In [1279]:
def count_easy_digits(mappings: [[str], [str]]) -> int:
    outputs = [out for _, out in mappings]
    return sum(len(digits) in (2, 3, 4, 7) for digits in chain(*outputs))


def parse_input(line):
    inputs, outputs = line.strip().split(" | ")
    return inputs.split(" "), outputs.split(" ")


data8 = Input(8, parse_input)
count_easy_digits(data8)

362

In [1280]:
assert _ == 362, "Day 8.1"

For the second part I spent a bunch of time staring at the digits and figuring out the logical combinations for them. I did toy with creating a general solution for this but it didn't seem worth it given the number of characters!

```
  0:      1:      2:      3:      4:
 aaaa    ....    aaaa    aaaa    ....
b    c  .    c  .    c  .    c  b    c
b    c  .    c  .    c  .    c  b    c
 ....    ....    dddd    dddd    dddd
e    f  .    f  e    .  .    f  .    f
e    f  .    f  e    .  .    f  .    f
 gggg    ....    gggg    gggg    ....

  5:      6:      7:      8:      9:
 aaaa    aaaa    aaaa    aaaa    aaaa
b    .  b    .  .    c  b    c  b    c
b    .  b    .  .    c  b    c  b    c
 dddd    dddd    ....    dddd    dddd
.    f  e    f  .    f  e    f  .    f
.    f  e    f  .    f  e    f  .    f
 gggg    gggg    ....    gggg    gggg
 ```

In [1281]:
def created_digit_to_code_map(codes):
    mapping = {}
    lengths = defaultdict(list)

    for code in codes:
        lengths[len(code)].append(frozenset(code))

    # Map unqiue lengths
    mapping[1] = lengths[2][0]
    mapping[4] = lengths[4][0]
    mapping[7] = lengths[3][0]
    mapping[8] = lengths[7][0]

    # 0, 6, 9
    for code in lengths[6]:
        # 6 does not contain 1
        if not mapping[1].issubset(code):
            mapping[6] = code
        # 9 containes 4
        elif mapping[4].issubset(code):
            mapping[9] = code
        # must be 0
        else:
            mapping[0] = code

    # 2, 3, 5
    for code in lengths[5]:
        # 3 contains 1
        if mapping[1].issubset(code):
            mapping[3] = code
        # 5 is a subset of 9
        elif code.issubset(mapping[9]):
            mapping[5] = code
        # must be 2
        else:
            mapping[2] = code
    return mapping


def build_mapping(mappings: [[str], [str]]):
    total = 0
    for inputs, outputs in mappings:
        mapping = created_digit_to_code_map(inputs)

        code_to_num = dict((chars, val) for val, chars in mapping.items())
        digits = [code_to_num[frozenset(out)] for out in outputs]
        num = 0
        for out in outputs:
            num *= 10
            num += code_to_num[frozenset(out)]
        total += num
    return total


build_mapping(data8)

1020159

In [1282]:
assert _ == 1020159, 'Day 8.2'

## [Day 9](https://adventofcode.com/2021/day/9)

In [1283]:
def find_low_points(grid: [[int]]) -> [(int, int)]:
    low_points = []

    for y, row in enumerate(grid):
        for x, cell in enumerate(row):
            for (nx, ny) in neighbours4(x, y):
                if nx < 0 or ny < 0 or nx >= len(row) or ny >= len(grid):
                    continue
                if grid[ny][nx] <= cell:
                    break
            else:
                low_points.append((x, y))
    return low_points


def calclualate_risk(grid: [[int]]) -> int:
    low_points = find_low_points(grid)
    return sum(grid[y][x] + 1 for (x, y) in low_points)


def parse_data(line):
    return mapt(int, line.strip())


data9 = Input(9, parse_data)
calclualate_risk(data9)

603

In [1284]:
assert _ == 603, 'Day 9.1'

In [1285]:
def find_basins(grid: [[int]]) -> [[(int, int)]]:
    roots = find_low_points(grid)

    basins = []

    xmax = len(grid[0])
    ymax = len(grid)

    for root in roots:
        basin = [root]
        to_explore = [root]

        while to_explore:
            x, y = to_explore.pop()

            for (nx, ny) in neighbours4(x, y):
                if nx < 0 or ny < 0 or nx >= xmax or ny >= ymax:
                    # out of bounds
                    continue

                if grid[ny][nx] == 9:
                    # hit a high point, and cannot continue
                    continue

                if (nx, ny) in basin:
                    # seen before, can ignore this time
                    continue
                to_explore.append((nx, ny))
                basin.append((nx, ny))

        basins.append(basin)
    return basins


def avoid_largest_basins(grid: [[int]]) -> int:
    basins = find_basins(grid)
    basins = sorted(basins, key=lambda b: len(b), reverse=True)
    return len(basins[0]) * len(basins[1]) * len(basins[2])


avoid_largest_basins(data9)

786780

## [Day 10](https://adventofcode.com/2021/day/10)

Using a stack here makes life a lot simpler, it's a classic parsing technique.

In [1286]:
def is_valid_line(line: str) -> (bool, str):
    """
    Checks if line is valid, if it is return True and any missing
    characters, otherwise False and the first offending character.
    """
    stack = []

    brackets = {"[": "]", "(": ")", "<": ">", "{": "}"}
    close_brackets = dict((val, key) for key, val in brackets.items())

    for char in line:
        if char in brackets:
            # new opening bracket
            stack.append(char)
            continue
        if stack[-1] != close_brackets[char]:
            # got a different close bracket, this is an error
            return False, char
        stack.pop()
    return True, "".join(brackets[k] for k in stack[::-1])


def find_syntax_error_score(lines):
    scores = {
        ")": 3,
        "]": 57,
        "}": 1197,
        ">": 25137,
    }
    total_score = 0
    for line in lines:
        is_valid, error_char = is_valid_line(line)
        if not is_valid:
            total_score += scores[error_char]
    return total_score


data10 = Input(10)
find_syntax_error_score(data10)

464991

In [1287]:
assert _ == 464991, 'Day 10.1'

In [1288]:
def calculate_autocomplete_score(chars):
    scores = {
        ")": 1,
        "]": 2,
        "}": 3,
        ">": 4,
    }
    score = 0

    for char in chars:
        score *= 5
        score += scores[char]
    return score


def find_autocomplete_score(lines):
    scores = []

    for line in lines:
        is_valid, missing_chars = is_valid_line(line)
        if is_valid:
            scores.append(calculate_autocomplete_score(missing_chars))

    sorted_scores = sorted(scores)
    mid_index = int(len(sorted_scores) / 2)
    return sorted_scores[mid_index]


find_autocomplete_score(data10)

3662008566

In [1289]:
assert _ == 3662008566, 'Day 10.2'

## [Day 11](https://adventofcode.com/2021/day/11)

Today's was quite fun! The only real difficulty was reading the question.. using `>= 9` held me up for much longer than it should have!

In [1290]:
def advance_octopus_grid(octopus_grid: [[int]]) -> ([[int]], int):
    """
    Advances all octopuses state.

    Returns a new grid of octopusses and the number that flashed during
    this step.
    """
    octopus_grid = [l.copy() for l in octopus_grid]
    # first advance all values
    remaining_flashes = set()
    for x in range(10):
        for y in range(10):
            octopus_grid[y][x] += 1

            if octopus_grid[y][x] > 9:
                remaining_flashes.add((x, y))

    did_flash = set()
    while remaining_flashes:
        x, y = remaining_flashes.pop()

        if (x, y) in did_flash:
            # this octopus has flashed, ignore
            continue

        for nx, ny in neighbours8(x, y):
            if nx < 0 or nx > 9 or ny < 0 or ny > 9:
                # out of bounds, ignore
                continue

            octopus_grid[ny][nx] += 1

            if octopus_grid[ny][nx] > 9:
                remaining_flashes.add((nx, ny))

        did_flash.add((x, y))

    # For all that flashed, set their energy to 0
    for x, y in list(did_flash):
        octopus_grid[y][x] = 0

    return octopus_grid, len(did_flash)


def count_flashes(octopus_grid: [[int]], n: int) -> int:
    total_flashes = 0

    for _ in range(n):
        octopus_grid, flashes = advance_octopus_grid(octopus_grid)
        total_flashes += flashes
    return total_flashes


data11 = """1443668646
7686735716
4261576231
3361258654
4852532611
5587113732
1224426757
5155565133
6488377862
8267833811""".split(
    "\n"
)

data11 = [list(mapt(int, l)) for l in data11]

count_flashes(data11, 100)

1743

In [1291]:
assert _ == 1743

In [1292]:
def find_first_simultaneous_flash(octopus_grid: [[int]]) -> int:
    for step in count(1):
        octopus_grid, flashes = advance_octopus_grid(octopus_grid)
        if flashes == 100:
            break
    return step


find_first_simultaneous_flash(data11)

364

In [1293]:
assert _ == 364

## [Day 12](https://adventofcode.com/2021/day/12)

First graph question! Below uses an adjacency list to keep track of connected nodes, and performs a depth-first-search on the tree, recursively exploring a single path until it terminates and tracking branches as we go. There is undoubtedly an efficient technique for reusing explored paths, which I may come and rewrite at some point.

In [1294]:
def find_all_paths(tree, multi_visit=False, node="start", current_path=None):
    if current_path is None:
        current_path = ["start"]

    if node == "end":
        yield current_path

    for child_node in tree[node]:
        path = current_path[:]

        if child_node.islower() and child_node in path:
            if not multi_visit:
                # cannot revisit a small cave
                continue
            if current_path[0] == "+":
                # already visited a small cave in this path, cannot revisit
                continue
            if child_node in ("start", "end"):
                # do not allow visiting end roots more than once
                continue
            # mark this path to say we've visited a small cave already
            path.insert(0, "+")

        path.append(child_node)

        yield from find_all_paths(tree, multi_visit, child_node, path)


def count_paths(tree, multi_visit=False):
    path_count = 0
    for path in find_all_paths(tree, multi_visit):
        path_count += 1
    return path_count


def parse_input(lines):
    nodes = defaultdict(list)

    for line in lines:
        left, right = line.strip().split("-")
        # connections are bi-directional
        nodes[left].append(right)
        nodes[right].append(left)

    return nodes

In [1295]:
data12 = Input(12, parse_input, whole_file=True)
count_paths(data12, False)

3887

In [1296]:
assert _ == 3887, 'Day 12.1'

In [1297]:
%%time
count_paths(data12, True)

CPU times: user 1.81 s, sys: 4.87 ms, total: 1.82 s
Wall time: 1.82 s


104834

In [1298]:
assert _ == 104834, 'Day 12.2'

## [Day 13: Transparent Origami](https://adventofcode.com/2021/day/13)

Today's problem wasn't too tricky, the only thing to recognise here was how to perform the folding of the grid. The question tells us that we're always folding UP or LEFT, so we only need to move points that are higher in x or y respectively.

In [1299]:
def perform_fold(grid, axis, fold_point):
    # which points are being flipped, x or y?
    index = 1 if axis == "y" else 0

    new_grid = set()
    for x, y in grid:
        new_point = [x, y]

        # we only care about points on the other side of the fold
        if new_point[index] > fold_point:
            # as we're folding LEFT or UP, this is negative
            distance_to_fold = new_point[index] - fold_point
            new_point[index] -= 2 * distance_to_fold
        new_grid.add(tuple(new_point))
    return list(new_grid)


def fold_grid(grid, folds):
    for fold in folds:
        axis, fold_point = fold
        grid = perform_fold(grid, axis, fold_point)
    return grid


def print_grid(grid):
    max_x = max(x for x, _ in grid)
    max_y = max(y for _, y in grid)

    for y in range(max_y + 1):
        row = []
        for x in range(max_x + 1):
            char = "#" if (x, y) in grid else "."
            row.append(char)
        print("".join(row))


def count_points_after_folds(grid, folds, display_output=False):
    grid = fold_grid(grid, folds)
    if display_output:
        print_grid(grid)
    return len(grid)


def parse_input(lines):
    grid = []
    folds = []

    for line in lines:
        line = line.strip()

        if not line:
            continue

        if "fold" in line:
            left_chunk, value = line.split('=')
            axis = left_chunk[-1]
            folds.append((axis, int(value)))
        else:
            x, y = line.split(",")
            point = (int(x), int(y))
            grid.append(point)
    return grid, folds


data13 = Input(13, parse_input, whole_file=True)
count_points_after_folds(data13[0], data13[1][:1])

602

In [1300]:
assert _ == 602, "Day 13.1"

In [1301]:
count_points_after_folds(data13[0], data13[1], True)

.##...##..####...##.#..#.####..##..#..#
#..#.#..#.#.......#.#..#....#.#..#.#.#.
#....#..#.###.....#.####...#..#....##..
#....####.#.......#.#..#..#...#....#.#.
#..#.#..#.#....#..#.#..#.#....#..#.#.#.
.##..#..#.#.....##..#..#.####..##..#..#


92

In [1302]:
assert _ == 92, "Day 13.2"

## [Day 14](https://adventofcode.com/2021/day/14)

Another problem that cannot be brute forced, as the memory complexity gets out of hand quite quickly. I started with a simple recursive solution, not knowing that the state of the overall polymer wouldn't be required for the second part... 

In [1314]:
def grow_poylmer(polymer: str, rules: dict[str, str]) -> str:
    pairs = [a + b for a, b in zip(polymer, polymer[1:])]
    new_polymer = [polymer[0]]

    for pair in pairs:
        if pair in rules:
            new_polymer.append(rules[pair])
        new_polymer.append(pair[1])
    return "".join(new_polymer)


def polymerization(polymer: str, rules: dict[str, str]) -> int:
    for _ in range(10):
        polymer = grow_poylmer(polymer, rules)
    chars = Counter(polymer)
    counts = chars.most_common()
    return counts[0][1] - counts[-1][1]


def parse_input(lines: [str]) -> (str, dict[str, str]):
    polymer, _, *lines = lines
    rules = {}
    for line in lines:
        pair, char = line.strip().split(" -> ")
        rules[pair] = char
    return polymer, rules


data14 = Input(14, parse_input, whole_file=True)

L = """NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C""".split('\n')

data14 = parse_input(L)

polymerization(*data14)

3800

In [1304]:
assert _ == 2891, 'Day 14.1'

AssertionError: Day 14.1

Similar to the lantern fish from day 6, the brute force solution simply will not work for the second part. Again the key observation is that the order of what is growing is not required for the result, so we only need to track the pair and character frequencies.

In [1312]:
def polymerization_count(polymer: str, rules: dict[str, str], n: int=40) -> int:
    # rather than keep track of the entire polymer, instead track the growth
    pair_counts = defaultdict(int)
    char_counts = defaultdict(int)

    # initalise our tracking dicts
    for a, b in zip(polymer, polymer[1:]):
        pair_counts[a + b] += 1

    for char in polymer:
        char_counts[char] += 1

    for _ in range(n):
        current_occurances = list(pair_counts.items())
        for pair, occurance in current_occurances:
            if pair in rules:
                new_char = rules[pair]
                # all current pairs are broken...
                pair_counts[pair] -= occurance
                # ...and new ones created
                pair_counts[pair[0] + new_char] += occurance
                pair_counts[new_char + pair[1]] += occurance

                char_counts[new_char] += occurance

    return max(char_counts.values()) - min(char_counts.values())



polymerization_count(*data14)

5174673055076

In [None]:
assert _ == 4607749009683, 'Day 14.2'

## [Day 15: Chiton](https://adventofcode.com/2021/day/15)

Our first offical BFS! As we know the distance to the target state (although not the cost), we our hueristic is simply the manhatten distance from the current location to the target. 

In [None]:
def find_cheapest_path(grid):
    start = (0, 0)

    max_y = len(grid) - 1
    max_x = len(grid[0]) - 1

    def h_func(state):
        x, y = state
        return max_y - y + max_x - x

    def moves(state):
        x, y = state
        for nx, ny in neighbours4(x, y):
            if nx < 0 or ny < 0 or nx > max_x or ny > max_y:
                continue
            yield (nx, ny)

    def cost(old_state, new_state):
        ox, oy = old_state
        x, y = new_state
        return grid[oy][ox] + grid[y][x]

    cheapest_path = a_star(start, h_func, moves, cost)
    total_cost = sum(grid[y][x] for (x, y) in cheapest_path)
    return total_cost - grid[0][0]


data15 = Input(15, lambda l: mapt(int, l.strip()))
find_cheapest_path(data15)

In [None]:
assert _ == 393, "Day 15.1"

In [None]:
def find_cheapest_path_large(tile):
    # we need to build a larger grid...
    x_lim = len(tile[0])
    y_lim = len(tile)

    grid = []

    for _ in range(y_lim * 5):
        grid.append([None] * x_lim * 5)

    for y_step, y_start in enumerate(range(0, y_lim * 5, y_lim)):
        for x_step, x_start in enumerate(range(0, x_lim * 5, x_lim)):
            for dy in range(y_lim):
                for dx in range(x_lim):
                    x = x_start + dx
                    y = y_start + dy

                    original_value = tile[dy][dx]
                    new_value = original_value + x_step + y_step
                    if new_value > 9:
                        new_value = new_value % 10 + 1
                    grid[y][x] = new_value
    # now the grid has been built, find the solution!
    return find_cheapest_path(grid)


find_cheapest_path_large(data15)

In [None]:
assert _ == 2823, 'Day 15.2'

## Day 16

In [None]:
Packet = namedtuple("Packet", "version,type_id,children,val")


def hex_to_binary(hex_str):
    to_hex = lambda c: format(int(c, 16), "b").rjust(4, "0")
    return "".join(map(to_hex, hex_str))


def unpack_literal_value(packet_stream):
    unpacked_binary = ""
    while True:
        chunk = packet_stream.read(5)
        more_bit, body = chunk[0], chunk[1:]
        unpacked_binary += body
        if int(more_bit) == 0:
            break
    return int(unpacked_binary, 2)


def read_packet(packet_stream):
    version_str = packet_stream.read(3)

    if version_str == "":
        # reached the end of the stream
        return

    version = int(version_str, 2)
    type_id = int(packet_stream.read(3), 2)
    sub_packets = []
    val = None

    if type_id == 4:
        # the packet is a literal value
        val = unpack_literal_value(packet_stream)
    else:
        # this is some kind of operator
        length_type_id = int(packet_stream.read(1), 2)
        if length_type_id == 0:
            # the next 15 bits are sub packets, create a new stream and consume
            sub_packet_length = int(packet_stream.read(15), 2)
            val = sub_packet_length

            substream = io.StringIO(packet_stream.read(val))
            while True:
                packet = read_packet(substream)
                if not packet:
                    break
                sub_packets.append(packet)
        else:
            # the next N packets are subpackets
            val = int(packet_stream.read(11), 2)
            for _ in range(val):
                packet = read_packet(packet_stream)
                sub_packets.append(packet)

    return Packet(version, type_id, sub_packets, val)


def find_transmission_version_total(transmission):
    binary = hex_to_binary(transmission)
    stream = io.StringIO(binary)

    def sum_versions(packet):
        total = packet.version
        total += sum(map(sum_versions, packet.children))
        return total

    packet = read_packet(stream)
    return sum_versions(packet)


data15 = "220D6448300428021F9EFE668D3F5FD6025165C00C602FC980B45002A40400B402548808A310028400C001B5CC00B10029C0096011C0003C55003C0028270025400C1002E4F19099F7600142C801098CD0761290021B19627C1D3007E33C4A8A640143CE85CB9D49144C134927100823275CC28D9C01234BD21F8144A6F90D1B2804F39B972B13D9D60939384FE29BA3B8803535E8DF04F33BC4AFCAFC9E4EE32600C4E2F4896CE079802D4012148DF5ACB9C8DF5ACB9CD821007874014B4ECE1A8FEF9D1BCC72A293A0E801C7C9CA36A5A9D6396F8FCC52D18E91E77DD9EB16649AA9EC9DA4F4600ACE7F90DFA30BA160066A200FC448EB05C401B8291F22A2002051D247856600949C3C73A009C8F0CA7FBCCF77F88B0000B905A3C1802B3F7990E8029375AC7DDE2DCA20C2C1004E4BE9F392D0E90073D31634C0090667FF8D9E667FF8D9F0C01693F8FE8024000844688FF0900010D8EB0923A9802903F80357100663DC2987C0008744F8B5138803739EB67223C00E4CC74BA46B0AD42C001DE8392C0B0DE4E8F660095006AA200EC198671A00010E87F08E184FCD7840289C1995749197295AC265B2BFC76811381880193C8EE36C324F95CA69C26D92364B66779D63EA071008C360098002191A637C7310062224108C3263A600A49334C19100A1A000864728BF0980010E8571EE188803D19A294477008A595A53BC841526BE313D6F88CE7E16A7AC60401A9E80273728D2CC53728D2CCD2AA2600A466A007CE680E5E79EFEB07360041A6B20D0F4C021982C966D9810993B9E9F3B1C7970C00B9577300526F52FCAB3DF87EC01296AFBC1F3BC9A6200109309240156CC41B38015796EABCB7540804B7C00B926BD6AC36B1338C4717E7D7A76378C85D8043F947C966593FD2BBBCB27710E57FDF6A686E00EC229B4C9247300528029393EC3BAA32C9F61DD51925AD9AB2B001F72B2EE464C0139580D680232FA129668"
find_transmission_version_total(data15)

In [None]:
assert _ == 1012, 'Day 15.1'

In [None]:
def perform_transmssion(transmission):
    binary = hex_to_binary(transmission)
    stream = io.StringIO(binary)    
    root_packet = read_packet(stream)
    
    def unpack_packet(packet):
        unpacked_children = [unpack_packet(child) for child in packet.children]
        type_operations = {
            0: lambda: sum(unpacked_children),
            1: lambda: reduce(op.mul, unpacked_children),
            2: lambda: min(unpacked_children),
            3: lambda: max(unpacked_children),
            4: lambda: packet.val,
            5: lambda: int(op.gt(*unpacked_children)),
            6: lambda: int(op.lt(*unpacked_children)),
            7: lambda: int(op.eq(*unpacked_children))
        }
        return type_operations[packet.type_id]()    
    return unpack_packet(root_packet)
            
    

perform_transmssion(data15)

In [None]:
assert _ == 2223947372407, 'Day 16.2'

## [Day 17: Trick Shot](https://adventofcode.com/2021/day/17)

This is a really nasty solution that just brute-forces the solution by testing a wide range of possible values. Absolutely zero elegance here at all and probably my least favourite solution to AoC ever!

In [None]:
def exhaust_until_miss(dx, dy, target):
    x = y = max_y = 0
    for _ in range(1000):
        x += dx
        y += dy

        max_y = max(max_y, y)

        if (
            target[0][0] <= x
            and x <= target[0][1]
            and target[1][0] <= y
            and y <= target[1][1]
        ):
            return max_y
        elif x > target[0][1] or y < target[1][0]:
            # have overshot, can break early
            break

        dy -= 1
        dx -= dx > 0 - dx < 0

    return -1


def trickshot(target: [[int, int], [int, int]]) -> int:
    winner = -1
    total = 0

    # we know the target is max -248 in y, so any more and we immediately overshoot
    for dy in range(-250, 1000):
        has_hit = False
        # we know that the target is to the right, so there's no point looking at negative values of dx
        for dx in range(1000):
            max_y = exhaust_until_miss(dx, dy, target)
            lands_in_target = max_y >= 0
            winner = max(winner, max_y)
            total += lands_in_target

            if not has_hit and lands_in_target:
                has_hit = True
            elif has_hit and not lands_in_target:
                # we will never land in the target for further values of dx
                break

    return winner, total


data17 = ((29, 73), (-248, -194))

answers = trickshot(data17)
print(f"Part 1: {answers[0]}")
print(f"Part 2: {answers[1]}")

## [Day 18: Snailfish](https://adventofcode.com/2021/day/18)


What a problem! This was really fun once I spotted the continual pre-order traversal, but it took me quite a while to actually conceptualise that.

In [None]:
class SnailFishNode(Node):
    def __add__(self, other: SnailFishNode) -> SnailFishNode:
        """
        Support a simple node1 + node2 operation.
        """
        root = SnailFishNode(name=None)

        root.add_child(self.copy())
        root.add_child(other.copy())

        # keep reducing the tree until no changes occur
        while root._reduce():
            pass
        return root

    def __str__(self) -> str:
        """Utility for viewing the tree as Advent of Code displays it."""
        if self.is_leaf:
            return str(self.name)
        return "[" + ",".join(str(child) for child in self.children) + "]"

    def copy(self) -> SnailFishNode:
        """Create a fresh copy of this SnailFishNode."""
        new_node = SnailFishNode(self.name)

        for child in self.children:
            new_node.add_child(child.copy())

        return new_node

    def _reduce(self) -> bool:
        """
        Perform a single reduction and return if was succesful or not.
        """
        for node in visit_pre_order(self):
            if node.depth == 4 and not node.is_leaf:
                node._explode()
                return True

        for node in visit_pre_order(self):
            if node.is_leaf and node.name >= 10:
                node._split()
                return True

        return False

    def _explode(self) -> None:
        nodes = visit_pre_order(self.root)

        previous_leaf_node = None
        next_leaf_node = None

        for node in nodes:
            if node == self:
                # find the next leaf that are NOT under this node
                # advance past the two children
                next(nodes)
                next(nodes)

                for node in nodes:
                    if node.is_leaf:
                        next_leaf_node = node
                        break
                break

            if node.is_leaf:
                previous_leaf_node = node

        # add current children to nearest neighbours...
        if previous_leaf_node:
            previous_leaf_node.name += self.children[0].name
        if next_leaf_node:
            next_leaf_node.name += self.children[1].name
        # ...and update current node
        self.name = 0
        self.children = []

    def _split(self) -> None:
        value = self.name
        left = SnailFishNode(math.floor(value / 2))
        right = SnailFishNode(math.ceil(value / 2))

        self.children = [left, right]
        self.name = None

    @property
    def magnitude(self) -> int:
        total = 0

        if self.is_leaf:
            return self.name

        total += 3 * self.children[0].magnitude
        total += 2 * self.children[1].magnitude

        return total


def visit_pre_order(node: Node) -> Iterator[Node]:
    """Visit a Tree in pre-order.
    
      A
     / \
    B   C
    
    A -> B -> C
    """
    yield node
    for child in node.children:
        yield from visit_pre_order(child)


def snailfish_to_tree(snailfish_number, parent=None):
    if parent is None:
        parent = SnailFishNode(None)

    for child in snailfish_number:
        if isinstance(child, list):
            child_node = SnailFishNode(None)
            snailfish_to_tree(child, child_node)
        else:
            child_node = SnailFishNode(child)
        parent.add_child(child_node)

    return parent


# hmmm using eval to parse test input. Lovely.
parse_input = lambda l: snailfish_to_tree(eval(l))
data18 = Input(18, parse_input)


def reduce_trees(trees, verbose=True) -> SnailFishNode:
    reduced_tree = trees[0]
    for tree in trees[1:]:
        if verbose:
            print(f"  {t}")
            print(f"+ {tree}")

        reduced_tree += tree

        if verbose:
            print(f"= {reduced_tree}\n")
    return reduced_tree

Lets check that the reduction is being applied correctly..

In [None]:
test_input = """[[[0,[4,5]],[0,0]],[[[4,5],[2,6]],[9,5]]]
[7,[[[3,7],[4,3]],[[6,3],[8,8]]]]
[[2,[[0,8],[3,4]]],[[[6,7],1],[7,[1,6]]]]
[[[[2,4],7],[6,[0,5]]],[[[6,8],[2,8]],[[2,1],[4,5]]]]
[7,[5,[[3,8],[1,4]]]]
[[2,[2,2]],[8,[8,1]]]
[2,9]
[1,[[[9,3],9],[[9,0],[0,7]]]]
[[[5,[7,4]],7],1]
[[[[4,2],2],6],[8,7]]""".split('\n')

reduce_trees(mapt(parse_input, test_input), True).magnitude

In [None]:
reduce_trees(data18, False).magnitude

In [None]:
assert _ == 3793, 'Day 18.1'

In [None]:
%%time
def find_largest_sum(trees):
    winner =  0
    
    for idx, left in enumerate(trees[:-2]):
        for right in trees[idx + 1:]:
            winner = max(winner, (left + right).magnitude)
            winner = max(winner, (right + left).magnitude)
    return winner

find_largest_sum(data18)

In [None]:
assert _ == 4695, 'Day 18.2'

## [Day 19: Beacon Scanner](https://adventofcode.com/2021/day/19)

Well this one was pretty horrific. Rather than cycle through each permutation I did something like this:

1. Each scanner has a series of beacons it’s scanned. For each beacon find the manhatten distance to each other beacon. This gives us an orientation agnostic layout.
2. Compare these distances between different scanners, those that share >= 12 distances overlap.
3. Of those that overlap, compare the  points that share a distance. Treat one as fixed, and inspect the other, identify the correct order of the points against the first and the signs those points need.
4. Update all points in the second with the new order and signs
5. The absolute location of this scanner can be found by comparing it to the first.


In [None]:
%%time
class Scanner:
    def __init__(self, location):
        self._points = []
        self.location = location

    @property
    def points(self):
        return tuple(self._points)

    @points.setter
    def points(self, points):
        self._points = points

    def add_point(self, scan: [int, int, int]):
        self._points.append(scan)

    @property
    def distances(self):
        """Manhatten distances from one point to every other point."""
        for idx, point in enumerate(self.points):
            distances = []
            for other_point in self.points:
                dist = manhatten_distance(point, other_point)
                distances.append(dist)
            yield idx, distances

    def find_overlaps(self, other_scanner):
        for scan1_point_idx, distances in self.distances:
            for scan2_point_idx, other_distances in other_scanner.distances:
                overlaps = set(distances) & set(other_distances)
                if len(overlaps) >= 12:
                    yield scan1_point_idx, scan2_point_idx, overlaps


def find_scanner_locations(scanners):
    scanners[0].location = (0, 0, 0)

    # loop through each known scanner, looking for overlaps until
    # we've found them all
    while [s for s in scanners if s.location is None]:
        for scanner_idx, scanner in enumerate(scanners):
            if scanner.location is None:
                # don't know where this one is, so can't use it to find another..
                continue

            for other_scanner_idx, other_scanner in enumerate(scanners):
                if other_scanner.location:
                    # already know this one's location, can ignore
                    continue

                overlaps = list(scanner.find_overlaps(other_scanner))
                if overlaps:
                    # these scanners intersect, we can reorient one to the other
                    relative_location = reorient_scanner(
                        scanner, other_scanner, overlaps
                    )
                    other_scanner.location = add_points(
                        scanner.location, relative_location
                    )


def reorient_scanner(scanner, other_scanner, overlaps):
    # we have to beacons and a set of distances that the overlap
    # to with in their own scanners. if we convert the distances back
    # to the beacons they represent, we can find the ordering
    for scan1_point_idx, scan2_point_idx, overlap in overlaps:
        diffs1 = []
        diffs2 = []

        # find the points that share this distance
        beacon_1 = scanner.points[scan1_point_idx]
        beacon_2 = other_scanner.points[scan2_point_idx]

        for distance in overlap:
            if distance == 0:
                continue

            for p in scanner.points:
                if manhatten_distance(beacon_1, p) == distance:
                    diffs1.append(sub_points(beacon_1, p))
                    break
            for p in other_scanner.points:
                if manhatten_distance(beacon_2, p) == distance:
                    diffs2.append(sub_points(beacon_2, p))
                    break

        # we have two coordinates that should share the same absolute values in different
        # orders and signs
        for diff1, diff2 in zip(diffs1, diffs2):
            ordering = []
            signs = []

            for p in diff1:
                try:
                    position = [abs(d) for d in diff2].index(abs(p))
                except ValueError:
                    # not sure why this happens.. it shouldn't :|
                    continue
                ordering.append(position)
                signs.append(p // diff2[position])

            # make sure that we've seen three unique values
            if len(set(ordering)) < 3:
                continue

            break
        break

    # we should have the ordering and signs of the second points in relation to the first
    # so let's find the relative location of the second scanner and update it's points
    new_points = []
    for p in other_scanner.points:
        new_points.append(tuple(p[idx] * sign for idx, sign in zip(ordering, signs)))
    other_scanner.points = new_points

    relative_location = sub_points(
        scanner.points[scan1_point_idx], other_scanner.points[scan2_point_idx]
    )
    return relative_location


def manhatten_distance(p1, p2):
    return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) + abs(p1[2] - p2[2])


def sub_points(p1, p2):
    return tuple(q1 - q2 for q1, q2 in zip(p1, p2))


def add_points(p1, p2):
    return tuple(q1 + q2 for q1, q2 in zip(p1, p2))


def parse_input(lines):
    scanners = []
    for line in lines:
        if not line.strip():
            continue
        if "scanner" in line:
            scanners.append(Scanner(None))
            continue
        point = tuple(int(c) for c in line.split(","))
        scanners[-1].add_point(point)
    return scanners


data19 = Input(19, parse_input, True)


def find_beacons(scanners):
    find_scanner_locations(scanners)
    beacons = set()
    for scanner in scanners:
        for point in scanner.points:
            beacons.add(add_points(point, scanner.location))
    beacons = sorted(list(beacons))
    return beacons


def count_beacons(scanners):
    beacons = find_beacons(scanners)
    return len(beacons)


count_beacons(data19)

In [None]:
assert _ == 457, 'Day 19.1'

In [None]:
def find_largest_distance(scanners):
    locations = [s.location for s in scanners]
    largest = 0
    for i, s1 in enumerate(scanners):
        for s2 in scanners[i + 1 :]:
            largest = max(largest, manhatten_distance(s1.location, s2.location))
    return largest


find_largest_distance(data19)

In [None]:
assert _ == 13243, 'Day 19.2'

## [Day 20: Trench Map](https://adventofcode.com/2021/day/20)

This one was mean.

The insight here is the key itself. For our input, the zeroth bit is ON, and the 511th bit OFF. This means that every value in our infinite grid that doesn't touch our active area will translate to ON. So, the first time the area expands, every point will along the boundary be ON. Now, the next time we expand, every bit that that contributes to a value will be ON, or the 511th bit in our key. Thus, it will translate to OFF.

This just means that every time we expand, the border is ON, then the next time, it is OFF.

In [None]:
def print_image(on_pixels):
    x_min, x_max, y_min, y_max = limits(on_pixels)
    for y in range(y_min - 1, y_max + 2):
        for x in range(x_min - 1, x_max + 2):
            c = "#" if (x, y) in on_pixels else "."
            print(c, end="")
        print("")
    print("")


def limits(pixels):
    return (
        min(p[0] for p in pixels),
        max(p[0] for p in pixels),
        min(p[1] for p in pixels),
        max(p[1] for p in pixels),
    )


def enhance_image(on_pixels, key, iteration):
    new_on_pixels = set()

    x_min, x_max, y_min, y_max = limits(on_pixels)

    modifable_area = {
        (x, y) for x in range(x_min - 1, x_max + 2) for y in range(y_min - 1, y_max + 2)
    }
    active_area = {
        (x, y) for x in range(x_min, x_max + 1) for y in range(y_min, y_max + 1)
    }

    boundary_bit = key[0] if iteration % 2 == 1 else "."

    for x, y in modifable_area:
        offset = read_image_int(on_pixels, x, y, active_area, boundary_bit)
        new_pixel = key[offset]
        if new_pixel == "#":
            new_on_pixels.add((x, y))
    return new_on_pixels


def get_neighbours(x, y):
    neighbours = list(neighbours8(x, y))
    neighbours.insert(4, (x, y))
    return neighbours


def read_image_int(on_pixels, x, y, active_area, boundary_bit):
    chars = ""
    for p in get_neighbours(x, y):
        if boundary_bit == "#":
            c = "1" if p in on_pixels or p not in active_area else "0"
        else:
            c = "1" if p in on_pixels else "0"
        chars += c

    return int(chars, 2)


def parse_input(lines):
    key, _, *pixels = lines
    on_pixels = set()
    for y, line in enumerate(pixels):
        for x, char in enumerate(line):
            if char == "#":
                on_pixels.add((x, y))
    return on_pixels, key


data20 = Input(20, parse_input, True)
pixels, key = data20


def enhance_n_times(pixels, key, n):
    for i in range(n):
        pixels = enhance_image(pixels, key, i)
    return len(pixels)


enhance_n_times(pixels, key, 2)

In [None]:
assert _ == 5583, 'Day 20.1'

In [None]:
enhance_n_times(pixels, key, 50)

In [None]:
assert _ == 19592, 'Day 20.2'