# [Advent of Code 2021](https://adventofcode.com/2021)

# The toolbox


Generalised pieces of code that either can be used in multiple questions or that simply makes understand the implementation easier.

In [3]:
from collections import defaultdict, Counter, namedtuple
from itertools import chain, count
import operator


def Input(day, parser=str.strip, whole_file=False):
    "Fetch the data input from disk."
    filename = f"../data/advent2021/input{day}.txt"
    with open(filename) as fin:
        if whole_file:
            return parser(fin)
        return mapt(parser, fin)


def mapt(fn, *args):
    "Do a map, and convert the results to a tuple"
    return tuple(map(fn, *args))


avg = lambda n: sum(n) / len(n)


NEIGHBOUR8_DELTAS = (
    (-1, -1), ( 0, -1), (1, -1),
    (-1,  0),           (1,  0),
    (-1,  1), ( 0,  1), (1,  1),
)

NEIGHBOUR4_DELTAS = (
              ( 0, -1),
    (-1,  0),           (1,  0),
              ( 0,  1),
)


def neighbours4(x, y):
    return tuple((x + dx, y + dy) for dx, dy in NEIGHBOUR4_DELTAS)


def neighbours8(x, y):
    return tuple((x + dx, y + dy) for dx, dy in NEIGHBOUR8_DELTAS)

## [Day 1](https://adventofcode.com/2021/day/1)

In [4]:
def count_increasing_measurements(scans):
    return sum(scan > scans[index - 1] for index, scan in enumerate(scans[1:], 1))

In [5]:
data1 = Input(1, int)
count_increasing_measurements(data1)

1121

In [6]:
assert _ == 1121, 'Day 1.1'

In [7]:
chunks = [data1[i:i + 3] for i in range(len(data1) - 2)]
count_increasing_measurements(mapt(sum, chunks))

1065

In [8]:
assert _ == 1065, 'Day 1.2'

## [Day 2](https://adventofcode.com/2021/day/2)

In [9]:
def follow_commands(commands):
    x = d = 0
    for direction, amount in commands:
        if direction == "forward":
            x += amount
        elif direction == "down":
            d += amount
        elif direction == "up":
            d -= amount
        else:
            print("unknown command!")
    return x * d


def parse_input_2(line):
    chunks = line.split(" ")
    return chunks[0], int(chunks[1])


data2 = Input(2, parse_input_2)
follow_commands(data2)

2036120

In [10]:
assert _ == 2036120, "Day 2.1"

In [11]:
def follow_commands(commands):
    x = d = aim = 0
    for direction, amount in commands:
        if direction == "forward":
            x += amount
            d += aim * amount
        elif direction == "down":
            aim += amount
        elif direction == "up":
            aim -= amount
        else:
            print("unknown command!")
    return x * d


follow_commands(data2)

2015547716

In [12]:
assert _ == 2015547716, "Day 2.2"

## [Day 3](https://adventofcode.com/2021/day/3)

Not the prettiest, but a nice example of using `zip` to unzip the string into individual digits.

In [13]:
def find_power_consumption(report: list[list[int]]) -> int:
    digits = list(zip(*report))
    avg_bits = [round(avg(digits[bit_index])) for bit_index in range(len(digits))]

    most_common = lambda index: str(avg_bits[index])
    least_common = lambda index: str(int(not avg_bits[index]))

    gamma = "".join(mapt(most_common, range(len(digits))))
    epsilon = "".join(mapt(least_common, range(len(digits))))
    return int(gamma, 2) * int(epsilon, 2)


def parse_input(line):
    return mapt(int, line.strip())


data3 = Input(3, parse_input)
find_power_consumption(data3)

3969000

In [14]:
assert _ == 3969000, 'Day 3.1'

Part two was fun, here we continue to abuse bools and recurse through `bit_criteria` to remove digits.

In [15]:
def bit_criteria(
    report: list[list[int]], keep_most_common: bool, bit_index: int = 0
) -> int:
    if len(report) == 1:
        return report[0]

    digits = list(zip(*report))
    avg_bit = avg(digits[bit_index])

    # I can't be bothered to use Decimal ROUND UP here...
    most_common = 1 if avg_bit == 0.5 else round(avg_bit)

    winning_bit = most_common if keep_most_common else int(not most_common)

    return bit_criteria(
        [line for line in report if line[bit_index] == winning_bit],
        keep_most_common,
        bit_index + 1,
    )


