In [1]:
import heapq
from tqdm.notebook import tqdm

In [2]:
data = open("input/16").read().splitlines()

In [3]:
grid = {}
start = None
end = None
for r, row in enumerate(data):
    for c, elem in enumerate(row):
        grid[(r, c)] = elem
        if elem == "S":
            start = (r, c)
        if elem == "E":
            end = (r, c)

In [4]:
def astart_find_best(pos, direction):
    q = []
    heapq.heappush(q, (0, 0, pos, direction))

    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    visited = set()
    while q:
        h_cost, cost, pos, prev_direction = heapq.heappop(q)

        for direction in directions:
            neighbour = (pos[0] + direction[0], pos[1] + direction[1])
            neighbour_elem = grid.get(neighbour)
            if not neighbour_elem or neighbour_elem == "#" or neighbour in visited:
                continue
            
            if neighbour_elem == "E":
                # +1 for the end node
                return cost + 1
                
            visited.add(neighbour)

            if direction == prev_direction:
                new_cost = cost + 1
            else:
                new_cost = cost + 1001
                
            heur = abs(end[0] - neighbour[0]) + abs(end[1] - neighbour[1])
            
            heapq.heappush(q, (heur + new_cost, new_cost, neighbour, direction))


In [5]:
part1 = astart_find_best(start, (0, 1))
print(f"Answer #1: {part1}")

Answer #1: 90460


# Part 2

In [6]:
# Cache intermediate steps to speed up the bfs in part 2

shortest_map = {}
for pos in tqdm(grid.keys()):
    if grid[pos] == "#":
        continue
    shortest1 = astart_find_best(pos, (0, 1))
    shortest2 = astart_find_best(pos, (0, -1))
    shortest3 = astart_find_best(pos, (1, 0))
    shortest4 = astart_find_best(pos, (-1, 0))

    shortest_map[(pos[0], pos[1])] = min(shortest1, shortest2, shortest3, shortest4)

  0%|          | 0/19881 [00:00<?, ?it/s]

In [7]:
def bfs_find_all(pos, direction, upper_limit):
    q = []
    heapq.heappush(q, (0, pos, direction, {pos}))

    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    poses_in_path = set()
   
    while q:
        cost, pos, prev_direction, path = heapq.heappop(q)

        if cost + shortest_map[(pos[0], pos[1])] > upper_limit:
            continue

        for direction in directions:
            neighbour = (pos[0] + direction[0], pos[1] + direction[1])
            
            if neighbour in path:
                continue

            neighbour_elem = grid.get(neighbour)

            if not neighbour_elem or neighbour_elem == "#":
                continue

            if neighbour_elem == "E":
                for p in path:
                    poses_in_path.add(p)
                
            if direction == prev_direction:
                heapq.heappush(q, (cost + 1, neighbour, direction, path | {neighbour}))
            else:
                heapq.heappush(q, (cost + 1001, neighbour, direction, path | {neighbour}))

    # +1 for the end node
    return len(poses_in_path) + 1

In [8]:
part2 = bfs_find_all(start, (0, 1), part1)
print(f"Answer #2: {part2}")

Answer #2: 575
