In [1]:
import heapq
from collections import Counter
from tqdm.notebook import tqdm
from itertools import combinations

In [2]:
data = open("input/20").read().splitlines()

In [3]:
grid = {}
start = None
end = None
for r, row in enumerate(data):
    for c, elem in enumerate(row):
        grid[(r, c)] = elem
        if elem == "S":
            start = (r, c)
        if elem == "E":
            end = (r, c)

In [4]:
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]

In [5]:
def astar_with_path(pos):
    q = []
    heapq.heappush(q, (0, 0, pos, [pos]))

    visited = set()
    while q:
        h_cost, cost, pos, path = heapq.heappop(q)
        for direction in directions:
            neighbour = (pos[0] + direction[0], pos[1] + direction[1])
            neighbour_elem = grid.get(neighbour)
            if not neighbour_elem or neighbour_elem == "#" or neighbour in visited:
                continue
            
            visited.add(neighbour)

            if neighbour == end:
                # +1 for the end node
                path = path + [end]
                return cost + 1, visited, path
                
            new_cost = cost + 1
            heur = abs(end[0] - neighbour[0]) + abs(end[1] - neighbour[1])
            heapq.heappush(q, (heur + new_cost, new_cost, neighbour, path + [neighbour]))
    return 0

In [6]:
best_distance, visited, path = astar_with_path(start)

In [7]:
# Find all neighbours
neighbours = []
for pos in list(visited):
    for direction in directions:
        neighbour = (pos[0] + direction[0], pos[1] + direction[1])
        if grid.get(neighbour) == "#":
           neighbours.append(neighbour) 

candidates = []
for pos, count in Counter(neighbours).items():
    if count >= 2:
        candidates.append(pos)

In [8]:
def astar_with_limit(pos, limit):
    q = []
    heapq.heappush(q, (0, 0, pos))

    visited = set()
    while q:
        h_cost, cost, pos = heapq.heappop(q)
        if cost > limit:
            return -1
        visited.add(pos)
        for direction in directions:
            neighbour = (pos[0] + direction[0], pos[1] + direction[1])
            if neighbour in visited:
                continue
            neighbour_elem = grid.get(neighbour)
            if not neighbour_elem or neighbour_elem == "#":
                continue
            
            if neighbour == end:
                # +1 for the end node
                return cost + 1
                
            heur = abs(end[0] - neighbour[0]) + abs(end[1] - neighbour[1])
            heapq.heappush(q, (heur + cost + 1, cost + 1, neighbour))

In [9]:
saves = []
for cand in tqdm(candidates):
    grid[cand] = "."
    new_distance = astar_with_limit(start, best_distance)
    if 0 < new_distance < best_distance:
        saves.append(best_distance - new_distance)
    grid[cand] = "#"

  0%|          | 0/7823 [00:00<?, ?it/s]

In [10]:
part1 = 0
for a, b in sorted(Counter(saves).items()):
    if a >= 100:
        part1 += b
print(f"Answer #1: {part1}")    

Answer #1: 1372


# Part 2

In [11]:
path_dict = {elem: idx for idx, elem in enumerate(path)}

In [12]:
part1_opt = 0
part2 = 0
for a, b in combinations(path, 2):
    manhattan_distance = abs(a[0] - b[0]) + abs(a[1] - b[1])
    if manhattan_distance <= 20:
        if path_dict[b] - path_dict[a] - manhattan_distance >= 100:
            part2 += 1
    if manhattan_distance <= 2:
        if path_dict[b] - path_dict[a] - manhattan_distance >= 100:
            part1_opt += 1

In [13]:
print(f"Answer #1 opt: {part1_opt}")
print(f"Answer #2: {part2}")

Answer #1 opt: 1372
Answer #2: 979014