def find_life_support_rating(report: list[list[int]]) -> int:
    oxygen_generator = bit_criteria(report[:], True)
    co2_scrubber = bit_criteria(report[:], False)

    to_int = lambda bits: int("".join(mapt(str, bits)), 2)

    return to_int(oxygen_generator) * to_int(co2_scrubber)


find_life_support_rating(data3)

4267809

In [16]:
assert _ == 4267809, 'Day 3.2'

## [Day 4](https://adventofcode.com/2021/day/4)

Not particularly clean, and I don't really think all the sets are required, I could have simply read out rows and cols from a larger list! Oh well, did the job.

In [17]:
BingoBoard = namedtuple("BingoBoard", "rows,cols")


def play_bingo(numbers: [int], boards: [BingoBoard], win_last: bool = False) -> int:
    # translate boards to a list of sets so we can easily remove seen numbers
    board_nums = dict(
        (boardId, [*[set(r) for r in board.rows], *[set(c) for c in board.cols]])
        for boardId, board in enumerate(boards)
    )

    # keep track of the most recent winningi board
    winning_board = None

    # track boards that are still in play
    active_boards = set(board_nums.keys())

    seen_numbers = set()
    for number in numbers:
        seen_numbers.add(number)

        for board_id in list(active_boards):
            for row_or_col_nums in board_nums[board_id]:
                row_or_col_nums -= seen_numbers
                if len(row_or_col_nums) == 0:
                    # Don't exit now as we'll need to remove this number
                    # from the row or col too!
                    winning_board = board_id

                    if board_id in active_boards:
                        active_boards.remove(board_id)

        have_a_winner = not win_last and winning_board is not None
        should_last_win = win_last and len(active_boards) == 0
        if have_a_winner or should_last_win:
            break

    if winning_board:
        remaining = set()
        for nums in board_nums[winning_board]:
            remaining |= nums
        return sum(remaining) * number

    print("No winning board found!")
    return -1


def parse_input(lines: [str]) -> [[int], [BingoBoard]]:
    numbers = []
    boards = []

    for line in lines:
        if not numbers:
            numbers = mapt(int, line.strip().split(","))
            continue

        if not line.strip():
            # add a new board
            boards.append(BingoBoard([], []))
            continue
        boards[-1].rows.append(mapt(int, line.strip().split()))

    # rotate rows into cols
    for board in boards:
        board.cols.extend((list(zip(*board.rows))))

    return (numbers, boards)


data4 = Input(4, parse_input, whole_file=True)
play_bingo(*data4)

8580

In [18]:
assert _ == 8580, "Day 4.1"

In [19]:
play_bingo(*data4, win_last=True)

9576

In [20]:
assert _ == 9576, "Day 4.2"

## [Day 5](https://adventofcode.com/2021/day/5)

Slight trick on this one was to sort the inputs so we knew we'd always be _increasing_ in x or y, makes finding the delta much simpler.

In [21]:
Point = namedtuple("Point", "x,y")


def count_overlapping_points(point_lines, horizontal_only=True):
    area = Counter()

    if horizontal_only:
        point_lines = [
            line
            for line in point_lines
            if line[0].x == line[1].x or line[0].y == line[1].y
        ]

    for start, stop in point_lines:
        if start.x == stop.x:
            delta = Point(0, 1)
        elif start.y == stop.y:
            delta = Point(1, 0)
        elif start.y < stop.y:
            delta = Point(1, 1)
        else:
            delta = Point(1, -1)

        line = [start]
        while line[-1] != stop:
            last_point = line[-1]
            next_point = Point(last_point.x + delta.x, last_point.y + delta.y)
            line.append(next_point)

        area.update(line)

    return len([val for _, val in area.items() if val >= 2])


