In [22]:
from itertools import combinations

test_input_str = """467..114..
...*......
..35..633.
......#...
617*......
.....+.58.
..592.....
......755.
...$.*....
.664.598.."""

puzzle_input_str = open("./puzzle_input/day3.txt").read()


def neighbours(grid, r, c):
    max_r = len(grid)
    max_c = len(grid[0])
    candidates = (
        (r - 1, c - 1),
        (r - 1, c),
        (r - 1, c + 1),
        (r, c - 1),
        (r, c + 1),
        (r + 1, c - 1),
        (r + 1, c),
        (r + 1, c + 1),
    )
    for maybe_r, maybe_c in candidates:
        if maybe_r == r and maybe_c == c:
            continue
        if maybe_r >= max_r or maybe_r < 0:
            continue
        if maybe_c >= max_c or maybe_c < 0:
            continue

        yield (maybe_r, maybe_c)


def find_numbers(r, row):
    digit_chars = []
    digit_positions = []

    for c, col in enumerate(row):
        if col.isdigit():
            digit_chars.append(col)
            digit_positions.append(c)
        else:
            if len(digit_chars) > 0:
                digit = int("".join(digit_chars))
                positions = ((r, dc) for dc in digit_positions)
                yield (digit, positions)
                digit_chars = []
                digit_positions = []


def parse_input(input_str):
    for line_str in input_str.split("\n"):
        yield line_str


def is_symbol(cell):
    return (not cell.isdigit()) and (cell != ".")


def symbols_adjacent(grid, positions):
    for r, c in positions:
        for nr, nc in neighbours(grid, r, c):
            if is_symbol(grid[nr][nc]):
                yield (nr, nc)


def part_one(input_str: str) -> int:
    grid = list(parse_input(input_str))
    total = 0

    for r, row in enumerate(grid):
        digits = list(find_numbers(r, row))
        for digit, positions in digits:
            symbols = list(symbols_adjacent(grid, positions))
            if len(symbols) > 0:
                total += digit

    return total


assert 4361 == part_one(test_input_str)

print("part one:", part_one(puzzle_input_str))

part one: 556057


In [23]:
def part_two(input_str: str) -> int:
    grid = list(parse_input(input_str))
    digits_with_symbols = []
    total = 0

    for r, row in enumerate(grid):
        digits = list(find_numbers(r, row))
        for digit, positions in digits:
            symbols = list(symbols_adjacent(grid, positions))
            if len(symbols) > 0:
                digits_with_symbols.append((digit, set(symbols)))

    for part1, part2 in combinations(digits_with_symbols, 2):
        digit1, symbols1 = part1
        digit2, symbols2 = part2

        if symbols1.intersection(symbols2):
            total += digit1 * digit2

    return total


assert 467835 == part_two(test_input_str)

print("part two:", part_two(puzzle_input_str))

part two: 82824352
