https://adventofcode.com/2021/day/12

In [16]:
%%time
from collections import deque, defaultdict

with open('data/12.txt') as fh:
    data = fh.read()


def load_graph(data):
    D = defaultdict(list)
    for pair in data.split():
        (a, b) = pair.split('-')
        if b != 'start':
            D[a].append(b)
        if a != 'start':
            D[b].append(a)
    return D


def no_small_caves_revisited(pth):
    smallcaves = [x for x in pth if x.islower()]
    return len(set(smallcaves)) == len(smallcaves)

def at_most_one_small_cave_revisited(pth):
    smallcaves = [x for x in pth if x.islower()]
    return len(set(smallcaves)) >= len(smallcaves) - 1
    
def count_all_paths(graph, validate_path=no_small_caves_revisited):
    q = deque([('start',)])
    visited = {('start',)}
    pathcount = 0
    while q:
        pth = q.popleft()
        current = pth[-1]        
        if current == 'end':
            pathcount += 1
            continue
        for nabe in graph[current]:
            newpth = pth + (nabe,)
            if newpth in visited:
                continue
            visited.add(newpth)
            if not validate_path(newpth):
                continue
            q.append(newpth)
    return pathcount


graph = load_graph(data)

part_1 = count_all_paths(graph, no_small_caves_revisited)
print('part_1 =', part_1)

part_2 = count_all_paths(graph, at_most_one_small_cave_revisited)
print('part_2 =', part_2)

part_1 = 4754
part_2 = 143562
CPU times: user 2.23 s, sys: 7.99 ms, total: 2.24 s
Wall time: 2.24 s


In [2]:
%%time

from collections import defaultdict

with open('data/12.txt') as fh:
    data = fh.read()

def load_graph(data):
    D = defaultdict(list)
    for pair in data.split():
        (a, b) = pair.split('-')
        if b != 'start':
            D[a].append(b)
        if a != 'start':
            D[b].append(a)
    return D

def count_paths_recursive(node, graph, maxsmall=0, smallcount=0, smallvisited=None):
    if smallvisited is None:
        smallvisited = set()
    if node == 'end':
        return 1
    if node.islower():
        if node in smallvisited:
            smallcount += 1
            if smallcount > maxsmall:
                return 0
        smallvisited.add(node)
    return sum(count_paths_recursive(nabe, graph, maxsmall, smallcount, smallvisited.copy()) for nabe in graph[node])

graph = load_graph(data)
print('part_1 =', count_paths_recursive('start', graph, maxsmall=0))
print('part_2 =', count_paths_recursive('start', graph, maxsmall=1))

part_1 = 4754
part_2 = 143562
CPU times: user 676 ms, sys: 3.11 ms, total: 679 ms
Wall time: 677 ms