def parse_input(line):
    chunks = line.split(" -> ")
    points = (
        Point(*mapt(int, chunks[0].split(","))),
        Point(*mapt(int, chunks[1].split(","))),
    )

    # sorting the points ensures we start from the left most one first
    return sorted(points)


data5 = Input(5, parse_input)
count_overlapping_points(data5)

6710

In [22]:
assert _ == 6710, 'Day 5.1'

In [23]:
count_overlapping_points(data5, horizontal_only=False)

20121

In [24]:
assert _ == 20121, 'Day 5.2'

## Day 6

In [29]:
def grow_lantern_fish(fish: [int], remaining_days: int):
    # first naive recursive implementation
    if remaining_days == 0:
        return fish
    
    next_generation = []
    for f in fish:
        next_val = f - 1
        # -1 here as the zeroth day counts
        if next_val == -1:
            next_generation.append(8)
            next_val = 6
        next_generation.append(next_val)
    return grow_lantern_fish(next_generation, remaining_days - 1)
        
            

data5_input = "1,5,5,1,5,1,5,3,1,3,2,4,3,4,1,1,3,5,4,4,2,1,2,1,2,1,2,1,5,2,1,5,1,2,2,1,5,5,5,1,1,1,5,1,3,4,5,1,2,2,5,5,3,4,5,4,4,1,4,5,3,4,4,5,2,4,2,2,1,3,4,3,2,3,4,1,4,4,4,5,1,3,4,2,5,4,5,3,1,4,1,1,1,2,4,2,1,5,1,4,5,3,3,4,1,1,4,3,4,1,1,1,5,4,3,5,2,4,1,1,2,3,2,4,4,3,3,5,3,1,4,5,5,4,3,3,5,1,5,3,5,2,5,1,5,5,2,3,3,1,1,2,2,4,3,1,5,1,1,3,1,4,1,2,3,5,5,1,2,3,4,3,4,1,1,5,5,3,3,4,5,1,1,4,1,4,1,3,5,5,1,4,3,1,3,5,5,5,5,5,2,2,1,2,4,1,5,3,3,5,4,5,4,1,5,1,5,1,2,5,4,5,5,3,2,2,2,5,4,4,3,3,1,4,1,2,3,1,5,4,5,3,4,1,1,2,2,1,2,5,1,1,1,5,4,5,2,1,4,4,1,1,3,3,1,3,2,1,5,2,3,4,5,3,5,4,3,1,3,5,5,5,5,2,1,1,4,2,5,1,5,1,3,4,3,5,5,1,4,3"
data5 = mapt(int, data5_input.split(','))
len(grow_lantern_fish(data5, 80))

346063

In [None]:
assert _ == 346063, 'Day 6.1'

In [31]:
def count_fish_in_state(fish: [int], days: int) -> int:
    # find the fish in each state, creating a list of length 8
    days_state = [fish.count(days_left) for days_left in range(9)]

    for _ in range(days):
        reproducing_fish = days_state[0]
        # remove the fish that have reproduced
        days_state = days_state[1:]
        # reset day 8 with new fish
        days_state.append(reproducing_fish)
        # move that have reproduced to day 6
        days_state[6] += reproducing_fish
    return sum(days_state)

count_fish_in_state(data5, 256)

1572358335990

In [34]:
assert _ == 1572358335990, 'Day 6.2'

## [Day 7](https://adventofcode.com/2021/day/7)

Here, we're looking to minimise [aggregating discrepancies](https://www.johnmyleswhite.com/notebook/2013/03/22/modes-medians-and-means-an-unifying-perspective/).

For part one, we're looking at **absolute deviation**, or that the value of the discrepancy increases linearly with the distance from it. We're seeking to minimise this value, which equates to the median value of the set.

In [35]:
def align_crab_subs(positions):
    positions = sorted(positions)
    median = positions[int(len(positions) / 2)]
    return sum(abs(pos - median) for pos in positions)


data7 = Input(7, lambda l: mapt(int, l.split(",")))[0]
align_crab_subs(data7)

341534

In [36]:
assert _ == 341534, 'Day 7.1'

