In [None]:
from collections import deque
from functools import cache

def parse_input(input_file):
    grid = []
    with open(input_file) as f:
        for line in f:
            row = [int(x) for x in line.rstrip()]
            grid.append(row)
    return grid

deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]

def calc_score(start, grid):
    m, n = len(grid), len(grid[0])
    q = deque([start])
    visited = set([start])
    ans = 0
    while q:
        i, j = q.popleft()
        if grid[i][j] == 9:
            ans += 1
            continue
        for di, dj in deltas:
            ni, nj = i + di, j + dj
            if 0 <= ni < m and 0 <= nj < n and grid[i][j] + 1 == grid[ni][nj] and (ni, nj) not in visited:
                visited.add((ni, nj))
                q.append((ni, nj))
    return ans


def part1(input_file):
    grid = parse_input(input_file)
    m, n = len(grid), len(grid[0])
    ans = 0
    for i in range(m):
        for j in range(n):
            if grid[i][j] == 0:
                ans += calc_score((i, j), grid)
    return ans

def calc_score2(start, grid):
    m, n = len(grid), len(grid[0])
    @cache
    def dfs(i, j):
        if grid[i][j] == 9:
            return 1
        ans = 0
        for di, dj in deltas:
            ni, nj = i + di, j + dj
            if 0 <= ni < m and 0 <= nj < n and grid[i][j] + 1 == grid[ni][nj]:
                ans += dfs(ni, nj)
        return ans
    i, j = start
    return dfs(i, j)


def part2(input_file):
    grid = parse_input(input_file)
    m, n = len(grid), len(grid[0])
    ans = 0
    for i in range(m):
        for j in range(n):
            if grid[i][j] == 0:
                ans += calc_score2((i, j), grid)
    return ans


In [7]:
part1('input/day10_test.txt')

36

In [8]:
part1('input/day10.txt')

538

In [10]:
part2('input/day10_test.txt')

81

In [11]:
part2('input/day10.txt')

1110