In [1]:
import doctest
import io
import re
from typing import Iterable, List, Optional

In [2]:
DATA = "input.txt"

# Part 1

Idea: Iterate through the lines keeping a buffer of the 3 most recent around. Upon pushing the next line to the window -- and popping the oldest one -- scan through the middle line in the window, parsing out any numbers and checking the surrounding context to see if any symbols are present. 

In [3]:
def is_symbol(c: str) -> bool:
    """Returns True if ``c`` is a symbol."""
    return re.match("[^0-9\.\\n]", c) is not None

def is_number(c: str) -> bool:
    """Returns True if ``c`` is a digit."""
    return re.match("[0-9]", c) is not None

In [4]:
class Buffer:
    def __init__(self):
        self.lines = []
        self.initialised = False

    def reset(self, line_len: int, n_line: int) -> None:
        self.lines = ["." * line_len for _ in range(n_line)]
        self.initialised = True
        
    def append(self, line: str) -> None:
        self.lines = self.lines[1:]
        self.lines.append(line)

In [5]:
def symbol_at_column(lines: List[str], idx: int) -> bool:
    """Returns True there is a symbol at column ``idx`` in ``lines``."""
    if idx < 0 or idx >= len(lines[0]):
        return False
    return any(is_symbol(line[idx]) for line in lines)

In [6]:
def parse_part_numbers(above: str, line: str, below: str) -> Iterable[int]:
    """Returns all part numbers in ``line``.

    Example:

        >>> a = "467..114.."
        >>> l = "...*......"
        >>> b = "..35..633."
        >>> list(parse_part_numbers(a, l, b))
        []

        >>> a = "...*......"
        >>> l = "..35..633."
        >>> b = "......#..."
        >>> list(parse_part_numbers(a, l, b))
        [35, 633]

        >>> a = "617*......"
        >>> l = ".....+.58."
        >>> b = "..592....."
        >>> list(parse_part_numbers(a, l, b))
        []

        >>> a = ".....*"
        >>> l = ".....9"
        >>> b = "......"
        >>> list(parse_part_numbers(a, l, b))
        [9]
    """
    in_num = False
    cur_num_seen_symbol = False
    part_num = 0
    
    for i in range(len(line)):
        if line[i] == "." or is_symbol(line[i]):
            if in_num:
                if cur_num_seen_symbol or symbol_at_column([above, line, below], i):
                    yield part_num
            in_num = False
            cur_num_seen_symbol = False
            part_num = 0
        elif is_number(line[i]):
            if not in_num:
                cur_num_seen_symbol = symbol_at_column([above, line, below], i - 1)
            cur_num_seen_symbol = cur_num_seen_symbol or symbol_at_column([above, line, below], i)
            in_num = True
            part_num = part_num*10 + int(line[i])
        else:
            raise ValueError(f"unknown character '{line[i]}'")

    if in_num and (cur_num_seen_symbol or symbol_at_column([above, line, below], i - 1)):
        yield part_num

In [7]:
def part_numbers(lines: io.TextIOBase) -> Iterable[int]:
    """Returns all part numbers in the lines of a schematic.

    Example:

        >>> lines = [
        ...     "467..114..",
        ...     "...*......",
        ...     "..35..633.",
        ...     "......#...",
        ...     "617*......",
        ...     ".....+.58.",
        ...     "..592.....",
        ...     "......755.",
        ...     "...$.*....",
        ...     ".664.598..",
        ... ]
        >>> list(part_numbers(lines))
        [467, 35, 633, 617, 592, 755, 664, 598]
    """
    b = Buffer()
    for line in lines:
        line = line.strip()
        if not b.initialised:
            b.reset(len(line), 3)

        # pop old line and push new one
        b.append(line)

        # iterate through middle line, parsing numbers and checking if part
        yield from parse_part_numbers(b.lines[0], b.lines[1], b.lines[2])

    if b.initialised:
        b.append("."*len(b.lines[0]))
        yield from parse_part_numbers(b.lines[0], b.lines[1], b.lines[2])

In [8]:
doctest.testmod()

TestResults(failed=0, attempted=18)

In [9]:
with open(DATA, "r") as f:
    print(sum(part_numbers(f)))

512794


# Part 2

Idea: similar to Part 1 but now for each `*` found we parse the numbers out of the surrounding context. If there is only two present then the gear ratio can be computed and returned.

In [10]:
def number_at(line: str, idx: int) -> bool:
    """Returns True if there is a number at ``idx`` in line."""
    if idx < 0 or idx >= len(line):
        return False
    return is_number(line[idx])

In [11]:
def parse_number(line: str, idx: int) -> int:
    """Returns the number at ``idx`` in line.

    Example:

        >>> parse_number("...*.123", 6)
        123
    """
    while number_at(line, idx - 1):
        idx = idx - 1
    n = int(line[idx])
    while number_at(line, idx + 1):
        idx = idx + 1
        n = n*10 + int(line[idx])
    return n

In [12]:
def parse_gear_part_numbers(above: str, line: str, below: str) -> Iterable[List[int]]:
    """Returns all part numbers for each gear in ``line``.

    Example:

        >>> a = "467..114.."
        >>> l = "...*......"
        >>> b = "..35..633."
        >>> list(parse_gear_part_numbers(a, l, b))
        [[467, 35]]
    """
    for i in range(len(line)):
        if line[i] != "*":
            continue
        numbers = []
        
        if number_at(line, i - 1):
            numbers.append(parse_number(line, i - 1))
        if number_at(line, i + 1):
            numbers.append(parse_number(line, i + 1))

        tl_num = number_at(above, i - 1)
        tm_num = number_at(above, i)
        tr_num = number_at(above, i + 1)
        if tl_num:
            numbers.append(parse_number(above, i - 1))
        if tm_num and not tl_num:
            numbers.append(parse_number(above, i))
        if tr_num and not tm_num:
            numbers.append(parse_number(above, i + 1))


        bl_num = number_at(below, i - 1)
        bm_num = number_at(below, i)
        br_num = number_at(below, i + 1)
        if bl_num:
            numbers.append(parse_number(below, i - 1))
        if bm_num and not bl_num:
            numbers.append(parse_number(below, i))
        if br_num and not bm_num:
            numbers.append(parse_number(below, i + 1))

        yield numbers

In [13]:
def gear_part_numbers(lines: io.TextIOBase) -> Iterable[int]:
    """Returns all part numbers for gears in the lines of a schematic.

    Example:

        >>> lines = [
        ...     "467..114..",
        ...     "...*......",
        ...     "..35..633.",
        ...     "......#...",
        ...     "617*......",
        ...     ".....+.58.",
        ...     "..592.....",
        ...     "......755.",
        ...     "...$.*....",
        ...     ".664.598..",
        ... ]
        >>> list(gear_part_numbers(lines))
        [[467, 35], [755, 598]]
    """
    b = Buffer()
    for line in lines:
        line = line.strip()
        if not b.initialised:
            b.reset(len(line), 3)

        # pop old line and push new one
        b.append(line)

        # iterate through middle line, parsing numbers and checking if part
        for pn in parse_gear_part_numbers(b.lines[0], b.lines[1], b.lines[2]):
            if len(pn) == 2:
                yield pn

    if b.initialised:
        b.append("."*len(b.lines[0]))
        for pn in parse_gear_part_numbers(b.lines[0], b.lines[1], b.lines[2]):
            if len(pn) == 2:
                yield pn

In [14]:
doctest.testmod()

TestResults(failed=0, attempted=25)

In [15]:
with open(DATA, "r") as f:
    print(sum(r1 * r2 for r1, r2 in gear_part_numbers(f)))

67779080