Then for part two, we're told that it's the sum of integers up to a limit, the more general form of which is `sum(1..n) -> n(n+1)/2` ([wikipedia](https://en.wikipedia.org/wiki/1_%2B_2_%2B_3_%2B_4_%2B_%E2%8B%AF)). If we expand we have `n^2 + n / 2` , because `n^2 >> n` we can make an approximation that we're looking to minimise the the average square distance, which is another way of expressing the mean.

This tells us the we can expect the answer to be around the mean value - I got lucky, as it was just the integer of the mean!

In [37]:
def align_crab_subs2(positions):
    mean = int(avg(positions))
    return sum(int(abs(pos - mean) * (abs(pos - mean) + 1) / 2) for pos in positions)


def non_linear_cost(crab_position, target_position):
    distance_required = abs(crab_position - target_position)
    return sum(range(1, distance_required + 1))


align_crab_subs2(data7)

93397632

In [38]:
assert _ == 93397632, 'Day 7.2'

## [Day 8](https://adventofcode.com/2021/day/8)

In [39]:
def count_easy_digits(mappings: [[str], [str]]) -> int:
    outputs = [out for _, out in mappings]
    return sum(len(digits) in (2, 3, 4, 7) for digits in chain(*outputs))


def parse_input(line):
    inputs, outputs = line.strip().split(" | ")
    return inputs.split(" "), outputs.split(" ")


data8 = Input(8, parse_input)
count_easy_digits(data8)

362

In [40]:
assert _ == 362, "Day 8.1"

For the second part I spent a bunch of time staring at the digits and figuring out the logical combinations for them. I did toy with creating a general solution for this but it didn't seem worth it given the number of characters!

```
  0:      1:      2:      3:      4:
 aaaa    ....    aaaa    aaaa    ....
b    c  .    c  .    c  .    c  b    c
b    c  .    c  .    c  .    c  b    c
 ....    ....    dddd    dddd    dddd
e    f  .    f  e    .  .    f  .    f
e    f  .    f  e    .  .    f  .    f
 gggg    ....    gggg    gggg    ....

  5:      6:      7:      8:      9:
 aaaa    aaaa    aaaa    aaaa    aaaa
b    .  b    .  .    c  b    c  b    c
b    .  b    .  .    c  b    c  b    c
 dddd    dddd    ....    dddd    dddd
.    f  e    f  .    f  e    f  .    f
.    f  e    f  .    f  e    f  .    f
 gggg    gggg    ....    gggg    gggg
 ```

In [41]:
def created_digit_to_code_map(codes):
    mapping = {}
    lengths = defaultdict(list)

    for code in codes:
        lengths[len(code)].append(frozenset(code))

    # Map unqiue lengths
    mapping[1] = lengths[2][0]
    mapping[4] = lengths[4][0]
    mapping[7] = lengths[3][0]
    mapping[8] = lengths[7][0]

    # 0, 6, 9
    for code in lengths[6]:
        # 6 does not contain 1
        if not mapping[1].issubset(code):
            mapping[6] = code
        # 9 containes 4
        elif mapping[4].issubset(code):
            mapping[9] = code
        # must be 0
        else:
            mapping[0] = code

    # 2, 3, 5
    for code in lengths[5]:
        # 3 contains 1
        if mapping[1].issubset(code):
            mapping[3] = code
        # 5 is a subset of 9
        elif code.issubset(mapping[9]):
            mapping[5] = code
        # must be 2
        else:
            mapping[2] = code
    return mapping


def build_mapping(mappings: [[str], [str]]):
    total = 0
    for inputs, outputs in mappings:
        mapping = created_digit_to_code_map(inputs)

        code_to_num = dict((chars, val) for val, chars in mapping.items())
        digits = [code_to_num[frozenset(out)] for out in outputs]
        num = 0
        for out in outputs:
            num *= 10
            num += code_to_num[frozenset(out)]
        total += num
    return total


build_mapping(data8)

1020159

In [42]:
assert _ == 1020159, 'Day 8.2'

## [Day 9](https://adventofcode.com/2021/day/9)

In [43]:
def find_low_points(grid: [[int]]) -> [(int, int)]:
    low_points = []

    for y, row in enumerate(grid):
        for x, cell in enumerate(row):
            for (nx, ny) in neighbours4(x, y):
                if nx < 0 or ny < 0 or nx >= len(row) or ny >= len(grid):
                    continue
                if grid[ny][nx] <= cell:
                    break
            else:
                low_points.append((x, y))
    return low_points


def calclualate_risk(grid: [[int]]) -> int:
    low_points = find_low_points(grid)
    return sum(grid[y][x] + 1 for (x, y) in low_points)


def parse_data(line):
    return mapt(int, line.strip())


data9 = Input(9, parse_data)
calclualate_risk(data9)

603

In [44]:
assert _ == 603, 'Day 9.1'

In [45]:
def find_basins(grid: [[int]]) -> [[(int, int)]]:
    roots = find_low_points(grid)

    basins = []

    xmax = len(grid[0])
    ymax = len(grid)

    for root in roots:
        basin = [root]
        to_explore = [root]

        while to_explore:
            x, y = to_explore.pop()

            for (nx, ny) in neighbours4(x, y):
                if nx < 0 or ny < 0 or nx >= xmax or ny >= ymax:
                    # out of bounds
                    continue

                if grid[ny][nx] == 9:
                    # hit a high point, and cannot continue
                    continue

                if (nx, ny) in basin:
                    # seen before, can ignore this time
                    continue
                to_explore.append((nx, ny))
                basin.append((nx, ny))

        basins.append(basin)
    return basins


def avoid_largest_basins(grid: [[int]]) -> int:
    basins = find_basins(grid)
    basins = sorted(basins, key=lambda b: len(b), reverse=True)
    return len(basins[0]) * len(basins[1]) * len(basins[2])


avoid_largest_basins(data9)

786780

## [Day 10](https://adventofcode.com/2021/day/10)

Using a stack here makes life a lot simpler, it's a classic parsing technique.

In [46]:
def is_valid_line(line: str) -> (bool, str):
    """
    Checks if line is valid, if it is return True and any missing
    characters, otherwise False and the first offending character.
    """
    stack = []

    brackets = {"[": "]", "(": ")", "<": ">", "{": "}"}
    close_brackets = dict((val, key) for key, val in brackets.items())

    for char in line:
        if char in brackets:
            # new opening bracket
            stack.append(char)
            continue
        if stack[-1] != close_brackets[char]:
            # got a different close bracket, this is an error
            return False, char
        stack.pop()
    return True, "".join(brackets[k] for k in stack[::-1])


def find_syntax_error_score(lines):
    scores = {
        ")": 3,
        "]": 57,
        "}": 1197,
        ">": 25137,
    }
    total_score = 0
    for line in lines:
        is_valid, error_char = is_valid_line(line)
        if not is_valid:
            total_score += scores[error_char]
    return total_score


data10 = Input(10)
find_syntax_error_score(data10)

464991

In [47]:
assert _ == 464991, 'Day 10.1'

In [48]:
def calculate_autocomplete_score(chars):
    scores = {
        ")": 1,
        "]": 2,
        "}": 3,
        ">": 4,
    }
    score = 0

    for char in chars:
        score *= 5
        score += scores[char]
    return score


def find_autocomplete_score(lines):
    scores = []

    for line in lines:
        is_valid, missing_chars = is_valid_line(line)
        if is_valid:
            scores.append(calculate_autocomplete_score(missing_chars))

    sorted_scores = sorted(scores)
    mid_index = int(len(sorted_scores) / 2)
    return sorted_scores[mid_index]


find_autocomplete_score(data10)

3662008566

In [49]:
assert _ == 3662008566, 'Day 10.2'

## [Day 11](https://adventofcode.com/2021/day/11)

Today's was quite fun! The only real difficulty was reading the question.. using `>= 9` held me up for much longer than it should have!

In [50]:
def advance_octopus_grid(octopus_grid: [[int]]) -> ([[int]], int):
    """
    Advances all octopuses state.

    Returns a new grid of octopusses and the number that flashed during
    this step.
    """
    octopus_grid = [l.copy() for l in octopus_grid]
    # first advance all values
    remaining_flashes = set()
    for x in range(10):
        for y in range(10):
            octopus_grid[y][x] += 1

            if octopus_grid[y][x] > 9:
                remaining_flashes.add((x, y))

    did_flash = set()
    while remaining_flashes:
        x, y = remaining_flashes.pop()

        if (x, y) in did_flash:
            # this octopus has flashed, ignore
            continue

        for nx, ny in neighbours8(x, y):
            if nx < 0 or nx > 9 or ny < 0 or ny > 9:
                # out of bounds, ignore
                continue

            octopus_grid[ny][nx] += 1

            if octopus_grid[ny][nx] > 9:
                remaining_flashes.add((nx, ny))

        did_flash.add((x, y))

    # For all that flashed, set their energy to 0
    for x, y in list(did_flash):
        octopus_grid[y][x] = 0

    return octopus_grid, len(did_flash)


def count_flashes(octopus_grid: [[int]], n: int) -> int:
    total_flashes = 0

    for _ in range(n):
        octopus_grid, flashes = advance_octopus_grid(octopus_grid)
        total_flashes += flashes
    return total_flashes


data11 = """1443668646
7686735716
4261576231
3361258654
4852532611
5587113732
1224426757
5155565133
6488377862
8267833811""".split(
    "\n"
)

data11 = [list(mapt(int, l)) for l in data11]

count_flashes(data11, 100)

1743

In [51]:
assert _ == 1743

In [52]:
def find_first_simultaneous_flash(octopus_grid: [[int]]) -> int:
    for step in count(1):
        octopus_grid, flashes = advance_octopus_grid(octopus_grid)
        if flashes == 100:
            break
    return step


find_first_simultaneous_flash(data11)

364

In [53]:
assert _ == 364

## [Day 12](https://adventofcode.com/2021/day/12)

First graph question! Below uses an adjacency list to keep track of connected nodes, and performs a depth-first-search on the tree, recursively exploring a single path until it terminates and tracking branches as we go. There is undoubtedly an efficient technique for reusing explored paths, which I may come and rewrite at some point.

In [54]:
def find_all_paths(tree, multi_visit=False, node="start", current_path=None):
    if current_path is None:
        current_path = ["start"]

    if node == "end":
        yield current_path

    for child_node in tree[node]:
        path = current_path[:]

        if child_node.islower() and child_node in path:
            if not multi_visit:
                # cannot revisit a small cave
                continue
            if current_path[0] == "+":
                # already visited a small cave in this path, cannot revisit
                continue
            if child_node in ("start", "end"):
                # do not allow visiting end roots more than once
                continue
            # mark this path to say we've visited a small cave already
            path.insert(0, "+")

        path.append(child_node)

        yield from find_all_paths(tree, multi_visit, child_node, path)


def count_paths(tree, multi_visit=False):
    path_count = 0
    for path in find_all_paths(tree, multi_visit):
        path_count += 1
    return path_count


def parse_input(lines):
    nodes = defaultdict(list)

    for line in lines:
        left, right = line.strip().split("-")
        # connections are bi-directional
        nodes[left].append(right)
        nodes[right].append(left)

    return nodes

In [55]:
data12 = Input(12, parse_input, whole_file=True)
count_paths(data12, False)

3887

In [56]:
assert _ == 3887, 'Day 12.1'

In [57]:
%%time
count_paths(data12, True)

CPU times: user 1.78 s, sys: 4.05 ms, total: 1.78 s
Wall time: 1.78 s


104834

In [58]:
assert _ == 104834, 'Day 12.2'

## [Day 13: Transparent Origami](https://adventofcode.com/2021/day/13)

Today's problem wasn't too tricky, the only thing to recognise here was how to perform the folding of the grid. The question tells us that we're always folding UP or LEFT, so we only need to move points that are higher in x or y respectively.

In [101]:
def perform_fold(grid, axis, fold_point):
    # which points are being flipped, x or y?
    index = 1 if axis == "y" else 0

    new_grid = set()
    for x, y in grid:
        new_point = [x, y]

        # we only care about points on the other side of the fold
        if new_point[index] > fold_point:
            # as we're folding LEFT or UP, this is negative
            distance_to_fold = new_point[index] - fold_point
            new_point[index] -= 2 * distance_to_fold
        new_grid.add(tuple(new_point))
    return list(new_grid)


def fold_grid(grid, folds):
    for fold in folds:
        axis, fold_point = fold
        grid = perform_fold(grid, axis, fold_point)
    return grid


def print_grid(grid):
    max_x = max(x for x, _ in grid)
    max_y = max(y for _, y in grid)

    for y in range(max_y + 1):
        row = []
        for x in range(max_x + 1):
            char = "#" if (x, y) in grid else "."
            row.append(char)
        print("".join(row))


def count_points_after_folds(grid, folds, display_output=False):
    grid = fold_grid(grid, folds)
    if display_output:
        print_grid(grid)
    return len(grid)


def parse_input(lines):
    grid = []
    folds = []

    for line in lines:
        line = line.strip()

        if not line:
            continue

        if "fold" in line:
            left_chunk, value = line.split('=')
            axis = left_chunk[-1]
            folds.append((axis, int(value)))
        else:
            x, y = line.split(",")
            point = (int(x), int(y))
            grid.append(point)
    return grid, folds


data13 = Input(13, parse_input, whole_file=True)
count_points_after_folds(data13[0], data13[1][:1])

602

In [99]:
assert _ == 602, "Day 13.1"

In [100]:
count_points_after_folds(data13[0], data13[1], True)

 ##   ##  ####   ## #  # ####  ##  #  #
#  # #  # #       # #  #    # #  # # # 
#    #  # ###     # ####   #  #    ##  
#    #### #       # #  #  #   #    # # 
#  # #  # #    #  # #  # #    #  # # # 
 ##  #  # #     ##  #  # ####  ##  #  #


92

In [93]:
assert _ == 92, "Day 13.2"

## [Day 14](https://adventofcode.com/2021/day/14)

Another problem that cannot be brute forced, as the memory complexity gets out of hand quite quickly. I started with a simple recursive solution, not knowing that the state of the overall polymer wouldn't be required for the second part... 

In [174]:
def grow_poylmer(polymer, rules):
    pairs = [a + b for a, b in zip(polymer, polymer[1:])]
    new_polymer = [polymer[0]]

    for pair in pairs:
        if pair in rules:
            new_polymer.append(rules[pair])
        new_polymer.append(pair[1])
    return "".join(new_polymer)


def polymerization(polymer, rules):
    for _ in range(10):
        polymer = grow_poylmer(polymer, rules)
    chars = Counter(polymer)
    counts = chars.most_common()
    return counts[0][1] - counts[-1][1]


def parse_input(lines):
    polymer = None
    rules = {}

    for line in lines:
        line = line.strip()

        if not polymer:
            polymer = line
        elif not line:
            continue
        else:
            pair, char = line.split(" -> ")
            rules[pair] = char
    return polymer, rules


data14 = parse_input(test_input)
data14 = Input(14, parse_input, whole_file=True)
polymerization(*data14)

2891

In [120]:
assert _ == 2891, 'Day 14.1'

Similar to the lantern fish from day 6, the brute force solution simply will not work for the second part. Again the key observation is that the order of what is growing is not required for the result, so we only need to track the pair and character frequencies.

In [176]:
def polymerization_count(polymer, rules, n=40):
    # rather than keep track of the entire polymer, instead track the growth
    pair_counts = defaultdict(int)
    char_counts = defaultdict(int)

    # initalise our tracking dicts
    for a, b in zip(polymer, polymer[1:]):
        pair_counts[a + b] += 1

    for char in polymer:
        char_counts[char] += 1

    for _ in range(n):
        current_occurances = list(pair_counts.items())
        for pair, occurance in current_occurances:
            if pair in rules:
                new_char = rules[pair]
                # all current pairs are broken...
                pair_counts[pair] -= occurance
                # ...and new ones created
                pair_counts[pair[0] + new_char] += occurance
                pair_counts[new_char + pair[1]] += occurance

                char_counts[new_char] += occurance

    return max(char_counts.values()) - min(char_counts.values())


polymerization_count(*data14)

4607749009683

In [170]:
assert _ == 4607749009683, 'Day 14.2'