In [1]:
from heapq import heappop, heappush
from collections import defaultdict, deque

In [2]:
def get_input(filename):
    with open(filename) as f:
        grid = {}
        for i, line in enumerate(f):
            for j, s in enumerate(line.strip()):
                if s in "SE.":
                    grid[i, j] = 1
                if s == "S":
                    source = (i, j)
                elif s == "E":
                    target = (i, j)
    return grid, source, target

def dijkstra(source, target, grid):
    rots = {
        (0, 1): [(-1, 0), (1, 0)],
        (0, -1): [(-1, 0), (1, 0)],
        (1, 0): [(0, -1), (0, 1)],
        (-1, 0): [(0, -1), (0, 1)]
    }
    i, j = source
    di, dj = 0, 1
    cost = {(i, j, di, dj): 0}
    pq = [(0, i, j, di, dj)]
    prev = {}
    while pq:
        c, i, j, di, dj = heappop(pq)
        if (i, j) == target:
            return c, prev, di, dj
        ni, nj = i + di, j + dj
        if (ni, nj) in grid:
            alt = cost[i, j, di, dj] + 1
            if (ni, nj, di, dj) not in cost or alt < cost[(ni, nj, di, dj)]:
                heappush(pq, (alt, ni, nj, di, dj))
                cost[ni, nj, di, dj] = alt
                prev[ni, nj, di, dj] = (i, j, di, dj)
        for ndi, ndj in rots[di, dj]:
            alt = cost[i, j, di, dj] + 1000
            if (i, j, ndi, ndj) not in cost or alt < cost[(i, j, ndi, ndj)]:
                heappush(pq, (alt, i, j, ndi, ndj))
                cost[i, j, ndi, ndj] = alt
                prev[i, j, ndi, ndj] = (i, j, di, dj)


In [3]:
grid, source, target = get_input("16_input.txt")
cost, prev, di, dj = dijkstra(source, target, grid)
print(cost)

91464


In [4]:
def dijkstra2(source, target, grid):
    rots = {
        (0, 1): [(-1, 0), (1, 0)],
        (0, -1): [(-1, 0), (1, 0)],
        (1, 0): [(0, -1), (0, 1)],
        (-1, 0): [(0, -1), (0, 1)]
    }
    i, j = source
    di, dj = 0, 1
    cost = defaultdict(lambda: float("inf"))
    cost[(i, j, di, dj)] = 0
    pq = [(0, i, j, di, dj)]
    prev = defaultdict(set)
    while pq:
        c, i, j, di, dj = heappop(pq)
        if (i, j) == target:
            continue

        ni, nj = i + di, j + dj
        if (ni, nj) in grid:
            alt = c + 1
            if alt <= cost[(ni, nj, di, dj)]:
                if alt < cost[(ni, nj, di, dj)]:
                    prev[(ni, nj, di, dj)] = set()
                heappush(pq, (alt, ni, nj, di, dj))
                cost[(ni, nj, di, dj)] = alt
                prev[(ni, nj, di, dj)].add((i, j, di, dj))

        for ndi, ndj in rots[di, dj]:
            alt = c + 1000
            if alt <= cost[(i, j, ndi, ndj)]:
                if alt < cost[(i, j, ndi, ndj)]:
                    prev[(i, j, ndi, ndj)] = set()
                heappush(pq, (alt, i, j, ndi, ndj))
                cost[(i, j, ndi, ndj)] = alt
                prev[(i, j, ndi, ndj)].add((i, j, di, dj))
    return prev, cost

def get_path(node, prev):
    q = deque()
    q.append(([node], node))
    all_paths = []
    while q:
        path, current = q.pop()
        if current not in prev or not prev[current]:
            all_paths.append([(n[0], n[1]) for n in reversed(path)])
            continue
        for parent in prev[current]:
            if parent not in path:
                q.append((path + [parent], parent))
    return all_paths


In [5]:
grid, source, target = get_input("16_input.txt")
prev, cost = dijkstra2(source, target, grid)
min_cost = min(c for t, c in cost.items() if (t[0], t[1]) == target)
targets = [u for u, c in cost.items() if c == min_cost]
all_paths = []
for node in targets:
    all_paths.extend(get_path(node, prev))
print(len(set([x for y in all_paths for x in y])))

494
