In [None]:
from pathlib import Path
import os

In [None]:
fp = os.path.join(Path().absolute(), "inputs", "input21.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input21_test.txt")

with open(fp, "r") as f:
    data = f.read().split("\n")[:-1]

In [None]:
data

# Part 1

In [None]:
# Find distance from S to all garden plots
# Can reach exactly those plots which are at most 64 steps away and at an even distance (all paths to a given destination are odd or all paths are even distance)

In [None]:
num_rows = len(data)
num_cols = len(data[0])
print(num_rows, num_cols)

In [None]:
start_found = False
for i, row in enumerate(data):
    if start_found:
        break
    for j, cell in enumerate(row):
        if cell == "S":
            start_found = True
            start_x = i
            start_y = j
            break

start = (start_x, start_y)
print(start_x, start_y)

In [None]:
all_plots = []
for i, row in enumerate(data):
    for j, cell in enumerate(row):
        if cell in [".", "S"]:
            all_plots.append((i, j))

In [None]:
def get_neighbours(point, data):

    num_rows = len(data)
    num_cols = len(data[0])

    x, y = point

    neighbours = []
    if x > 0:
        cand = (x - 1, y)
        if data[cand[0]][cand[1]] in [".", "S"]:
            neighbours.append(cand)
    if x < num_rows - 1:
        cand = (x + 1, y)
        if data[cand[0]][cand[1]] in [".", "S"]:
            neighbours.append(cand)
    if y > 0:
        cand = (x, y - 1)
        if data[cand[0]][cand[1]] in [".", "S"]:
            neighbours.append(cand)
    if y < num_cols - 1:
        cand = (x, y + 1)
        if data[cand[0]][cand[1]] in [".", "S"]:
            neighbours.append(cand)

    return neighbours

In [None]:
neighbours_dict = {}
for x in range(num_rows):
    for y in range(num_cols):
        point = (x, y)
        neighbours = get_neighbours(point, data)
        neighbours_dict[point] = neighbours

In [None]:
def dijkstra(start, all_plots, neighbours_dict, verbose=False):
    """Returns distances from start to any node"""

    shortest_path_distance_all_plots = {plot: float("inf") for plot in all_plots}
    shortest_path_distance_all_plots[start] = 0

    current = start
    to_expand = [start]

    max_num_iter = 1_000_0000
    num_iter = 0
    while num_iter < max_num_iter and len(to_expand) > 0:
        if verbose:
            print(f"{num_iter = }, {len(to_expand) = }")

        to_expand = sorted(to_expand, key=shortest_path_distance_all_plots.get, reverse=True)
        current = to_expand.pop()

        current_cost = shortest_path_distance_all_plots[current]

        neighbours = neighbours_dict[current]
        for neighbour in neighbours:
            if current_cost + 1 < shortest_path_distance_all_plots[neighbour]:
                shortest_path_distance_all_plots[neighbour] = current_cost + 1
                to_expand.append(neighbour)

        num_iter += 1

    return shortest_path_distance_all_plots

In [None]:
shortest_path_distance_all_plots = dijkstra(start, all_plots, neighbours_dict, verbose=True)

In [None]:
num_steps_exactly = 64
poss = {plot: dist for plot, dist in shortest_path_distance_all_plots.items() if dist <= num_steps_exactly and dist % 2 == 0}
num_poss = len(poss)
num_poss

# Part 2

Following the solution here, however there is a bug somewhere.
https://www.reddit.com/r/adventofcode/comments/18o4y0m/2023_day_21_part_2_algebraic_solution_using_only/

It is also possible to solve this using quadratic interpolation: On infinite grid, how many plots can be reached in 65, 65 + 131, and 65 + 2 * 131 steps? General solution in M (see below) is of quadratic form.

In [None]:
for y in range(num_cols):
    assert data[0][y] == "."

for y in range(num_cols):
    assert data[num_rows - 1][y] == "."

for x in range(num_rows):
    assert data[x][0] == "."

for x in range(num_rows):
    assert data[x][num_cols - 1] == "."

In [None]:
# Build 3x3 tiling of original map
num_rows_tiling = 3 * num_rows
num_cols_tiling = 3 * num_cols

tiling = [[None for _ in range(num_rows_tiling)] for __ in range(num_cols_tiling)]

for i in range(num_rows_tiling):
    for j in range(num_cols_tiling):
        symbol = data[i % num_rows][j % num_cols]
        if symbol == "S":
            symbol = "."
        tiling[i][j] = symbol

tiling = ["".join(row) for row in tiling]
tiling

In [None]:
start_tiling_x = num_rows + start_x
start_tiling_y = num_cols + start_y
start_tiling = (start_tiling_x, start_tiling_y)
print(start_tiling)

In [None]:
all_plots_tiling = []
for i, row in enumerate(tiling):
    for j, cell in enumerate(row):
        if cell in [".", "S"]:
            all_plots_tiling.append((i, j))

In [None]:
neighbours_dict_tiling = {}
for x in range(num_rows_tiling):
    for y in range(num_cols_tiling):
        point = (x, y)
        neighbours = get_neighbours(point, tiling)
        neighbours_dict_tiling[point] = neighbours

