In [1]:
area = {}

with open("inputs/day18-input.txt") as f:
    for y, line in enumerate(f.read().splitlines()):
        for x, tile in enumerate(line):
            area[(x, y)] = tile

## Part 1

In [2]:
%%time

from itertools import combinations


position = next(filter(lambda x: area[x] == "@", area))

nodes = {tile: {"position": position, "paths": {}} for position, tile in area.items() if (tile.isalpha() and tile.islower()) or tile == "@"}

directions = {
    0: (0, -1),
    1: (1, 0),
    2: (0, 1),
    3: (-1, 0)
}


def inverse_direction(direction):
    if direction == (0, 0):
        return direction
    return directions[(next(x for x in directions if directions[x] == direction) + 2) % 4]


def move(position, direction):
    return (position[0] + direction[0], position[1] + direction[1])


def backtrack(start, end, paths, direction=(0, 0), steps=0, keys_needed="", visited=set()):
    tile = area[start]
    
    if tile == "#" or start in visited:
        return False
    
    if paths and steps >= min(paths, key=lambda x: x["steps"])["steps"]:
        return False
    
    if start == end:
        paths.append({"steps": steps, "keys_needed": keys_needed})
        return True
    
    if tile.isalpha() and tile.isupper():
        keys_needed += tile.lower()
        
    origin = inverse_direction(direction)
    visited.add(start)
    
    for d in directions.values():
        if d == origin:
            continue
            
        backtrack(move(start, d), end, paths, d, steps + 1, keys_needed, visited)
        
    visited.remove(start)


for a, b in combinations(nodes, 2):
    paths = []
    backtrack(nodes[a]["position"], nodes[b]["position"], paths)
    
    path = min(paths, key=lambda x: x["steps"])
    nodes[a]["paths"][b] = path
    nodes[b]["paths"][a] = path

CPU times: user 32.8 s, sys: 38.9 ms, total: 32.8 s
Wall time: 33.9 s


In [3]:
%%time

def keys_available(keys, keys_needed):
    return all(x in keys for x in keys_needed)

def nodes_reachable(node, keys):
    return filter(lambda x: not x in keys and keys_available(keys, node["paths"][x]["keys_needed"]), node["paths"])


nodes_length = len(nodes)
nodes_set = set(nodes)
memory = {}

def backtrack_nodes(node, keys="", steps=0):
    global memory

    # Use Memoization for incomplete paths!
    state = (node, tuple(sorted(nodes_set - set(keys))))
    if state in memory:
        add_steps, add_keys = memory[state]
        steps += add_steps
        keys += add_keys
        
        return {"keys": keys, "steps": steps}
    
    keys += node
    if len(keys) == nodes_length:
        return {"keys": keys, "steps": steps}
    
    solutions = []
    for n in nodes_reachable(nodes[node], keys):
        new_steps = steps + nodes[node]["paths"][n]["steps"]
        solution = backtrack_nodes(n, keys, new_steps)
        
        if solution:
            solutions.append(solution)
    
    if solutions:
        remaining_length = len(state[1])
        min_solution = min(solutions, key=lambda x: x["steps"])
        memory[state] = (min_solution["steps"] - steps, min_solution["keys"][-remaining_length:])
        
        return min_solution
    else:
        return False


print(backtrack_nodes("@"))

{'keys': '@ejmuzxwnaptiokyqlrvscfhdgb', 'steps': 5808}
CPU times: user 1.31 s, sys: 14.9 ms, total: 1.33 s
Wall time: 1.34 s


## Part 2

In [4]:
%%time

multi_area = area.copy()
center = next(filter(lambda x: multi_area[x] == "@", multi_area))

multi_area[move(center, (-1, -1))] = "1"
multi_area[move(center, (0, -1))] = "#"
multi_area[move(center, (1, -1))] = "2"
multi_area[move(center, (-1, 0))] = "#"
multi_area[move(center, (0, 0))] = "#"
multi_area[move(center, (1, 0))] = "#"
multi_area[move(center, (-1, 1))] = "3"
multi_area[move(center, (0, 1))] = "#"
multi_area[move(center, (1, 1))] = "4"

#for y in range(81):
#    for x in range(81):
#        print(multi_area[(x,y)], end="")
#    print()


def multi_backtrack(start, end, paths, direction=(0, 0), steps=0, keys_needed="", visited=set()):
    tile = multi_area[start]
    
    if tile == "#" or start in visited:
        return False
    
    if paths and steps >= min(paths, key=lambda x: x["steps"])["steps"]:
        return False
    
    if start == end:
        paths.append({"steps": steps, "keys_needed": keys_needed})
        return True
    
    if tile.isalpha() and tile.isupper():
        keys_needed += tile.lower()
        
    origin = inverse_direction(direction)
    visited.add(start)
    
    for d in directions.values():
        if d == origin:
            continue
            
        multi_backtrack(move(start, d), end, paths, d, steps + 1, keys_needed, visited)
        
    visited.remove(start)


def multi_backtrack_nodes(current_nodes, keys="", steps=0):
    global memory

    # Use Memoization for incomplete paths!
    state = (tuple(current_nodes), tuple(sorted(multi_nodes_set - set(keys))))
    if state in memory:
        add_steps, add_keys = memory[state]
        steps += add_steps
        keys += add_keys
        
        return {"keys": keys, "steps": steps}
    
    new_keys = keys
    for node in current_nodes:
        if node not in new_keys:
            new_keys += node
    
    solutions = []
    for robot in range(4):
        node = current_nodes[robot]

        if len(new_keys) == multi_nodes_length:
            return {"keys": new_keys, "steps": steps}

        for n in nodes_reachable(multi_nodes[node], new_keys):
            new_steps = steps + multi_nodes[node]["paths"][n]["steps"]
            new_nodes = current_nodes[:]
            new_nodes[robot] = n

            solution = multi_backtrack_nodes(new_nodes, new_keys, new_steps)

            if solution:
                solutions.append(solution)
    
    if solutions:
        remaining_length = len(state[1])
        min_solution = min(solutions, key=lambda x: x["steps"])
        memory[state] = (min_solution["steps"] - steps, min_solution["keys"][-remaining_length:])
        
        return min_solution
    else:
        return False

multi_nodes = {tile: {"position": position, "paths": {}} for position, tile in multi_area.items() if (tile.isalpha() and tile.islower()) or tile.isdigit()}
multi_nodes_length = len(multi_nodes)
multi_nodes_set = set(multi_nodes)
    
for a, b in combinations(multi_nodes, 2):
    paths = []
    multi_backtrack(multi_nodes[a]["position"], multi_nodes[b]["position"], paths)
    
    if paths:
        path = min(paths, key=lambda x: x["steps"])
        multi_nodes[a]["paths"][b] = path
        multi_nodes[b]["paths"][a] = path

memory = {}
print(multi_backtrack_nodes(["1", "2", "3", "4"]))

{'keys': '1234uqejmzxwnaptioklyrvscfhdgb', 'steps': 1992}
CPU times: user 15.7 s, sys: 169 ms, total: 15.9 s
Wall time: 16.2 s
