In [1]:
import doctest
import io
import math
import re
from itertools import cycle
from typing import List, Tuple

In [2]:
DATA = "input.txt"

# Part 1

Store directed, possibly cyclical, graph in a dict and follow instructions until `ZZZ` is reached.

In [3]:
Graph = dict[str, str]

In [4]:
def parse_input(data: io.TextIOBase) -> Tuple[str, Graph]:
    """Returns the instruction string and graph.

    Example:
        >>> data = io.StringIO('''LLR
        ...
        ... AAA = (BBB, BBB)
        ... BBB = (AAA, ZZZ)
        ... ZZZ = (ZZZ, ZZZ)
        ... ''')
        >>> ins, graph = parse_input(data)
        >>> ins
        'LLR'
        >>> graph
        {'AAA': ('BBB', 'BBB'), 'BBB': ('AAA', 'ZZZ'), 'ZZZ': ('ZZZ', 'ZZZ')}
    """ 
    data_iter = iter(data)
    ins = next(data_iter).strip()
    next(data_iter) # skip blank line
    graph = {}
    pattern = re.compile("([A-Z0-9]+) = \(([A-Z0-9]+), ([A-Z0-9]+)\)\s?")
    for line in data_iter:
        m = pattern.match(line)
        graph[m.group(1)] = (m.group(2), m.group(3))
    return ins, graph

In [5]:
def n_steps(ins: str, graph: Graph) -> int:
    """Returns the number of steps to reach ZZZ from AAA.

    Example:

        >>> ins = "LLR"
        >>> graph = {'AAA': ('BBB', 'BBB'), 'BBB': ('AAA', 'ZZZ'), 'ZZZ': ('ZZZ', 'ZZZ')}
        >>> n_steps(ins, graph)
        6
    """
    element = "AAA"
    cnt = 0
    for i in cycle(ins):
        if element == "ZZZ":
            break
        if i == "L":
            element = graph[element][0]
        else:
            element = graph[element][1]
        cnt += 1
    return cnt

In [6]:
doctest.testmod()

TestResults(failed=0, attempted=7)

In [7]:
with open(DATA, "r") as f:
    ins, graph = parse_input(f)

In [8]:
n_steps(ins, graph)

19951

# Part 2

In theory it's the same as above but for multiple states. In practice the naive implementation seems slow so instead we'll find the number of steps required for each state to reach all Z ending states before encounting a loop and then find the least common multiple of all of these numbers. 

e.g. if there were 4 starting states that reached a Z ending state in 2, 3, 4, and 6 steps respectively the least common multiple -- the point at which they are all on a state ending in Z -- would be 12. 

In [9]:
def n_steps_2(ins: str, graph: Graph, start_state: str) -> List[int]:
    """Returns the number of steps to reach XXZ nodes from `start_state`.

    Example:

        >>> ins = "LR"
        >>> graph = {'11A': ('11B', 'XXX'), '11B': ('XXX', '11Z'), '11Z': ('11B', 'XXX')}
        >>> n_steps_2(ins, graph, "11A")
        [2]
    """
    element = start_state
    visited = set()
    cnt = 0
    cnts = []
    for ins_pos, i in cycle(zip(range(len(ins)), ins)):
        if (element, ins_pos) in visited:
            break
        visited.add((element, ins_pos))
        
        if element[2] == "Z":
            cnts.append(cnt)
            
        if i == "L":
            element = graph[element][0]
        else:
            element = graph[element][1]
        cnt += 1
    return cnts

In [10]:
doctest.testmod()

TestResults(failed=0, attempted=10)

In [11]:
with open(DATA, "r") as f:
    ins, graph = parse_input(f)

In [12]:
start_states = [s for s in graph if s[2] == "A"]
steps_to_z = []
for s in start_states:
    steps_to_z.extend(n_steps_2(ins, graph, s))

In [13]:
math.lcm(*steps_to_z)

16342438708751