In [None]:
from pathlib import Path
import os

In [None]:
fp = os.path.join(Path().absolute(), "inputs", "input16.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input16_test.txt")

with open(fp, "r") as f:
    data = f.read().split("\n")[:-1]

In [None]:
data

# Part 1

In [None]:
num_rows = len(data)
num_cols = len(data[0])
print(num_rows, num_cols)

In [None]:
def process_dot(loc_x, loc_y, dir):
    if dir == "N":
        next_loc_x = loc_x - 1
        next_loc_y = loc_y
    elif dir == "S":
        next_loc_x = loc_x + 1
        next_loc_y = loc_y
    elif dir == "W":
        next_loc_x = loc_x
        next_loc_y = loc_y - 1
    elif dir == "E":
        next_loc_x = loc_x
        next_loc_y = loc_y + 1
    
    next_dir = dir

    return (next_loc_x, next_loc_y, next_dir)


def process_forward_slash(loc_x, loc_y, dir):
    
    if dir == "N":
        next_loc_x = loc_x
        next_loc_y = loc_y + 1
        next_dir = "E"
    elif dir == "S":
        next_loc_x = loc_x
        next_loc_y = loc_y - 1
        next_dir = "W"
    elif dir == "W":
        next_loc_x = loc_x + 1
        next_loc_y = loc_y
        next_dir = "S"
    elif dir == "E":
        next_loc_x = loc_x - 1
        next_loc_y = loc_y
        next_dir = "N"

    return (next_loc_x, next_loc_y, next_dir)



def process_back_slash(loc_x, loc_y, dir):

    if dir == "N":
        next_loc_x = loc_x
        next_loc_y = loc_y - 1
        next_dir = "W"
    elif dir == "S":
        next_loc_x = loc_x
        next_loc_y = loc_y + 1
        next_dir = "E"
    elif dir == "W":
        next_loc_x = loc_x - 1
        next_loc_y = loc_y
        next_dir = "N"
    elif dir == "E":
        next_loc_x = loc_x + 1
        next_loc_y = loc_y
        next_dir = "S"

    return (next_loc_x, next_loc_y, next_dir)

In [None]:
def get_to_follow(loc_x, loc_y, dir):

    follow_candidates = []

    if data[loc_x][loc_y] == ".":
        follow_candidate = process_dot(loc_x, loc_y, dir)
        follow_candidates.append(follow_candidate)

    elif data[loc_x][loc_y] == "/":
        follow_candidate = process_forward_slash(loc_x, loc_y, dir)
        follow_candidates.append(follow_candidate)

    elif data[loc_x][loc_y] == "\\":
        follow_candidate = process_back_slash(loc_x, loc_y, dir)
        follow_candidates.append(follow_candidate)

    elif data[loc_x][loc_y] == "|":
        if dir == "N" or dir == "S":
            follow_candidate = process_dot(loc_x, loc_y, dir)
            follow_candidates.append(follow_candidate)
        elif dir == "W" or dir == "E":
            # split into two beams

            follow_candidate = process_forward_slash(loc_x, loc_y, dir)
            follow_candidates.append(follow_candidate)

            follow_candidate = process_back_slash(loc_x, loc_y, dir)
            follow_candidates.append(follow_candidate)

    elif data[loc_x][loc_y] == "-":
        if dir == "W" or dir == "E":
            follow_candidate = process_dot(loc_x, loc_y, dir)
            follow_candidates.append(follow_candidate)
        elif dir == "N" or dir == "S":
            # split into two beams

            follow_candidate = process_forward_slash(loc_x, loc_y, dir)
            follow_candidates.append(follow_candidate)

            follow_candidate = process_back_slash(loc_x, loc_y, dir)
            follow_candidates.append(follow_candidate)
    else:
        raise ValueError
    
    to_follow = []
    for follow_candidate in follow_candidates:
        next_loc_x, next_loc_y, next_dir = follow_candidate
        if 0 <= next_loc_x <= num_rows - 1 and 0 <= next_loc_y <= num_cols - 1:
            # valid
            if follow_candidate != (loc_x, loc_y, dir):
                to_follow.append(follow_candidate)
    
    return to_follow

In [None]:
to_follow_dict = {}

for loc_x in range(num_rows):
    for loc_y in range(num_cols):
        for dir in ["N", "S", "E", "W"]:
            key = (loc_x, loc_y, dir)
            to_follow = get_to_follow(loc_x, loc_y, dir)
            to_follow_dict[key] = to_follow

In [18]:
def find_followed_states(loc_x_initial, loc_y_initial, dir_initial):

    followed = set() # MUCH FASTER to use a set than a list
    to_follow = [(loc_x_initial, loc_y_initial, dir_initial)]

    while len(to_follow) > 0:
        current = to_follow.pop()
        followed.add(current)
        follow_cands = to_follow_dict[current]
        for follow_cand in follow_cands:
            if follow_cand not in to_follow and follow_cand not in followed:
                to_follow.append(follow_cand)
                    
    return followed

In [19]:
def find_followed_states_recursive(loc_x, loc_y, dir, res):
    
    state = (loc_x, loc_y, dir)

    # We use this initial loop to keep the recursion depth low
    while True:
        to_follow = to_follow_dict[state]

        if len(to_follow) == 0:
            return
        elif len(to_follow) == 1:
            state = to_follow[0]
            if state in res:
                return
            else:
                res.add(state)
        else:
            break

    for next_state in to_follow:
        if next_state not in res:
            res.add(next_state)
            find_followed_states_recursive(*next_state, res)

In [20]:
def find_num_energised_states(loc_x_initial, loc_y_initial, dir_initial, recursive=True):

    if recursive:
        state = (loc_x_initial, loc_y_initial, dir_initial)
        res = {state}
        find_followed_states_recursive(loc_x_initial, loc_y_initial, dir_initial, res)
        energised = set([(loc_x, loc_y) for loc_x, loc_y, dir in res])
    else:
        followed = find_followed_states(loc_x_initial, loc_y_initial, dir_initial)
        energised = set([(loc_x, loc_y) for loc_x, loc_y, dir in followed])

    return len(energised)

In [None]:
find_num_energised_states(0, 0, "E", recursive=True)

Recursion is quite a bit faster!

In [None]:
%%timeit
find_num_energised_states(0, 0, "E", recursive=True)

In [None]:
%%timeit
find_num_energised_states(0, 0, "E", recursive=False)

# Part 2

In [None]:
start_states = [
    (loc_x, 0, "E") for loc_x in range(num_rows)
    ] + [
        (loc_x, num_cols - 1, "W") for loc_x in range(num_rows)
        ] + [
            (0, loc_y, "S") for loc_y in range(num_cols)
            ] + [
                (num_rows - 1, loc_y, "N") for loc_y in range(num_cols)
                ]

In [None]:
len(start_states)

In [None]:
max_num_energised_states = -float("inf")
for idx, start_state in enumerate(start_states):
    num_energised_states = find_num_energised_states(*start_state, recursive=True)
    if num_energised_states > max_num_energised_states:
        max_num_energised_states = num_energised_states

In [None]:
print(max_num_energised_states)