# Setup

## Imports

In [1]:
# builtins
from __future__ import annotations
from collections import defaultdict, deque
from dataclasses import dataclass
from enum import Enum
import functools
import itertools
import math
import operator

# own stuff
from utils import Stopwatch, chunksof

# external stuff
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
import z3

## Timing

In [2]:
stopwatch = Stopwatch()
stopwatch.start()

# Day 1

## Process input

In [None]:
with open('inputs/01') as f:
    chunks = f.read().rstrip().split('\n\n')
cals = [sum(int(n) for n in c.splitlines()) for c in chunks]

## Part 1

In [None]:
print(max(cals))
stopwatch.add_split()

## Part 2

In [None]:
print(sum(sorted(cals)[-3:]))
stopwatch.add_split()

# Day 2

## Process input

In [None]:
mapping = {
    'A': 0, 'B': 1, 'C': 2,
    'X': 0, 'Y': 1, 'Z': 2,
}

with open('inputs/02') as f:
    rounds = [(mapping[a], mapping[b]) for a, b in (line.strip().split() for line in f)]

## Part 1

In [None]:
succ = lambda a: (a + 1) % 3
pred = lambda a: (a - 1) % 3
    
def score_round(a, b):
    if a == b:
        return b + 4
    elif b == succ(a):
        return b + 7
    else:
        return b + 1

print(sum(score_round(*x) for x in rounds))
stopwatch.add_split()

## Part 2

In [None]:
def score_round(a, b):
    if b == 0:
        return pred(a) + 1
    elif b == 1:
        return a + 4
    else:
        return succ(a) + 7

print(sum(score_round(*x) for x in rounds))
stopwatch.add_split()

# Day 3

## Process input

In [None]:
def to_sack(contents):
    n = len(contents)
    assert n % 2 == 0
    i = n//2
    return set(contents[:i]), set(contents[i:])

with open('inputs/03') as f:
    sacks = [to_sack(line.rstrip()) for line in f]

## Part 1

In [None]:
def priority(item):
    if 'a' <= item <= 'z':
        return ord(item) - ord('a') + 1
    else:
        return ord(item) - ord('A') + 27

print(sum(priority(*(a & b)) for a, b in sacks))
stopwatch.add_split()

## Part 2

In [None]:
merged = (a | b for a, b in sacks)
groups = chunksof(3, merged)
badges = (a & b & c for a, b, c in groups)
print(sum(priority(*b) for b in badges))
stopwatch.add_split()

# Day 4

## Process input

In [None]:
def parse_range(r):
    start, end = r.split('-')
    return int(start), int(end)

def parse_row(r):
    a, b = r.rstrip().split(',')
    return parse_range(a), parse_range(b)

with open('inputs/04') as f:
    pairs = [parse_row(r) for r in f]

## Part 1

In [None]:
def intersect(a, b):
    (x, y) = a
    (z, w) = b
    overlap = x <= w and y >= z
    if overlap:
        return (max(x, z), min(y, w))
    else:
        return None

def fully_contained(a, b):
    x = intersect(a, b)
    return a == x or b == x

print(sum(1 for a, b in pairs if fully_contained(a, b)))
stopwatch.add_split()

## Part 2

In [None]:
print(sum(1 for a, b in pairs if intersect(a, b) != None))
stopwatch.add_split()

# Day 5

## Process input

In [None]:
def parse_container_line(s):
    containers = []
    for i, j in enumerate(range(1, len(s), 4)):
        if s[j] == ' ':
            continue
        containers.append((i, s[j]))
    return containers

def parse_command_line(s):
    parts = s.split()
    return int(parts[1]), int(parts[3]) - 1, int(parts[5]) - 1

with open('inputs/05') as f:
    container_string, command_string = f.read().rstrip().split('\n\n')

container_lines = container_string.split('\n')
command_lines = command_string.split('\n')

num_stacks = len(container_lines[-1].strip().split())
stacks = [[] for _ in range(num_stacks)]
for line in reversed(container_lines[:-1]):
    for stack, container in parse_container_line(line):
        stacks[stack].append(container)

commands = [parse_command_line(l) for l in command_lines]

## Part 1

In [None]:
def apply(command, stacks):
    num, src, dst = command
    for _ in range(num):
        stacks[dst].append(stacks[src].pop())

stacks_copy = [s.copy() for s in stacks]
for c in commands:
    apply(c, stacks_copy)
    
print(''.join(s[-1] for s in stacks_copy))
stopwatch.add_split()

## Part 2

In [None]:
def apply(command, stacks):
    num, src, dst = command
    stacks[dst] += stacks[src][-num:]
    del stacks[src][-num:]

stacks_copy = [s.copy() for s in stacks]
for c in commands:
    apply(c, stacks_copy)
    
print(''.join(s[-1] for s in stacks_copy))
stopwatch.add_split()

# Day 6

## Process input

In [None]:
with open('inputs/06') as f:
    signal = f.read().rstrip()

## Part 1

In [None]:
def marker_index(num_chars, signal):
    for i in range(0, len(signal)-num_chars):
        j = i + num_chars
        if len(set(signal[i:j])) == num_chars:
            return j
        
    assert False, "A signal start was not found"

