In [1]:
import numpy as np
from typing import Callable, TypeVar
from collections import Counter, deque, defaultdict
import itertools
from functools import cmp_to_key
import regex as re
from intervaltree import Interval, IntervalTree


np.set_printoptions(edgeitems=30, linewidth=100000, 
    formatter=dict(float=lambda x: "%.3g" % x))

T = TypeVar('T')

def data(day: int, parser: Callable[[str], T] = str) -> list[T]:
  with open(f"./data/day{day}.txt") as f:
    return [parser(line.strip()) for line in f.readlines()]

processors = {
  'int_list': lambda x: [int(y) for y in x.split()]
}

def search(nodes, start, get_neighbors, end_condition=lambda _, __: False, dfs=True):
    q, visited = deque([(start, 0)]), {}
    while q:
        current, distance = q.popleft() if dfs else q.pop()
        if end_condition(current, distance):
            return visited, current
        if current in visited:
            continue
        for node in get_neighbors(current, distance):
            q.append((node, distance+1))
        visited[current] = distance
    return visited, None

In [7]:

def day1():
    loc1, loc2 = zip(*data(1, processors['int_list']))
    part1 = sum(abs(x[0]-x[1]) for x in zip(sorted(loc1), sorted(loc2)))
    counts = Counter(loc2)
    part2 = sum(x*counts[x] for x in loc1)
    return part1, part2

day1()

(1941353, 22539317)

In [80]:
def day2():
    def check_safe(report):
        ascending = sorted(report)
        diffs = np.diff(ascending)
        return max(diffs) <= 3 and min(diffs) >= 1 and (
            report == ascending or
            report == list(reversed(ascending))
        )

    def check_safe_damp(report):
        if check_safe(report):
            return 1, 1
        for damped in itertools.combinations(report, len(report)-1):
            if check_safe(list(damped)):
                return 0, 1
        return 0, 0

    reports = data(2, processors['int_list'])
    safe = np.array((0,0))
    for report in reports:
        safe += check_safe_damp(report)
    return safe

day2()

array([356, 413])

In [76]:
def day3():
    def mul_strings(s):
        x, y = s.split(',')
        return int(x)*int(y)

    instructions = ''.join(data(3))
    matches = list(re.finditer(r'mul\((\d+,\d+)\)', instructions))
    conds = list(re.finditer(r"don't\(\).+?do\(\)", instructions))
    donts = IntervalTree([Interval(*cond.span()) for cond in conds])
    result = sum([mul_strings(mul[1]) * (1 if not donts[mul.span()[0]] else 1j) for mul in matches])
    return int(result.real+result.imag), int(result.real)

day3()

(182780583, 90772405)

In [184]:
def day4():
    grid = np.array(data(4, lambda x: np.array(list(x))))
    ymax, xmax = grid.shape

    def find_target_occurences(target):
        occurences = set()

        def get_neighbors(current, distance):
            col, row = int(current.real), int(current.imag)
            target_letter = target[distance]
            if grid[row, col] != target_letter:
                return
            if distance == len(target)-1:
                total.add(current)
                return
            for v in [1, -1, 1j, -1j, 1+1j, 1-1j, -1+1j, -1-1j]:
                new = current + v
                x, y = int(new.real), int(new.imag)
                if not (y >= 0 and x >= 0 and y < ymax and x < xmax):
                    continue
                yield new
        
        for j in range(ymax):
            for i in range(xmax):
                total = set()
                coordinate = i+1j*j
                search(grid, coordinate, get_neighbors)
                for end in total:
                    occurences.add((coordinate, end))
        return occurences

    def find_diags(hits, l):
        centers = Counter()
        for start, end in hits:
            distance = end-start
            if abs(distance.real) == l and abs(distance.imag) == l:
                center = start + distance/2
                centers[center] += 1
        return centers
    
    def find_straights(hits, target):
        rev = target[::-1]
        l = len(target)
        td = l-1
        for start, end in hits:
            i, j = int(start.real), int(start.imag)
            d = end-start
            if (
                (d.real == td and not d.imag and ''.join(grid[j, i:i+l]) == target)
                or (d.real == -td and not d.imag and ''.join(grid[j, i-td:i+1]) == rev)
                or (not d.real and d.imag == td and ''.join(grid[j:j+l, i]) == target)
                or (not d.real and d.imag == -td and ''.join(grid[j-td:j+1, i]) == rev)
            ):
                yield start

    target = 'XMAS'
    matches = find_target_occurences(target)
    part1 = sum(find_diags(matches, len(target)-1).values()) + len(list(find_straights(matches, target)))

    centers = find_diags(find_target_occurences(target[1:]), len(target)-2)
    part2 = sum([1 if centers[x] == 2 else 0 for x in centers])

    return (part1, part2)

day4()

(2599, 1948)

In [66]:
def day5():
    text = data(5)
    split = text.index('')
    lists = [tuple(map(int, x.split(','))) for x in text[split+1:]]

    parents = defaultdict(lambda: set())
    for x in text[:split]:
        parent, child = tuple(map(int, x.split('|')))
        parents[child].add(parent)

    def check_illegal(nums):
        illegal = set()
        for num in nums:
            if num in illegal:
                return True
            illegal.update(parents[num])

    def compare(a, b):
        if a in parents[b]:
            return 1
        elif b in parents[a]:
            return -1
        return -1 if a < b else 1

    part1, part2 = 0, 0
    for nums in lists:
        if not check_illegal(nums):
            part1 += nums[len(nums)//2]
        else:
            part2 += sorted(nums, key=cmp_to_key(compare))[len(nums)//2]
            
    return part1, part2

day5()

(6041, 4884)

In [20]:
def day6(debug=False):
    grid = np.array(data(6, list))
    ymax, xmax = grid.shape
    y, x = np.argwhere(grid == '^')[0]
    grid[y, x] = '.'
    turns = [(1, 0), (0, 1), (-1, 0), (0, -1)]
    directions, v = itertools.cycle([np.array(x) for x in turns]), turns[-1]

    visited = set([(y, x)])
    while True:
        ny, nx = y+v[1], x + v[0]
        if (nx < 0 or ny < 0 or nx >= xmax or ny >= ymax):
            break
        elif grid[ny, nx] != '.':
            v = next(directions)
            continue
        y, x = ny, nx
        visited.add((y, x))

    return len(visited)

day6()

4374

In [37]:
def day6(debug=False):
    grid = np.array(data(6, list))
    ymax, xmax = grid.shape
    y, x = np.argwhere(grid == '^')[0]
    grid[y, x] = '.'
    turns = [(1, 0), (0, 1), (-1, 0), (0, -1)]
    
    def run_guard(y, x, obstacle=(-1, -1)):
        directions, v = itertools.cycle(turns), turns[-1]
        y, x = int(y)-v[1], int(x)-v[0]
        visited, states = set(), set()
        while True:
            ny, nx = y+v[1], x + v[0]
            if (ny, nx, v) in states:
                return True, states
            elif (nx < 0 or ny < 0 or nx >= xmax or ny >= ymax):
                return False, visited
            elif grid[ny, nx] != '.' or (ny, nx) == obstacle:
                v = next(directions)
                continue
            y, x = ny, nx
            visited.add((y, x))
            states.add((y, x, v))

    _, original = run_guard(y, x)

    print(sum([run_guard(y, x, i)[0] for i in original]))

day6(True)

1705
