In [1]:
from collections import defaultdict

def parse_input(input_file):
    grid = []
    with open(input_file) as f:
        for line in f:
            line = line.rstrip()
            grid.append([c for c in line])
    return grid

def find_start(grid):
    m, n = len(grid), len(grid[0])
    for i in range(m):
        for j in range(n):
            if grid[i][j] == '^':
                return (i, j)

chars = ['^', '>', 'v', '<']
directions = [(-1, 0), (0, 1), (1, 0), (0, -1)]

moves = {x: y for x, y in zip(chars, directions)}

def part1(input_file):
    grid = parse_input(input_file)
    start = find_start(grid)
    m, n = len(grid), len(grid[0])
    cur = start
    while True:
        i, j = cur
        di, dj = moves[grid[i][j]]
        ni, nj = i + di, j + dj
        if not(0 <= ni < m and 0 <= nj < n):
            break
        if grid[ni][nj] != '#':
            grid[ni][nj] = grid[i][j]
            cur = (ni, nj)
        else:
            idx = chars.index(grid[i][j])
            grid[i][j] = chars[(idx+ 1)%4]
    ans = 0
    for i in range(m):
        for j in range(n):
            if grid[i][j] not in '.#':
                ans += 1
    return ans
    


In [5]:
part1('input/day6_test.txt')

41

In [6]:
part1('input/day6.txt')

4665

In [14]:
chars = ['^', '>', 'v', '<']
directions = [(-1, 0), (0, 1), (1, 0), (0, -1)]

moves = {x: y for x, y in zip(chars, directions)}

def copy_grid(grid):
    ans = []
    for row in grid:
        ans.append(row.copy())
    return ans

def mark_grid(grid, start):
    cur = start
    m, n = len(grid), len(grid[0])
    while True:
        i, j = cur
        di, dj = moves[grid[i][j]]
        ni, nj = i + di, j + dj
        if not(0 <= ni < m and 0 <= nj < n):
            break
        if grid[ni][nj] != '#':
            grid[ni][nj] = grid[i][j]
            cur = (ni, nj)
        else:
            idx = chars.index(grid[i][j])
            grid[i][j] = chars[(idx+ 1)%4]

def explore(grid, start):
    cur = start
    m, n = len(grid), len(grid[0])
    cur_d = '^'
    grid[start[0]][start[1]] = '|'
    steps = 0
    while True:
        i, j = cur
        di, dj = moves[cur_d]
        ni, nj = i + di, j + dj
        if not(0 <= ni < m and 0 <= nj < n):
            return False
        if grid[ni][nj] !=  '#':
            steps += 1
            if steps > m * n:
                return True
            if cur_d in '^v':
                if grid[ni][nj] in '.':
                    grid[ni][nj] = '|'
                elif grid[ni][nj] == '-':
                    grid[ni][nj] = '+'
            else:
                if grid[ni][nj] == '.':
                    grid[ni][nj] = '-'
                elif grid[ni][nj] == '|':
                    grid[ni][nj] = '+'
            cur = (ni, nj)
        else:
            idx = chars.index(cur_d)
            cur_d = chars[(idx+ 1)%4]
            grid[i][j] = '+'

def part2(input_file):
    grid = parse_input(input_file)
    start = find_start(grid)
    ans = 0
    m, n = len(grid), len(grid[0])
    marked_grid = copy_grid(grid)
    mark_grid(marked_grid,start)
    for i in range(m):
        for j in range(n):
            if (i, j) == start:
                continue
            
            if marked_grid[i][j] in moves:
                new_grid = copy_grid(grid)
                new_grid[i][j] = '#'
                if explore(new_grid, start):
                    ans += 1
    return ans

def print_grid(grid):
    for row in grid:
        print(''.join(row))

In [15]:
part2('input/day6_test.txt')

6

In [16]:
part2('input/day6.txt')

1688