print(marker_index(4, signal))
stopwatch.add_split()

## Part 2

In [None]:
print(marker_index(14, signal))
stopwatch.add_split()

# Day 7

## Process input

In [None]:
@dataclass
class Directory:
    name: str
    contents: dict[str, Node]
    
    def __eq__(self, other):
        return self.name == other.name

@dataclass
class File:
    name: str
    size: int
    
    def __eq__(self, other):
        return self.name == other.name

Node = Directory | File

def parse_node(s):
    match s.split():
        case ["dir", name]:
            return Directory(name, {})
        case [size, name]:
            return File(name, int(size))

with open('inputs/07') as f:
    blocks = [b.strip().split('\n') for b in f.read().split('$')[1:]]

root = Directory('root', {})
current = root
history = []
for cmd, *output in blocks:
    match cmd.split():
        case ['cd', '/']:
            current = root
            history = []
        case ['cd', '..']:
            current = history.pop()
        case ['cd', dst]:
            history.append(current)
            current = current.contents[dst]
        case ['ls']:
            nodes = (parse_node(s) for s in output)
            current.contents = {n.name: n for n in nodes}

## Part 1

In [None]:
def dirsizes(directory):
    self_size = 0
    sizes = []
    for node in directory.contents.values():
        match node:
            case File(_, sz):
                self_size += sz
            case Directory(_, _):
                sizes += dirsizes(node)
                self_size += sizes[-1]
    sizes.append(self_size)
    return sizes

ds = dirsizes(root)
print(sum(sz for sz in ds if sz <= 100_000))
stopwatch.add_split()

## Part 2

In [None]:
min_del_size = 30_000_000 - (70_000_000 - ds[-1])

print(min(sz for sz in ds if sz >= min_del_size))
stopwatch.add_split()

# Day 8

## Process input

In [None]:
with open('inputs/08') as f:
    trees = np.array([[int(c) for c in row.rstrip()] for row in f], dtype=int)

## Part 1

In [None]:
def visible(trees):
    (m, n) = trees.shape
    vis = np.ones_like(trees, dtype=bool)
    rows = np.s_[1:m-1]
    cols = np.s_[1:n-1]
    vis[rows, cols] = False
    
    # horizontal
    for j in range(1, n-1):
        h = trees[rows, j]
        vis[rows, j] |= np.all(trees[rows,    :j].T < h, axis=0)
        vis[rows, j] |= np.all(trees[rows, j+1: ].T < h, axis=0)
    
    # vertical
    for i in range(1, m-1):
        h = trees[i, cols]
        vis[i, cols] |= np.all(trees[   :i, cols] < h, axis=0)
        vis[i, cols] |= np.all(trees[i+1: , cols] < h, axis=0)
    
    return vis
    

print(visible(trees).sum())
stopwatch.add_split()

## Part 2

In [None]:
def num_vis(trees, h):
    vis = np.multiply.accumulate(trees < h, dtype=bool, axis=0)
    n = np.sum(vis, axis=0)
    n[np.logical_not(np.all(vis, axis=0))] += 1
    return n

def scenic_score(trees):
    (m, n) = trees.shape
    score = np.zeros_like(trees)
    rows = np.s_[1:m-1]
    cols = np.s_[1:n-1]
    score[rows, cols] = 1
    
    # horizontal
    for j in range(1, n-1):
        h = trees[rows, j]
        score[rows, j] *= num_vis(trees[rows, j-1::-1].T, h)
        score[rows, j] *= num_vis(trees[rows, j+1:   ].T, h)
    
    # vertical
    for i in range(1, m-1):
        h = trees[i, cols]
        score[i, cols] *= num_vis(trees[i-1::-1, cols], h)
        score[i, cols] *= num_vis(trees[i+1:   , cols], h)
    
    return score

print(np.amax(scenic_score(trees)))
stopwatch.add_split()

# Day 9

## Process input

In [None]:
def parse_motions(f):
    for l in f:
        d, n = l.rstrip().split()
        if d == 'U':
            d = [0, 1]
        elif d == 'D':
            d = [0, -1]
        elif d == 'L':
            d = [-1, 0]
        else:
            d = [1, 0]
        
        for _ in range(int(n)):
            yield d
        

with open('inputs/09') as f:
    motions = np.array(list(parse_motions(f)), dtype=int)

## Part 1

In [None]:
def dt(h, t):
    distance = h - t
    dt = np.zeros(2, dtype=int)
    if np.amax(np.abs(distance)) <= 1:
        return dt
    else:
        return np.sign(distance)

def simulate(motions, num_knots):
    num_steps = len(motions) + 1
    knots = np.zeros((num_knots, num_steps, 2))
    knots[0] = np.add.accumulate(np.insert(motions, 0, 0, axis=0))
    
    for knot in range(1, num_knots):
        for step in range(1, num_steps):
            h = knots[knot-1, step  , :]
            t = knots[knot  , step-1, :]
            knots[knot, step, :] = t + dt(h, t)
    
    return knots

knots = simulate(motions, 2)
print(len({(x, y) for x, y in knots[1]}))
stopwatch.add_split()