In [None]:
# Note there are some inaccessible plots
{k: v for k, v in neighbours_dict_tiling.items() if len(v) == 0}

In [None]:
shortest_path_distance_all_plots_tiling = dijkstra(start_tiling, all_plots_tiling, neighbours_dict_tiling, verbose=True)

In [None]:
len({k: v for k, v in shortest_path_distance_all_plots_tiling.items() if v > 1000})

In [None]:
len({k: v for k, v in shortest_path_distance_all_plots.items() if v > 1000})

In [None]:
N = 26501365
# N = 5000 # EXAMPLE MAP

In [None]:
h = (num_rows - 1) / 2
h

In [None]:
M = (N - h) / num_rows
M

In [None]:
def is_offset(i):
    """Checks whether location i lies in the middle of the 3x3 tiling (horizontally or vertically)"""
    
    return i < num_rows or i > 2 * num_rows - 1

In [None]:
T = 0
A = 0
B = 0
E = 0
O = 0

for i in range(num_rows_tiling):
    for j in range(num_cols_tiling):
        point = (i, j)
        if tiling[i][j] not in ["S", "."]:
            continue

        dist = shortest_path_distance_all_plots_tiling[point]
        if dist == float("inf"):
            continue

        i_offset = is_offset(i)
        j_offset = is_offset(j)

        if (i_offset and not j_offset) or (not i_offset and j_offset):
            if dist <= h + num_rows and dist % 2 == 1:
                T += 1

            if dist <= h + 2 * num_rows and dist % 2 == 0:
                E += 1

        if i_offset and j_offset:
            if dist <= h + 2 * num_rows and dist % 2 == 0:
                A += 1

            if dist <= h + num_rows and dist % 2 == 1:
                B += 1

        if not i_offset and not j_offset:
            if dist <= h + 2 * num_rows and dist % 2 == 0:
                O += 1

E = E / 4

print(f"{T = }")
print(f"{A = }")
print(f"{B = }")
print(f"{E = }")
print(f"{O = }")

In [None]:
total_alt = (M - 1) ** 2 * O + M ** 2 * E + (M - 1) * A + M * B + T
total_alt

The true value should be 600090522932119 (by running one of the published solutions).

Alternative calculation using 5x5 tiling, however same buggy result

In [None]:
num_rows_tiling_large = 5 * num_rows
num_cols_tiling_large = 5 * num_cols

tiling_large = [[None for _ in range(num_rows_tiling_large)] for __ in range(num_cols_tiling_large)]

for i in range(num_rows_tiling_large):
    for j in range(num_cols_tiling_large):
        symbol = data[i % num_rows][j % num_cols]
        if symbol == "S":
            symbol = "."
        tiling_large[i][j] = symbol

tiling_large = ["".join(row) for row in tiling_large]

In [None]:
start_tiling_large_x = 2 * num_rows + start_x
start_tiling_large_y = 2 * num_cols + start_y
start_tiling_large = (start_tiling_large_x, start_tiling_large_y)
print(start_tiling_large)

In [None]:
all_plots_tiling_large = []
for i, row in enumerate(tiling_large):
    for j, cell in enumerate(row):
        if cell in [".", "S"]:
            all_plots_tiling_large.append((i, j))

In [None]:
neighbours_dict_tiling_large = {}
for x in range(num_rows_tiling_large):
    for y in range(num_cols_tiling_large):
        point = (x, y)
        neighbours = get_neighbours(point, tiling_large)
        neighbours_dict_tiling_large[point] = neighbours

In [None]:
shortest_path_distance_all_plots_tiling_large = dijkstra(start_tiling_large, all_plots_tiling_large, neighbours_dict_tiling_large, verbose=True)

In [None]:
T = 0
A = 0
B = 0
E = 0
O = 0

for i in range(num_rows_tiling_large):
    for j in range(num_cols_tiling_large):
        point = (i, j)
        if tiling_large[i][j] not in ["S", "."]:
            continue

        dist = shortest_path_distance_all_plots_tiling_large[point]
        if dist == float("inf"):
            continue

        assert isinstance(dist, int)

        if dist % 2 != 0 or dist > h + 2 * num_rows:
            continue

        i_idx = i // num_rows
        j_idx = j // num_cols

        if i_idx == 2 and j_idx == 2:
            O += 1
        elif (i_idx in [1, 3] and j_idx == 2) or (j_idx in [1, 3] and i_idx == 2):
            E += 1
        elif i_idx in [1, 3] and j_idx in [1, 3]:
            A += 1
        elif (i_idx in [0, 4] and j_idx == 2) or (j_idx in [0, 4] and i_idx == 2):
            T += 1
        elif i_idx not in [0, 4] or j_idx not in [0, 4]:
            B += 1

B = B / 2
E = E / 4

print(f"{T = }")
print(f"{A = }")
print(f"{B = }")
print(f"{E = }")
print(f"{O = }")

In [None]:
total_alt = (M - 1) ** 2 * O + M ** 2 * E + (M - 1) * A + M * B + T
total_alt