In [19]:
from utils import read_lines
from collections import deque
import heapq

def parse_input(input_file):
    ans = []
    for line in read_lines(input_file):
        ans.append([c for c in line])
    return ans

def find_nodes(grid):
    m, n = len(grid), len(grid[0])
    ans = {}
    for i in range(m):
        for j in range(n):
            c = grid[i][j]
            if c not in ('.', '#'):
                ans[c] = (i, j)
    return ans

deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]

def build_graph(grid):
    nodes = find_nodes(grid)
    m, n = len(grid), len(grid[0])
    edges = {}
    for node, (i, j) in nodes.items():
        edges_cur_node = []
        visited = set([(i, j)])
        step = 0
        q = deque([(i, j)])
        while q:
            cur_len = len(q)
            step += 1
            for _ in range(cur_len):
                r, c = q.popleft()
                for dr, dc in deltas:
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < m and 0<= nc < n and grid[nr][nc] != '#' and (nr, nc) not in visited:
                        visited.add((nr, nc))
                        if grid[nr][nc] != '.':
                            edges_cur_node.append((grid[nr][nc], step))
                        else:
                            q.append((nr, nc))
        edges[node] = edges_cur_node
    return nodes, edges 

def can_reach(key_mask, node):
    if 'A' <= node <= 'Z':
        return key_mask & 1 << (ord(node.lower()) - ord('a'))
    else:
        return True

def calc_mask(mask, node):
    if 'a' <= node <= 'z':
        return mask | 1 << (ord(node) - ord('a'))
    else:
        return mask

def dijstra(key_mask, node, steps, dist, graph):
    hp = [[steps, key_mask, node]]
    new_states = []
    while hp:
        cur_step, cur_mask, cur_node = heapq.heappop(hp)
        for next_node, next_step in graph[cur_node]:
            if can_reach(cur_mask, next_node):
                next_mask = calc_mask(cur_mask, next_node)
                total_step = cur_step + next_step
                if dist.get((next_mask, next_node), float('inf')) > total_step:
                    dist[(next_mask, next_node)] = total_step
                    heapq.heappush(hp, [total_step, next_mask, next_node])
                    new_states.append([total_step, next_mask, next_node])
    return new_states

def part1(input_file):
    grid = parse_input(input_file)
    nodes, edges = build_graph(grid)
    end_state = 0
    for k in edges:
        if 'a' <= k <= 'z':
            end_state |= 1 << (ord(k) - ord('a'))
    
    dist = {(0, '@'): 0}
    hp = [[0, 0, '@']] # steps, mask, node
    while hp:
        steps, key_mask, node = heapq.heappop(hp)
        if key_mask == end_state:
            return steps
        for new_state in dijstra(key_mask, node, steps, dist, edges):
            heapq.heappush(hp, new_state)
    
        

In [18]:
part1('inputs/day18_test.txt')

86

In [20]:
part1('inputs/day18.txt')

4420

In [43]:
def build_graph2(grid):
    nodes = find_nodes(grid)
    start_i, start_j =  nodes['@']
    grid[start_i][start_j] = '#'
    for di, dj in deltas:
        grid[start_i + di][start_j + dj] = '#'
    grid[start_i-1][start_j-1] = '@'
    nodes['@'] = (start_i-1, start_j-1)
    grid[start_i-1][start_j+1] = '!'
    nodes['!'] = (start_i-1, start_j+1)
    grid[start_i+1][start_j-1] = '$'
    nodes['$'] = (start_i+1, start_j-1)
    grid[start_i+1][start_j+1] = '%'
    nodes['%'] = (start_i+1, start_j+1)

    m, n = len(grid), len(grid[0])
    edges = {}
    for node, (i, j) in nodes.items():
        edges_cur_node = []
        visited = set([(i, j)])
        step = 0
        q = deque([(i, j)])
        while q:
            cur_len = len(q)
            step += 1
            for _ in range(cur_len):
                r, c = q.popleft()
                for dr, dc in deltas:
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < m and 0<= nc < n and grid[nr][nc] != '#' and (nr, nc) not in visited:
                        visited.add((nr, nc))
                        if grid[nr][nc] != '.':
                            edges_cur_node.append((grid[nr][nc], step))
                        else:
                            q.append((nr, nc))
        edges[node] = edges_cur_node
    return nodes, edges 

def dijstra2(key_mask, node, steps, dist, graph):
    hp = [[steps, key_mask, node]]
    new_states = []
    while hp:
        cur_step, cur_mask, cur_nodes = heapq.heappop(hp)
        cur_nodes = list(cur_nodes)
        for i in range(4):
            cur_node = cur_nodes[i]
            for next_node, next_step in graph[cur_node]:
                if can_reach(cur_mask, next_node):
                    next_mask = calc_mask(cur_mask, next_node)
                    total_step = cur_step + next_step
                    next_nodes = tuple(cur_nodes[:i] + [next_node] + cur_nodes[i+1:])
                    if dist.get((next_mask, next_nodes), float('inf')) > total_step:
                        dist[(next_mask, next_nodes)] = total_step
                        heapq.heappush(hp, [total_step, next_mask, next_nodes])
                        new_states.append([total_step, next_mask, next_nodes])
    return new_states

def part2(input_file):
    grid = parse_input(input_file)
    nodes, edges = build_graph2(grid)
    # print(nodes)
    # print(edges)
    end_state = 0
    for k in edges:
        if 'a' <= k <= 'z':
            end_state |= 1 << (ord(k) - ord('a'))
    
    dist = {(0, ('@', '!', '$', '%')): 0}
    hp = [[0, 0, ('@', '!', '$', '%')]] # steps, mask, nodes
    while hp:
        steps, key_mask, node = heapq.heappop(hp)
        if dist[(key_mask, node)] < steps:
            continue
        if key_mask == end_state:
            return steps
        for new_state in dijstra2(key_mask, node, steps, dist, edges):
            heapq.heappush(hp, new_state)

In [44]:
part2('inputs/day18_test2.txt')

24

In [45]:
part2('inputs/day18.txt')

2128