## Part 2

In [None]:
knots = simulate(motions, 10)
print(len({(x, y) for x, y in knots[-1]}))
stopwatch.add_split()

# Day 10

## Process input

In [None]:
def run(instructions):
    x = 1
    for instruction in instructions:
        match instruction.split():
            case ['addx', num]:
                yield x
                yield x
                x += int(num)
            case ['noop']:
                yield x

with open('inputs/10') as f:
    x = np.fromiter(run(f.readlines()), dtype=int)

## Part 1

In [None]:
cycle = np.arange(1, len(x)+1)
idx = np.s_[19:221:40]
print(np.sum(x[idx]*cycle[idx]))
stopwatch.add_split()

## Part 2

In [None]:
img = np.ones((6, 40))
for i in range(6):
    for j in range(40):
        spritepos = x[40*i + j]
        if spritepos - 1 <= j <= spritepos + 1:
            img[i, j] = 0

plt.imshow(img, cmap='gray')
plt.show()

stopwatch.add_split()

# Day 11

## Process input

In [None]:
class Monkey:
    def __init__(self, s):
        [idline, itemline, opline, testline, trueline, falseline] = [
            l.strip(' :') for l in s.splitlines()
        ]
        self.counter = 0
        
        match idline.split():
            case ['Monkey', num]:
                self.id = int(num)
            case _:
                raise ValueError(f'Failed to match {idline}')
                
        match itemline.split():
            case ['Starting', 'items:', *items]:
                self.initial_items = [int(it.rstrip(',')) for it in items]
                self.items = self.initial_items.copy()
            case _:
                raise ValueError(f'Failed to match {itemline}')
                
        match opline.split():
            case ['Operation:', 'new', '=', 'old', '+', 'old']:
                self.op = lambda x: x + x
            case ['Operation:', 'new', '=', 'old', '*', 'old']:
                self.op = lambda x: x * x
            case ['Operation:', 'new', '=', 'old', '+', y]:
                operand = int(y)
                self.op = lambda x: x + operand
            case ['Operation:', 'new', '=', 'old', '*', y]:
                operand = int(y)
                self.op = lambda x: x * operand
            case _:
                raise ValueError(f'Failed to match {opline}')
                
        match testline.split():
            case ['Test:', 'divisible', 'by', y]:
                divisor = int(y)
                self.div = divisor
                self.test = lambda x: x % divisor == 0
            case _:
                raise ValueError(f'Failed to match {testline}')
                
        match trueline.split():
            case ['If', 'true:', 'throw', 'to', 'monkey', y]:
                recipient_t = int(y)
                self.true = recipient_t
            case _:
                raise ValueError(f'Failed to match {trueline}')
                
        match falseline.split():
            case ['If', 'false:', 'throw', 'to', 'monkey', y]:
                recipient_f = int(y)
                self.false = recipient_f
            case _:
                raise ValueError(f'Failed to match {falseline}')
        
    def do_turn(self, monkies):
        for item in self.items:
            self.counter += 1
            item = self.op(item) // 3
            if self.test(item):
                monkies[self.true].items.append(item)
            else:
                monkies[self.false].items.append(item)
        self.items = []
    
    def do_turn_2(self, monkies, lcm):
        for item in self.items:
            self.counter += 1
            item = self.op(item) % lcm
            if self.test(item):
                monkies[self.true].items.append(item)
            else:
                monkies[self.false].items.append(item)
        self.items = []
    
    def reset(self):
        self.counter = 0
        self.items = self.initial_items.copy()

with open('inputs/11') as f:
    monkies = [Monkey(s) for s in f.read().rstrip().split('\n\n')]

## Part 1

In [None]:
for _ in range(20):
    for monkey in monkies:
        monkey.do_turn(monkies)

print(np.prod(np.sort([monkey.counter for monkey in monkies])[-2:]))

stopwatch.add_split()

## Part 2

In [None]:
for m in monkies:
    m.reset()

lcm = np.lcm.reduce([m.div for m in monkies])

for _ in range(10000):
    for monkey in monkies:
        monkey.do_turn_2(monkies, lcm)

print(np.prod(np.sort([monkey.counter for monkey in monkies])[-2:]))

stopwatch.add_split()

# Day 12

## Process input

In [None]:
elevations = []

with open('inputs/12') as f:
    for i, line in enumerate(f):
        el = []
        for j, c in enumerate(line.rstrip()):
            match c:
                case 'S':
                    el.append(0)
                    start = (i, j)
                case 'E':
                    el.append(ord('z') - ord('a'))
                    end = (i, j)
                case _:
                    el.append(ord(c) - ord('a'))
        elevations.append(el)

elevations = np.array(elevations)

## Part 1

In [None]:
def distances(elevations, start):
    (M, N) = elevations.shape
    def neighbors(i, j):
        return (
            (k, l)
            for k, l in [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]
            if k in range(M) and l in range(N)
            and elevations[i, j] + 1 >= elevations[k, l]
        )
    
    visit = deque([(start, 0)])
    visited = np.zeros_like(elevations, dtype=bool)
    steps = np.full(elevations.shape, np.iinfo(int).max)
    while len(visit) > 0:
        idx, counter = visit.popleft()
        if visited[idx]:
            continue
        steps[idx] = counter
        visited[idx] = True
        visit.extend(((nb, counter+1) for nb in neighbors(*idx) if not visited[nb]))

    return steps

