## Setup

In [None]:
import sys
from pathlib import Path

from aocd import get_data, submit

In [2]:
# Add parent directory to path to allow relative imports into Jupyter notebook
sys.path.append(str(Path.cwd().parent))

In [153]:
# Get raw advent-of-code data
data: str = get_data(year=2024, day=20)

## Part a

In [None]:
# Imports
from common.utils.dict_grid import find_object_in_grid, text_to_dict

In [None]:
# Functions
def find_standard_path_positions(grid: dict[complex, str]) -> dict[complex, int]:
    """Find the ordered shortest path positions in a single-route maze."""
    # Find start, end and initial position
    position = previous_position = find_object_in_grid(grid, "S")
    path = [position]
    end = find_object_in_grid(grid, "E")

    while position != end:
        # Try to move in all four directions
        for direction in [1j, 1, -1j, -1]:
            new_position = position + direction

            # Skip if new position is invalid (previous position or wall)
            if new_position == previous_position or grid[new_position] == "#":
                continue

            # Update position and path
            position, previous_position = new_position, position
            path.append(position)

    # Return the positions and their order in the path
    return {position: i for i, position in enumerate(path)}

In [154]:
# Parse data
grid = text_to_dict(data)

In [155]:
# Find the shortest path without cheating
standard_path = find_standard_path_positions(grid)

In [None]:
cheats_over_100_a = 0
for position, step_count in standard_path.items():  # For every step in the standard path
    for direction in [1j, 1, -1j, -1]:  # Check every neighbor
        # If the neighbor is a wall and the position across is more than 100 steps further along the path, count it
        if grid[position + direction] == "#" and standard_path.get(position + direction * 2, 0) - step_count - 1 > 100:
            cheats_over_100_a += 1

In [None]:
# Submit answer
submit(cheats_over_100_a, part="a", day=20, year=2024)

## Part b

In [189]:
# Functions
def find_cheat_paths_from_step(
    position: complex,
    step_count: int,
    standard_path: dict[complex, int],
    grid: dict[complex, str],
    minimum_gain: int = 100,
    max_cheat_steps: int = 20,
) -> int:
    """Find valid cheat paths and their gains from a starting position."""

    def dfs(position: complex, path_len: int = 0) -> None:
        if (
            path_len > max_cheat_steps  # Stop if the maximum number of steps is reached
            or position not in grid  # Or if we're out of bounds
            or path_len
            >= visited.get(position, max_cheat_steps + 1)  # Or if we've already been here with a shorter path
        ):
            return

        # Update visited
        visited[position] = path_len

        # Check if we found end of cheat path with sufficient gain
        if standard_path.get(position, 0) - step_count - path_len >= minimum_gain:
            cheat_path_ends.add(position)

        # Continue searching
        for direction in [1j, 1, -1j, -1]:
            dfs(position + direction, path_len + 1)

    cheat_path_ends = set()
    visited = {}

    dfs(position)

    return len(cheat_path_ends)

In [None]:
cheats_over_100_b = sum(
    find_cheat_paths_from_step(pos, step_count, standard_path, grid) for pos, step_count in standard_path.items()
)

In [None]:
# Submit answer
submit(cheats_over_100_b, part="b", day=20, year=2024)