In [16]:
from utils import read_lines
from collections import defaultdict, deque

def parse_input(input_file):
    lines = read_lines(input_file)
    # for line in lines:
    #     print(line)
    m = len(lines)
    n = len(lines[0])
    labels = defaultdict(list)
    for j in range(n):
        if lines[0][j] != ' ':
            label = lines[0][j] + lines[1][j]
            labels[label].append((2, j))
        if lines[m-1][j] != ' ':
            label = lines[m-2][j] + lines[m-1][j]
            labels[label].append((m-3, j))
    for i in range(m):
        if lines[i][0] != ' ':
            label = lines[i][0] + lines[i][1]
            labels[label].append((i, 2))
        if lines[i][n-1] != ' ':
            label = lines[i][n-2] + lines[i][n-1]
            labels[label].append((i, n-3))
    for i in range(2, m-2):
        for j in range(2, n-2):
            if 'A' <= lines[i][j] <= 'Z':
                if lines[i-1][j] == '.' and 'A' <= lines[i+1][j] <= 'Z':
                    label = lines[i][j] + lines[i+1][j]
                    labels[label].append((i-1, j))
                elif lines[i+2][j] == '.' and 'A' <= lines[i+1][j] <= 'Z':
                    label = lines[i][j] + lines[i+1][j]
                    labels[label].append((i+2, j))
                elif lines[i][j+2] == '.' and 'A' <= lines[i][j+1] <= 'Z':
                    label = lines[i][j] + lines[i][j+1]
                    labels[label].append((i, j+2))
                elif lines[i][j-1] == '.' and 'A' <= lines[i][j+1] <= 'Z':
                    label = lines[i][j] + lines[i][j+1]
                    labels[label].append((i, j-1))
    return lines, labels

deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]

def part1(input_file):
    matrix, labels = parse_input(input_file)
    start = labels['AA'][0]
    end = labels['ZZ'][0]
    portals = {}
    for _, points in labels.items():
        if len(points) == 2:
            p1, p2 = points
            portals[p1] = p2
            portals[p2] = p1
    
    m, n = len(matrix), len(matrix[0])
    visited = set([start])
    q = deque([start])
    step = 0
    while q:
        cur_len = len(q)
        for _ in range(cur_len):
            x, y = q.popleft()
            if (x, y) == end:
                return step
            for dx, dy in deltas:
                nx, ny = x+dx, y+dy
                if 0 <= nx < m and 0 <= ny < n and matrix[x][y] == '.' and (nx, ny) not in visited:
                    q.append((nx, ny))
                    visited.add((nx, ny))
            if (x, y) in portals:
                nx, ny = portals[(x, y)]
                if (nx, ny) not in visited:
                    q.append((nx, ny))
                    visited.add((nx, ny))
        step += 1
    

    

In [15]:
matrix, labels = parse_input('inputs/day20_test.txt')
print(len(matrix), len(matrix[0]))
print(labels)

37 35
defaultdict(<class 'list'>, {'BU': [(34, 11), (21, 26)], 'JP': [(34, 15), (28, 21)], 'AA': [(2, 19)], 'CP': [(34, 19), (8, 21)], 'VT': [(11, 32), (23, 26)], 'DI': [(15, 2), (21, 8)], 'ZZ': [(17, 2)], 'AS': [(17, 32), (8, 17)], 'JO': [(19, 2), (28, 13)], 'LF': [(21, 32), (28, 15)], 'YN': [(23, 2), (13, 26)], 'QG': [(23, 32), (17, 26)]})


In [17]:
part1('inputs/day20_test.txt')

58

In [18]:
part1('inputs/day20.txt')

596

In [19]:

def part2(input_file):
    matrix, labels = parse_input(input_file)
    start = labels['AA'][0]
    end = labels['ZZ'][0]
    portals = {}
    for _, points in labels.items():
        if len(points) == 2:
            p1, p2 = points
            portals[p1] = p2
            portals[p2] = p1
    
    m, n = len(matrix), len(matrix[0])
    visited = set((start, 0))
    q = deque([(start, 0)])
    step = 0

    def is_outer(x, y):
        return x in (2, m-3) or y in (2, n-3)
    
    while q:
        cur_len = len(q)
        for _ in range(cur_len):
            (x, y), level = q.popleft()
            if (x, y) == end and level == 0:
                return step
            for dx, dy in deltas:
                nx, ny = x+dx, y+dy
                if 0 <= nx < m and 0 <= ny < n and matrix[x][y] == '.' and ((nx, ny), level) not in visited:
                    q.append(((nx, ny), level))
                    visited.add(((nx, ny), level))
            if (x, y) in portals and (level > 0 or not is_outer(x, y)):
                nx, ny = portals[(x, y)]
                if is_outer(x, y):
                    n_level = level - 1
                else:
                    n_level = level + 1
                if ((nx, ny), n_level) not in visited:
                    q.append(((nx, ny), n_level))
                    visited.add(((nx, ny), n_level))
        step += 1

In [20]:
part2('inputs/day20_test2.txt')

396

In [21]:
part2('inputs/day20.txt')

7610