print(distances(elevations, start)[end])
stopwatch.add_split()

## Part 2

In [None]:
print(distances(-elevations, end)[elevations == 0].min())
stopwatch.add_split()

# Day 13

## Process input

In [None]:
with open('inputs/13') as f:
    packets = [eval(p) for p in f.read().split()]

## Part 1

In [None]:
def compare(a, b):
    match (a, b):
        case (int(x), int(y)):
            return x - y
        case (int(x), list(ys)):
            return compare([x], ys)
        case (list(xs), int(y)):
            return compare(xs, [y])
        case ([], []):
            return 0
        case ([], [y, *ys]):
            return -1
        case ([x, *xs], []):
            return 1
        case ([x, *xs], [y, *ys]):
            return cmp if (cmp := compare(x, y)) != 0 else compare(xs, ys)
        case _:
            raise ValueError("No match")

pairs = zip(packets[0::2], packets[1::2])
print(sum(i+1 for i, p in enumerate(pairs) if compare(*p) < 0))

stopwatch.add_split()

## Part 2

In [None]:
dividers = [ [[2]], [[6]] ]
sorted_packets = sorted(packets + dividers, key=functools.cmp_to_key(compare))
print(math.prod(sorted_packets.index(d) + 1 for d in dividers))

stopwatch.add_split()

# Day 14

## Process input

In [None]:
Tile = Enum('Tile', ['ROCK', 'SAND'])
initial_cave = {}

def connect(points):
    for a, b in zip(points[:-1], points[1:]):
        d = np.sign(b - a)
        k = 0
        while not np.all((p := a + k*d) == b):
            yield p[0], p[1]
            k += 1
    yield b[0], b[1]
            
with open('inputs/14') as f:
    for line in f:
        points = np.array([eval(p) for p in line.rstrip().split(' -> ')])
        initial_cave |= {p: Tile.ROCK for p in connect(points)}

## Part 1

In [None]:
cave = initial_cave.copy()
lowest_rock = max(y for _, y in cave)

visit = [(500, 0)]
count = 0
while True:
    x, y = visit[-1]
    if y == lowest_rock:
        break
        
    neighbors = tuple(p for p in ((x+1, y+1), (x-1, y+1), (x, y+1)) if p not in cave)
    if not neighbors:
        cave[visit.pop()] = Tile.SAND
        count += 1
    else:
        visit.extend(neighbors)

print(count)
stopwatch.add_split()

## Part 2

In [None]:
cave = initial_cave.copy()
floor_level = max(y for _, y in cave) + 2

visit = [(500, 0)]
count = 0
while visit:
    x, y = visit.pop()
    cave[x, y] = Tile.SAND
    count += 1
    if y < floor_level - 1:
        visit.extend(p for p in ((x+1, y+1), (x-1, y+1), (x, y+1)) if p not in cave)

print(count)
stopwatch.add_split()

# Day 15

## Process input

In [None]:
def parse(line):
    parts = [p.strip('xy=:,') for p in line.split()]
    return (int(parts[2]), int(parts[3])), (int(parts[8]), int(parts[9]))

with open('inputs/15') as f:
    sensors, beacons = (list(x) for x in zip(*(parse(l) for l in f)))

## Part 1

In [None]:
def dist(p, q):
    match p, q:
        case (x, y), (z, w):
            return abs(x - z) + abs(y - w)
        case _:
            raise ValueError(f'Invalid points {p}, {q}')

def merge(unmerged):
    if not unmerged:
        return []
    
    merged = [unmerged[0]]
    i = 0
    j = 1
    while j < len(unmerged):
        a, b = merged[i]
        c, d = unmerged[j]
        if c <= b:
            merged[i] = (a, max(b, d))
        else:
            merged.append((c, d))
            i += 1
        j += 1
    return merged
            
ranges = [dist(s, b) for s, b in zip(sensors, beacons)]
y = 2_000_000
no_beacons = merge(sorted(
    (sx - dx, sx + dx + 1)
    for (sx, sy), r in zip(sensors, ranges)
    if (dx := r - abs(sy - y)) >= 0
))

num_beacons = sum(1 for _, by in set(beacons) if by == y)
print(sum(b - a for a, b in no_beacons) - num_beacons)

stopwatch.add_split()

## Part 2

In [None]:
N = 4_000_000

sol = z3.Solver()
px = z3.Int('px')
py = z3.Int('py')
tf = z3.Int('Tuning frequency')

dist = lambda sx, sy: (
    z3.If(sx - px < 0, px - sx, sx - px) + 
    z3.If(sy - py < 0, py - sy, sy - py)
)

sol.add(px >= 0, py >= 0, px <= N, py <= N)
sol.add([dist(sx, sy) > r for (sx, sy), r in zip(sensors, ranges)])
sol.add(tf == 4_000_000 * px + py)

if sol.check().r == 1:
    print(sol.model()[tf])
