In [32]:
import heapq

directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
INF = float('inf')

def parse_input(input_file):
    grid = []
    with open(input_file) as f:
        for line in f:
            line = line.rstrip()
            row = [c for c in line]
            grid.append(row)
    
    return grid


def find_start(grid):
    m, n = len(grid), len(grid[0])
    for r in range(m):
        for c in range(n):
            if grid[r][c] == 'S':
                return r, c
            
def print_grid(grid):
    for row in grid:
        print(''.join(row))

def part1(input_file):
    grid = parse_input(input_file)
    start = find_start(grid)
    dist = {(start, 0): 0}
    hp = [(0, start, 0)]
    while hp:
        cur_d, cur_pos, face = heapq.heappop(hp)
        if cur_d > dist[(cur_pos, face)]:
            continue
        if grid[cur_pos[0]][cur_pos[1]] == 'E':
            return cur_d
        x, y = cur_pos
        dx, dy = directions[face]
        nx, ny = x + dx, y + dy
        if grid[nx][ny] != '#' and dist.get(((nx, ny), face), INF) > cur_d + 1:
            dist[((nx, ny), face)] = cur_d + 1
            heapq.heappush(hp, (cur_d+1, (nx, ny), face))
        left_turn = (face +3) % 4
        if dist.get((cur_pos, left_turn), INF) > cur_d + 1000:
            dist[(cur_pos, left_turn)] = cur_d + 1000
            heapq.heappush(hp, (cur_d + 1000, cur_pos, left_turn))

        right_turn = (face + 1) % 4
        if dist.get((cur_pos, right_turn), INF) > cur_d + 1000:
            dist[(cur_pos, right_turn)] = cur_d + 1000
            heapq.heappush(hp, (cur_d + 1000, cur_pos, right_turn))



def part2(input_file):
    grid = parse_input(input_file)
    start = find_start(grid)
    dist = {(start, 0): 0}
    hp = [(0, start, 0, [start])]
    best_score = INF
    tiles = set()
    while hp:
        cur_d, cur_pos, face, path = heapq.heappop(hp)
        if cur_d > best_score:
            break
        if cur_d > dist[(cur_pos, face)]:
            continue
        if grid[cur_pos[0]][cur_pos[1]] == 'E':
            best_score = cur_d
            for p in path:
                tiles.add(p)
        x, y = cur_pos
        dx, dy = directions[face]
        nx, ny = x + dx, y + dy
        if grid[nx][ny] != '#' and dist.get(((nx, ny), face), INF) >= cur_d + 1:
            dist[((nx, ny), face)] = cur_d + 1
            heapq.heappush(hp, (cur_d+1, (nx, ny), face, path+[(nx, ny)]))
        left_turn = (face +3) % 4
        if dist.get((cur_pos, left_turn), INF) >= cur_d + 1000:
            dist[(cur_pos, left_turn)] = cur_d + 1000
            heapq.heappush(hp, (cur_d + 1000, cur_pos, left_turn, path[:]))

        right_turn = (face + 1) % 4
        if dist.get((cur_pos, right_turn), INF) >= cur_d + 1000:
            dist[(cur_pos, right_turn)] = cur_d + 1000
            heapq.heappush(hp, (cur_d + 1000, cur_pos, right_turn, path[:]))
    return len(tiles)

In [33]:
part1('input/day16_test.txt')

7036

In [34]:
part1('input/day16_test2.txt')

11048

In [35]:
part1('input/day16.txt')

85480

In [36]:
part2('input/day16_test.txt')

45

In [37]:
part2('input/day16.txt')

518