In [78]:
from collections import deque, defaultdict
from functools import cache

def parse_input(input_file):
    grid = []
    with open(input_file) as f:
        for line in f:
            grid.append(line.rstrip())
    return grid

num_pad = [['7', '8', '9'], ['4', '5', '6'], ['1', '2', '3'], ['', '0', 'A']]
dir_pad = [['', '^', 'A'], ['<', 'v', '>']]

def to_key_to_position(board):
    m, n = len(board), len(board[0])
    ans = {}
    for i in range(m):
        for j in range(n):
            if board[i][j]:
                ans[board[i][j]] = (i, j)
    return ans

num_pad_key_to_position = to_key_to_position(num_pad)
dir_pad_key_to_position = to_key_to_position(dir_pad)


moves = {'^': (-1, 0), 'v': (1, 0), '<': (0, -1), '>':(0, 1)}

def shortest_path(board, start, end):
    if start == end:
        return ['']
    m, n = len(board), len(board[0])
    shortest = 100
    steps = 0
    q = deque([(start, '')])
    visited = set([(start, '')])
    ans = []
    while q and steps < shortest:
        cur_len = len(q)
        for _ in range(cur_len):
            (x, y), path = q.popleft()
            for mv, (dx, dy) in moves.items():
                nx, ny = x + dx, y + dy
                if 0 <= nx < m and 0 <= ny < n and board[nx][ny]:
                    if (nx, ny) == end:
                        shortest = steps
                        ans.append(path + mv)
                    elif (nx, ny) not in visited:
                        new_path = path + mv
                        visited.add(((nx, ny), new_path))
                        q.append(((nx, ny), new_path))
        steps += 1
    
    return ans

def shorted_seq_keyboard(seq):
    start = num_pad_key_to_position['A']
    ans = ['']
    for c in seq:
        end = num_pad_key_to_position[c]
        paths = shortest_path(num_pad, start, end)
        new_ans = []
        for p in ans:
            for p1 in paths:
                new_ans.append(p + p1 + 'A')
        start = end
        ans = new_ans
    return ans

@cache
def shortest_path_numpad(start, end):
    return shortest_path(num_pad, start, end)

@cache
def shortest_path_direction(start, end):
    return shortest_path(dir_pad, start, end)

def shorted_seq_direction(seq):
    start = dir_pad_key_to_position['A']
    ans = ['']
    for c in seq:
        end = dir_pad_key_to_position[c]
        paths = shortest_path_direction(start, end)
        new_ans = []
        for p in ans:
            for p1 in paths:
                new_ans.append(p + p1 + 'A')
        start = end
        ans = new_ans
    return ans

def indirection(path1):
    path2 = []
    for p in path1:
        path2 += shorted_seq_direction(p)
    min_len = min(len(p) for p in path2)
    
    path2 = [x for x in path2 if len(x) == min_len]
    return path2

def score(seq):
    path1 = shorted_seq_keyboard(seq)
    path2 = indirection(path1)
    path3 = indirection(path2)

    return len(path3[0]) * int(seq[:-1])


def part1(input_file):
    seqs = parse_input(input_file)

    return sum(score(seq) for seq in seqs)

@cache 
def recur(cur_robot, start_key, end_key, total_robots):
    if start_key == end_key:
        return 1
    if cur_robot == 0:
        start, end = num_pad_key_to_position[start_key], num_pad_key_to_position[end_key]
        path = shortest_path_numpad(start, end)
    else:
        start, end = dir_pad_key_to_position[start_key], dir_pad_key_to_position[end_key]
        path = shortest_path_direction(start, end)
    if cur_robot == total_robots:
        return len(path[0]) + 1
    ans = float('inf')
    for p in path:
        steps = recur(cur_robot + 1, 'A', p[0], total_robots)
        for i in range(1, len(p)):
            steps += recur(cur_robot + 1, p[i-1], p[i], total_robots)
        steps += recur(cur_robot + 1, p[-1], 'A', total_robots)
        ans = min(ans, steps)
    return ans


def part2(input_file, total_bots):
    seqs = parse_input(input_file)
    ans = 0
    for seq in seqs:
        complexity = recur(0, 'A', seq[0], total_bots)
        for i in range(1, len(seq)):
            complexity += recur(0, seq[i-1], seq[i], total_bots)
        print(seq, complexity)
        ans += complexity * int(seq[:-1])
    return ans

In [73]:
part1('input/day21_test.txt')

126384

In [74]:
part1('input/day21.txt')

163086

In [79]:
part2('input/day21_test.txt', 2)

029A 68
980A 60
179A 68
456A 64
379A 64


126384

In [81]:
part2('input/day21.txt', 25)

286A 86475783008
480A 90594397580
140A 87513499934
413A 87288844796
964A 85006969638


198466286401228