else:
    print('unsat :(')

stopwatch.add_split()

# Day 16

## Process input

In [None]:
def parse(row):
    parts = row.split()
    valve = parts[1]
    flow_rate = int(parts[4].strip('rate=;'))
    neighbors = [p.rstrip(',') for p in parts[9:]]
    return valve, (flow_rate, neighbors)

with open('inputs/16') as f:
    rows = [parse(row) for row in f]
    
graph = {v: vs for v, (_, vs) in rows}
flow_rate = {v: f for v, (f, _) in rows}

## Part 1

In [None]:
def prune_impossibles(states, t, flow_rate):
    max_value = max(states.values())
    
    def value_ub(value, opened):
        unopened = sorted((f for u, f in flow_rate.items() if u not in opened), reverse=True)
        return value + sum(t_ * f for t_, f in zip(range(t-1, 1, -2), unopened))
    
    return {(u, o): v for (u, o), v in states.items() if value_ub(v, o) >= max_value}

def max_release(graph, flow_rate, debug=False):
    history = set()
    visit = [('AA', frozenset(), 0)]
    for t in range(30, 0, -1):
        if debug:
            print(f't = {t}')
            print(f'  {len(visit)} visit(s)')
        states = defaultdict(lambda: 0)
        for u, opened, value in visit:
            states[u, opened] = max(states[u, opened], value)
            history.add((u, opened, value))
        
        if debug:
            print(f'  {len(states)} state(s)')
        
        states = prune_impossibles(states, t, flow_rate)
        
        if debug:
            print(f'  {len(states)} pruned state(s)')
        
        visit = []
        for (u, opened), value in states.items():
            if u not in opened and flow_rate[u] > 0:
                u_open = opened | {u}
                u_value = value + (t-1)*flow_rate[u]
                if (u, u_open, u_value) not in history:
                    visit.append((u, u_open, u_value))
            visit.extend((v, opened, value) for v in graph[u] if (v, opened, value) not in history)
                
            
    return max(states.values())

print(max_release(graph, flow_rate, debug=False))

stopwatch.add_split()

## Part 2

In [None]:
def prune_impossibles(states, t, flow_rate):
    max_value = max(states.values())
    
    def value_ub(value, opened):
        unopened = sorted((f for u, f in flow_rate.items() if u not in opened), reverse=True)
        return value + sum(t_ * (f + g) for t_, f, g in zip(range(t-1, 1, -2), unopened[::2], unopened[1::2]))
    
    return {(uv, o): v for (uv, o), v in states.items() if value_ub(v, o) >= max_value}

def psort(a, b):
    return (a, b) if a <= b else (b, a)

def max_release(graph, flow_rate, debug=False):
    history = set()
    visit = [(('AA', 'AA'), frozenset(), 0)]
    for t in range(26, 0, -1):
        if not visit:
            break
            
        if debug:
            print(f't = {t}')
            print(f'  {len(visit)} visit(s)')
        
        states = defaultdict(lambda: 0)
        for uv, opened, value in visit:
            states[uv, opened] = max(states[uv, opened], value)
        history |= set(visit)
        
        if debug:
            print(f'  {len(states)} state(s)')
        
        states = prune_impossibles(states, t, flow_rate)
        
        if debug:
            print(f'  {len(states)} pruned state(s)')
        
        visit = set()
        for ((u, v), opened), value in states.items():
            # Case 1: open both u and v
            can_open_u = u not in opened and flow_rate[u] > 0
            can_open_v = v not in opened and flow_rate[v] > 0 and v != u
            if can_open_u and can_open_v:
                opened_ = opened | {u, v}
                value_ = value + (t-1)*(flow_rate[u] + flow_rate[v])
                visit.add(((u, v), opened_, value_))
                
            # Case 2: open u, move from v
            if can_open_u:
                opened_ = opened | {u}
                value_ = value + (t-1)*flow_rate[u]
                visit |= {(psort(u, v_), opened_, value_) for v_ in graph[v]}
            
            # Case 3: move from u, open v
            if can_open_v:
                opened_ = opened | {v}
                value_ = value + (t-1)*flow_rate[v]
                visit |= {(psort(u_, v), opened_, value_) for u_ in graph[u]}
            
            # Case 4: move from both u and v
            visit |= {(psort(u_, v_), opened, value) for u_ in graph[u] for v_ in graph[v]}
        visit = [v for v in visit if v not in history]
                
            
    return max(states.values())

print(max_release(graph, flow_rate, debug=False))

stopwatch.add_split()

# Day 17

## Process input

In [None]:
with open('inputs/17') as f:
    jets = [-1 if c == '<' else 1 for c in f.read().rstrip()]

## Part 1

