In [78]:
from utils import read_lines
import math
from collections import deque

NO_WIND = 0
NORTH = 1
SOUTH = 2
WEST = 4
EAST = 8

all_blizzard = [NORTH, SOUTH, WEST, EAST]

char_to_v = {
    '.': NO_WIND,
    '<': WEST,
    '>': EAST,
    '^': NORTH,
    'v': SOUTH,
}

deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]

def valid_moves(m, n, i, j):
    ans = [(i, j)]
    for di, dj in deltas:
        ni, nj = i + di, j + dj
        if 0 <= ni < m and 0<= nj < n:
            ans.append((ni, nj))
    return ans

def parse_input(input_file):
    lines = read_lines(input_file)
    matrix = []
    for line in lines[1:-1]:
        row = [char_to_v[c] for c in line[1:-1]]
        matrix.append(row)
    return matrix

def flow(m, n, i, j, v):

    match v:
        case 1:
            if i > 0:
                i -= 1
            else:
                i = m - 1
        case 2:
            if i < m - 1:
                i += 1
            else:
                i = 0
        case 4:
            if j > 0:
                j -= 1
            else:
                j = n - 1
        case 8:
            if j < n - 1:
                j += 1
            else:
                j = 0
        case _:
            raise ValueError(f'illegal blizzard value {v}')
    return i, j

def move_blizzard(matrix):
    m = len(matrix)
    n = len(matrix[0])
    ans = [[0] * n for _ in range(m)]
    for i, row in enumerate(matrix):
        for j, v in enumerate(row):
            for b in all_blizzard:
                if v & b:
                    ni, nj = flow(m, n, i, j, b)
                    ans[ni][nj] |= b
    return ans


def calc_all_matrix(matrix):
    ans = [matrix]
    while True:
        nm = move_blizzard(ans[-1])
        if nm == matrix:
            break
        ans.append(nm)
    return ans

def part1(input_file):
    matrix = parse_input(input_file)
    m = len(matrix)
    n = len(matrix[0])
    
    all_matrix = calc_all_matrix(matrix)
    peroid = len(all_matrix)
    start = (-1, 0)
    end = (m-1, n-1)

    seen = set([(start, 0)])
    q = deque()
    q.append(start)
    turn = 0
    while q:
        cur_len = len(q)
        turn += 1
        mi = turn % peroid
        mt = all_matrix[mi]
        for _ in range(cur_len):
            i, j = q.popleft()
            for ni, nj in valid_moves(m, n, i, j):
                if (ni, nj) == end:
                    return turn + 1
                if ((ni, nj), mi) not in seen and (ni == -1 or ni == m or mt[ni][nj] == 0):
                    q.append((ni, nj))
                    seen.add(((ni, nj), mi))


def part2(input_file):
    matrix = parse_input(input_file)
    m = len(matrix)
    n = len(matrix[0])
    
    all_matrix = calc_all_matrix(matrix)
    peroid = len(all_matrix)

    start = (-1, 0)
    end = (m-1, n-1)
    seen = set([(start, 0)])
    q = deque()
    q.append(start)
    turn = 0
    at_goal = False
    while q and not at_goal:
        cur_len = len(q)
        turn += 1
        mi = turn % peroid
        mt = all_matrix[mi]
        for _ in range(cur_len):
            i, j = q.popleft()
            for ni, nj in valid_moves(m, n, i, j):
                if (ni, nj) == end:
                    at_goal = True
                    break
                if ((ni, nj), mi) not in seen and (ni == -1 or ni == m or mt[ni][nj] == 0):
                    q.append((ni, nj))
                    seen.add(((ni, nj), mi))
    
    turn += 1
    print(turn)

    start = (m, n-1)
    end = (0, 0)
    seen = set([(start, turn%peroid)])
    q = deque()
    q.append(start)
    at_goal = False
    while q and not at_goal:
        cur_len = len(q)
        turn += 1
        mi = turn % peroid
        mt = all_matrix[mi]
        for _ in range(cur_len):
            i, j = q.popleft()
            for ni, nj in valid_moves(m, n, i, j):
                if (ni, nj) == end:
                    at_goal = True
                    break
                if ((ni, nj), mi) not in seen and (ni == -1 or ni == m or mt[ni][nj] == 0):
                    q.append((ni, nj))
                    seen.add(((ni, nj), mi))
    turn += 1
    print(turn)

    start = (-1, 0)
    end = (m-1, n-1)
    seen = set([(start, turn % peroid)])
    q = deque()
    q.append(start)
    at_goal = False
    while q and not at_goal:
        cur_len = len(q)
        turn += 1
        mi = turn % peroid
        mt = all_matrix[mi]
        for _ in range(cur_len):
            i, j = q.popleft()
            for ni, nj in valid_moves(m, n, i, j):
                if (ni, nj) == end:
                    at_goal = True
                    break
                if ((ni, nj), mi) not in seen and (ni == -1 or ni == m or mt[ni][nj] == 0):
                    q.append((ni, nj))
                    seen.add(((ni, nj), mi))
    turn += 1
    return turn

In [72]:
part1('inputs/day24_test.txt')

18

In [71]:
part1('inputs/day24.txt') + 1 # why need +1 here?

308

In [79]:
part2('inputs/day24_test.txt')

18
40


54

In [80]:
part2('inputs/day24.txt') + 1 

307
598


908