In [14]:
import itertools
from functools import cache
from tqdm import tqdm

with open("input.txt", "r") as f:
    lines = f.read()


class GenericKeypad:

    directions = {
        "<": (-1, 0),
        ">": (1, 0),
        "^": (0, -1),
        "v": (0, 1),
    }

    @property
    def keypad(self):
        raise NotImplementedError

    @property
    def keypad_reverse(self):
        return {v: k for k, v in self.keypad.items()}

    def __init__(self, start_pos: str):
        self.current_pos = start_pos


    @cache
    def _get_shortest_paths(
        self, current_pos: tuple[int, int], target_pos: tuple[int, int]
    ) -> list[str]:
        if current_pos == target_pos:
            return [""]

        dx, dy = target_pos[0] - current_pos[0], target_pos[1] - current_pos[1]

        all_paths = []
        
        if dx > 0:
            next_pos = (current_pos[0] + 1, current_pos[1])
            if next_pos in self.keypad_reverse:
                for p in self._get_shortest_paths(next_pos, target_pos):
                    all_paths.append(">" + p)

        elif dx < 0:
            next_pos = (current_pos[0] - 1, current_pos[1])
            if next_pos in self.keypad_reverse:
                for p in self._get_shortest_paths(next_pos, target_pos):
                    all_paths.append("<" + p)

        if dy > 0:
            next_pos = (current_pos[0], current_pos[1] + 1)
            if next_pos in self.keypad_reverse:
                for p in self._get_shortest_paths(next_pos, target_pos):
                    all_paths.append("v" + p)

        elif dy < 0:
            next_pos = (current_pos[0], current_pos[1] - 1)
            if next_pos in self.keypad_reverse:
                for p in self._get_shortest_paths(next_pos, target_pos):
                    all_paths.append("^" + p)


        return all_paths

    def move(self, direction: str):
        next_pos = (
            self.keypad[self.current_pos][0] + self.directions[direction][0],
            self.keypad[self.current_pos][1] + self.directions[direction][1],
        )
        if next_pos in self.keypad_reverse:
            self.current_pos = self.keypad_reverse[next_pos]
        else:
            raise ValueError(f"Invalid direction: {direction}")

    def get_shortest_path(self, target_pos: str):
        return self._get_shortest_paths(
            self.keypad[self.current_pos], self.keypad[target_pos]
        )


class NumericKeypad(GenericKeypad):

    keypad = {
        "7": (0, 0),
        "8": (1, 0),
        "9": (2, 0),
        "4": (0, 1),
        "5": (1, 1),
        "6": (2, 1),
        "1": (0, 2),
        "2": (1, 2),
        "3": (2, 2),
        "0": (1, 3),
        "A": (2, 3),
    }

    def __init__(self, start_pos: str):
        super().__init__(start_pos)

    def __hash__(self):
        return hash(self.current_pos) + hash("numpad")
    
    def __eq__(self, other):
        return self.current_pos == other.current_pos and self.__class__ == other.__class__


class DirectionKeypad(GenericKeypad):

    keypad = {
        "^": (1, 0),
        "A": (2, 0),
        "<": (0, 1),
        "v": (1, 1),
        ">": (2, 1),
    }

    def __init__(self, start_pos: str):
        super().__init__(start_pos)

    def __hash__(self):
        return hash(self.current_pos) + hash("dirpad")
    
    def __eq__(self, other):
        return self.current_pos == other.current_pos and self.__class__ == other.__class__

In [24]:
@cache
def get_all_combs(buttons_to_press: str, keypad):
    all_combs = []

    for char in buttons_to_press:
        shortest_paths = keypad.get_shortest_path(char)

        if len(shortest_paths) > 0:
            for k in shortest_paths[0]:
                keypad.move(k)

        paths = [path + "A" for path in shortest_paths]

        paths = paths if len(paths) > 0 else ["A"]
        all_combs.append(paths)

    all_combs = ["".join(x) for x in itertools.product(*all_combs)]

    return set(all_combs), keypad.current_pos

@cache
def length_for_comb(comb: str, pad_nr: int, max_dep: int):
    new_keypad = DirectionKeypad("A")
    total_length = 0

    for char in comb:
        length, keypad_pos = get_all_paths(char, pad_nr + 1, new_keypad, max_dep)
        total_length += length
        new_keypad.current_pos = keypad_pos

    return total_length


def get_all_paths(buttons_to_press: str, pad_nr: int, keypad: GenericKeypad, max_dep: int):

    all_combs, keypad_pos = get_all_combs(buttons_to_press, keypad)
    keypad.current_pos = keypad_pos

    if pad_nr == max_dep:
        return min(map(len, all_combs)), keypad.current_pos
    else:

        min_length = float("inf")

        for comb in all_combs:
            total_length = length_for_comb(comb, pad_nr, max_dep)

            if total_length < min_length:
                min_length = total_length

        return min_length, keypad.current_pos

score = 0

for line in tqdm(lines.splitlines()):
    num_pad = NumericKeypad("A")
    score += sum(get_all_paths(char, 0, num_pad, 2)[0] for char in line) * int(line[:-1])

print(score)

score = 0

for line in tqdm(lines.splitlines()):
    num_pad = NumericKeypad("A")
    score += sum(get_all_paths(char, 0, num_pad, 25)[0] for char in line) * int(line[:-1])

print(score)

100%|██████████| 5/5 [00:00<00:00, 108.01it/s]


246990


100%|██████████| 5/5 [00:00<00:00, 19.16it/s]

306335137543664