In [None]:
def rocks(n):
    generators = [
        lambda dx, dy: {
            (x + dx, y + dy) for x, y in [ (0, 0), (1, 0), (2, 0), (3, 0) ]
        },
        lambda dx, dy: {
            (x + dx, y + dy) for x, y in [ (0, 1), (1, 0), (1, 1), (1, 2), (2, 1) ]
        },
        lambda dx, dy: {
            (x + dx, y + dy) for x, y in [ (0, 0), (1, 0), (2, 0), (2, 1), (2, 2) ]
        },
        lambda dx, dy: {
            (x + dx, y + dy) for x, y in [ (0, 0), (0, 1), (0, 2), (0, 3) ]
        },
        lambda dx, dy: {
            (x + dx, y + dy) for x, y in [ (0, 0), (0, 1), (1, 0), (1, 1) ]
        }
    ]
    for i in range(n):
        yield generators[i % 5]
    
def heights(n):
    chamber = set()
    h = 0
    yield h
    movements = itertools.cycle(jets)
    for rock_gen in rocks(n):
        rock = rock_gen(2, h + 4)

        for dx in movements:
            # horizontal
            rock, prev = { (x + dx, y) for x, y in rock }, rock
            if chamber & rock or any(x < 0 or x >= 7 for x, _ in rock):
                rock = prev

            # vertical
            rock, prev = { (x, y - 1) for x, y in rock }, rock
            if chamber & rock or min(y for _, y in rock) == 0:
                rock = prev
                break

        chamber |= rock
        h = max(h, *(y for _, y in rock))
        yield h

print(max(h for h in heights(2022)))

stopwatch.add_split()

## Part 2

In [None]:
def find_period(dy, fingerprint_size=100):
    N = len(dy)
    a, b = N-fingerprint_size, N
    fingerprint = dy[-fingerprint_size:]
    period = 1
    while period <= a and np.any(fingerprint != dy[a-period : b-period]):
        period += 1
    if period > a:
        raise ValueError('dy does not contain a period')
    return period

def warmup_length(dy, period, fingerprint_size=100):
    wu = 0
    while np.any(dy[wu:wu+fingerprint_size] != dy[wu+period:wu+period+fingerprint_size]):
        wu += 1
    return wu

def height_after(n, y):
    if n < len(y):
        return y[n]
    
    dy = np.diff(y)
    period = find_period(dy)
    warmup = warmup_length(dy, period)
    
    n -= warmup
    dy = dy[warmup:warmup+period]
    
    return y[warmup] + n//period * dy.sum() + dy[:n%period].sum()

target = 1_000_000_000_000
print(height_after(target, np.fromiter(heights(3000), dtype=int)))

stopwatch.add_split()

# Day 18

## Process input

In [None]:
with open('inputs/18') as f:
    droplet = {eval(xyz) for xyz in f}

## Part 1

In [None]:
def neighbors(x, y, z):
    return {(  x,   y, z-1), (  x,   y, z+1),
            (  x, y-1,   z), (  x, y+1,   z),
            (x-1,   y,   z), (x+1,   y,   z)}
    
def regions(droplet):
    d = droplet.copy()
    regions = []
    while d:
        region = set()
        source = d.pop()
        add_to_region = [source]
        while add_to_region:
            cell = add_to_region.pop()
            region.add(cell)
            nbs = neighbors(*cell) & d
            d -= nbs
            add_to_region.extend(nbs)
        regions.append(region)
    return regions

def area(region):
    A = 0
    cells = set()
    for cell in region:
        A += 6 - 2*len(neighbors(*cell) & cells)
        cells.add(cell)
    return A

rs = regions(droplet)
print(sum(area(r) for r in rs))

## Part 2

In [None]:
def air_pockets(droplet):
    xs = [x for x, _, _ in droplet]
    ys = [y for _, y, _ in droplet]
    zs = [z for _, _, z in droplet]
    rx = range(min(xs) - 1, max(xs) + 2)
    ry = range(min(ys) - 1, max(ys) + 2)
    rz = range(min(zs) - 1, max(zs) + 2)
    air = {(x, y, z) for x in rx for y in ry for z in rz if (x, y, z) not in droplet}
    anchor = (rx[0], ry[0], rz[0])
    return (r for r in regions(air) if anchor not in r)

filled_droplet = droplet.union(*air_pockets(droplet))
print(area(filled_droplet))

# Day 19

## Process input

In [None]:
def parse(blueprint):
    parts = blueprint.split()
    id_ = int(parts[1].rstrip(':'))
    costs = np.array([
        [int(parts[6]), 0, 0, 0],
        [int(parts[12]), 0, 0, 0],
        [int(parts[18]), int(parts[21]), 0, 0],
        [int(parts[27]), 0, int(parts[30]), 0]
    ])
    return id_, costs

with open('inputs/19') as f:
    blueprints = [parse(bp) for bp in f]

## Part 1

In [None]:
def value(blueprint, n):
    id_, costs = blueprint
    robot_ub = costs.max(axis=0)
    robot_ub[3] = 100
    new_robot = np.eye(4, dtype=int)
    best = 0
    stack = [(n, np.array((0, 0, 0, 0)), np.array((1, 0, 0, 0)))]
    while stack:
        time, resources, robots = stack.pop()
        best = max(best, resources[3] + time*robots[3])
        
        if time <= 1:
            continue
        
        value_ub = resources[3] + time*robots[3] + time*(time-1) // 2
        if value_ub <= best:
            continue
        
        under_ub = robots < robot_ub
        production_towards = np.array([True, True, robots[1] > 0, robots[2] > 0])
        for i in np.nonzero(under_ub & production_towards)[0]:
            t = time
            re = resources.copy()
            while any(costs[i] > re):
                t -= 1
                re += robots
            if t > 1:
                stack.append((t - 1, re + robots - costs[i], robots + new_robot[i]))
    
    return best


print(sum(bp[0] * value(bp, 24) for bp in blueprints))

stopwatch.add_split()

## Part 2

In [None]:
print(math.prod(value(bp, 32) for bp in blueprints[:3]))

stopwatch.add_split()

# Day 20

## Process input

In [None]:
with open('inputs/20') as f:
    numbers = [int(n) for n in f]

## Part 1

In [None]:
class Number:
    def __init__(self, value):
        self.value = value
        self.pred = None
        self.succ = None
    
    def move(self, k):
        for _ in range(k):
            # a -> self -> b -> c ...
            a = self.pred
            b = self.succ
            c = b.succ
            # ... to a -> b -> self -> c
            a.succ = b
            b.succ = self
            self.succ = c
            c.pred = self
            self.pred = b
            b.pred = a
    
    def to_list(self, n):
        l = []
        num = self
        for _ in range(n):
            l.append(num.value)
            num = num.succ
        return l

def mix(numlist, N):
    for n in numlist:
        n.move(n.value % (N - 1))
    

numlist = [Number(n) for n in numbers]

for pred, n, succ in zip([numlist[-1], *numlist[:-1]], numlist, [*numlist[1:], numlist[0]]):
    n.pred = pred
    n.succ = succ

N = len(numbers)
mix(numlist, N)

l = numlist[0].to_list(N)
i0 = l.index(0)
print(sum(l[(i0 + k) % N] for k in (1000, 2000, 3000)))

stopwatch.add_split()

## Part 2

In [None]:
dk = 811589153

numlist = [Number(dk * n) for n in numbers]

for pred, n, succ in zip([numlist[-1], *numlist[:-1]], numlist, [*numlist[1:], numlist[0]]):
    n.pred = pred
    n.succ = succ

N = len(numbers)
for _ in range(10):
    mix(numlist, N)

l = numlist[0].to_list(N)
i0 = l.index(0)
print(sum(l[(i0 + k) % N] for k in (1000, 2000, 3000)))

stopwatch.add_split()

# Day 21

## Process input

In [None]:
def parse(row):
    k, deps = row.split(': ')
    f = {'+': operator.add, '-': operator.sub, '*': operator.mul, '/': operator.truediv}
    match deps.split():
        case [s]:
            v = int(s)
        case [a, op, b]:
            v = (f[op], a, b)
    return k, v

with open('inputs/21') as f:
    monke = {k: v for k, v in (parse(row) for row in f)}

## Part 1

In [None]:
def evaluate(monke):
    evaluated = monke.copy()
    def inner(k):
        match evaluated[k]:
            case int(v):
                pass
            case (f, a, b):
                v = f(inner(a), inner(b))
            case _:
                raise RuntimeError(f'Failed to match {evaluated[k]}')
        # print(f'{k} -> {v}')
        evaluated[k] = v
        return v
    inner('root')
    return evaluated

print(evaluate(monke)['root'])

stopwatch.add_split()

## Part 2

In [None]:
def solve(monke):
    fixed = monke.copy()
    _, a, b = fixed['root']
    fixed['root'] = (operator.eq, a, b)
    del fixed['humn']
    
    s = z3.Solver()
    consts = {'humn': z3.Real('humn'), 'root': z3.Bool('root')}
    consts |= {k: z3.Real(k) for k in fixed if k != 'root'}
    for k in fixed:
        match fixed[k]:
            case int(v):
                s.add(consts[k] == v)
            case (f, a, b):
                s.add(consts[k] == f(consts[a], consts[b]))
            case _:
                raise RuntimeError(f'Failed to match {evaluated[k]}')
    s.add(consts['root'] == True)
    
    print(s.check())
    print()
    m = s.model()
    print(f'humn = {m[consts["humn"]]}')
    # print()
    # print(*(f'{c} = {m[c]}' for c in m), sep='\n')

solve(monke)

stopwatch.add_split()

# Day 22

## Process input

In [3]:
tiles = {}
with open('inputs/22') as f:
    map_, path_s = f.read().rstrip().split('\n\n')
    map_ = map_.splitlines()
    for y, row in enumerate(map_):
        for x, tile in enumerate(row):
            if tile != ' ':
                tiles[x+1, y+1] = tile
    
    path = []
    i = 0
    while i < len(path_s):
        if '0' <= path_s[i] <= '9':
            j = i+1
            while j < len(path_s) and '0' <= path_s[j] <= '9':
                j += 1
            path.append(int(path_s[i:j]))
            i = j
        elif path_s[i] == 'R':
            path.append(1)
            i += 1
        else:
            assert path_s[i] == 'L'
            path.append(-1)
            i += 1
    path.reverse()

graph = {}
for x, y in tiles:
    if tiles[x, y] == '#':
        continue
    graph[x, y] = [None] * 4
    for i, (dx, dy) in enumerate([(1, 0), (0, 1), (-1, 0), (0, -1)]):
        x_, y_ = x+dx, y+dy
        if (x_, y_) not in tiles:
            if dx < 0:
                x_ = max(c for c, r in tiles if r == y)
            elif dx > 0:
                x_ = min(c for c, r in tiles if r == y)
            elif dy < 0:
                y_ = max(r for c, r in tiles if c == x)
            else:
                y_ = min(r for c, r in tiles if c == x)
            
                
        if tiles[x_, y_] == '.':
            graph[x, y][i] = (x_, y_)

# print(*graph.items(), sep='\n')
# print()
# print(path)

## Part 1

In [4]:
def solve(graph, path):
    tile = min(x for x, y in graph if y == 1), 1
    facing = 0
    while path:
        for _ in range(path.pop()):
            if graph[tile][facing] == None:
                break
            tile = graph[tile][facing]
        if path:
            facing = (facing + path.pop()) % 4
    c, r = tile
    return 1000 * r + 4 * c + facing

print(solve(graph, path.copy()))

89224


## Part 2

In [32]:
L = 50
corners = [(50, 0), (100, 0), (50, 50), (0, 100), (50, 100), (0, 150)]
sides = [np.array([[j, i, 0, 1] for i in range(L) for j in range(L) if map_[r+i][c+j] == '.']).T for c, r in corners]
T = np.array([
    [[1, 0,  0,  0], [0,  1,  0,  0], [0, 0,  1,  0]],
    [[0, 0, -1, 50], [0,  1,  0,  0], [1, 0,  0,  1]],
    [[1, 0,  0,  0], [0,  0, -1, 50], [0, 1,  0,  1]],
    [[0, 0,  1, -1], [0, -1,  0, 49], [1, 0,  0,  1]],
    [[1, 0,  0,  0], [0, -1,  0, 49], [0, 0, -1, 51]],
    [[0, 1,  0,  0], [0,  0,  1, -1], [1, 0,  0,  1]]
])
sides = [A@x for A, x in zip(T, sides)]
tiles = {(x, y, z) for x, y, z in np.concatenate(sides, axis=1).T}

def solve(tiles, path):
    p = np.array([0, 0, 0])
    d = np.array([1, 0, 0])
    n = np.array([0, 0, -1])
    trail = [p.copy()]
    
    while path:
        for _ in range(path.pop()):
            if tuple(p + d) in tiles:
                p += d
            elif tuple(p + d - n) in tiles:
                trail.append(p.copy())
                p += d - n
                trail.append(p.copy())
                d, n = -n, d
            else:
                break
        trail.append(p.copy())
        
        if not path:
            break
            
        if path.pop() == -1:
            d = R.from_rotvec(np.pi/2 * n).apply(d)
        else:
            d = R.from_rotvec(-np.pi/2 * n).apply(d)
        d = np.around(d).astype(int)
    
    side = [(0, 0, -1), (1, 0, 0), (0, 1, 0), (-1, 0, 0), (0, 0, 1), (0, -1, 0)].index(tuple(n))
    facing = [
        [(1, 0, 0), (0,  1, 0), (-1, 0,  0), ( 0, -1,  0)],
        [(0, 0, 1), (0,  1, 0), ( 0, 0, -1), ( 0,  0, -1)],
        [(1, 0, 0), (0,  0, 1), (-1, 0,  0), ( 0,  0, -1)],
        [(0, 0, 1), (0, -1, 0), ( 0, 0, -1), ( 0,  1,  0)],
        [(1, 0, 0), (0, -1, 0), (-1, 0,  0), ( 0,  1,  0)],
        [(0, 0, 1), (1,  0, 0), ( 0, 0, -1), (-1,  0,  0)]
    ][side].index(tuple(d))
    T = np.array([
        [[1, 0,  0,  0], [0,  1,  0,  0], [0, 0,  1,  0]],
        [[0, 0, -1, 50], [0,  1,  0,  0], [1, 0,  0,  1]],
        [[1, 0,  0,  0], [0,  0, -1, 50], [0, 1,  0,  1]],
        [[0, 0,  1, -1], [0, -1,  0, 49], [1, 0,  0,  1]],
        [[1, 0,  0,  0], [0, -1,  0, 49], [0, 0, -1, 51]],
        [[0, 1,  0,  0], [0,  0,  1, -1], [1, 0,  0,  1]]
    ])[side]
    t = np.array([(50, 0), (100, 0), (50, 50), (0, 100), (50, 100), (0, 150)])[side]
    
    c, r = (np.linalg.solve(np.insert(T, 3, [0, 0, 0, 1], axis=0), np.insert(p, 3, 1))[:2] + t + 1).astype(int)
    return 1000 * r + 4 * c + facing

print(solve(tiles, path.copy()))

136182


# Day 23

## Process input

## Part 1

## Part 2

# Day X

## Process input

## Part 1

## Part 2

# Performance analysis

In [None]:
t = np.array(stopwatch.stop(add_split=False), dtype=int)
plt.plot(np.arange(len(t)), t / 1e9)
plt.xlabel('stars')
plt.ylabel('time (s)')
plt.grid('on')
plt.show()