In [101]:
import numpy as np
import itertools as it
import functools as ft
import sortedcontainers as sc
import copy
from tqdm.notebook import tqdm


In [5]:
def parse_input(fl) -> list[np.array]:
    grid = []
    with open(fl) as infile:
        for ln in infile.readlines():
            grid.append(list(ln.strip()))
    return np.array(grid)

In [206]:
def solve_p1(grid):
    nrows, ncols = grid.shape
    ixes = move_dirn(
        ixes=grid_to_ixes(grid),
        dirn='N',
        nrows=nrows,
        ncols=ncols,
        rebuild_grid=False,
    )['ixes']
    return _compute_load(ixes=ixes, nrows=nrows)

def solve_p2(grid, num_cycles):
    nrows, ncols = grid.shape
    ixes = grid_to_ixes(grid)
    state2lseen = {}
    nc_seen = 0
    period = None
    for c in tqdm(range(num_cycles), desc='Cycles'):
        nc_seen += 1
        ixes = cycle(
            ixes=ixes,
            nrows=nrows,
            ncols=ncols,
        )['ixes']
        state = tuple(sorted((ix[0], ix[1]) for ix in ixes if ix[2] == 'O'))
        if state not in state2lseen:
            # no period found yet
            state2lseen[state] = c
        else:
            # period found
            period = c - state2lseen[state]
            break
    cycles_to_go = num_cycles - (c+1)
    remainder_cycles = cycles_to_go % period
    for c in tqdm(range(remainder_cycles), desc='Rem cycles'):
        res = cycle(
            ixes=ixes,
            nrows=nrows,
            ncols=ncols,
        )
        ixes = res['ixes']
    return {
        "grid": _build_grid(ixes=ixes, nrows=nrows, ncols=ncols),
        "load": _compute_load(ixes=ixes, nrows=nrows)
    }

def _compute_load(ixes, nrows):
    # filter by `O` type and fetch the row indices
    #   ixes has rows of [r, c, sym] ex: [1, 3, '#']
    ixes = np.array(ixes)
    O_rix = ixes[ixes[:,2]=='O'][:,0].astype(np.int32)
    return (nrows-O_rix).sum()

def grid_to_ixes(grid):
    nrows, ncols = grid.shape
    O_ix = _get_sym_ixes(grid, sym='O', append_sym=True)
    H_ix = _get_sym_ixes(grid, sym='#', append_sym=True)
    # we want to mutate position of O_ix
    ixes = sorted(
        (
            # first and last col 
            [[r,-1, '-'] for r in range(nrows)]
            + [[r,ncols, '-'] for r in range(nrows)]
            # first and last row
            + [[-1, c, '-'] for c in range(ncols)]
            + [[nrows, c, '-'] for c in range(ncols)]
            # Fixed and round rocks
            + [list(hix) for hix in H_ix]
            + [list(oix) for oix in O_ix]
        ),
    )
    return ixes

def _get_sym_ixes(grid, sym, append_sym=False):
    rix, cix = np.where(grid==sym)
    if append_sym:
        return list(zip(rix, cix, [sym]*len(rix)))
    else:
        return list(zip(rix, cix))


def cycle(ixes, nrows, ncols, rebuild_grid=False, debug=False):
    move_fn = ft.partial(move_dirn, nrows=nrows, ncols=ncols, rebuild_grid=rebuild_grid)
    for d in ['N', 'W', 'S', 'E']:
        res = move_fn(ixes=ixes, dirn=d)
        ixes = res['ixes']
        if debug:
            print(f"After {d}")
            print(res['grid'])
            print('\n-----------\n')
    return res


def move_dirn(ixes, dirn, nrows, ncols, rebuild_grid=False):
    match dirn:
        case 'N':
            ixes = _move_dirn_helper(
                ixes=ixes,
                new_O_pos_fn=lambda prev_r, prev_c, r, c, sym: [prev_r+1, c, sym],
                # sort by column, then row, then symbol
                sort_fn=lambda x:(x[1], x[0], x[2]),
            )
        case 'W':
            ixes = _move_dirn_helper(
                ixes=ixes,
                new_O_pos_fn=lambda prev_r, prev_c, r, c, sym: [r, prev_c+1, sym],
                # sort by row, then col, then symbol
                sort_fn=lambda x:(x[0], x[1], x[2]),
            )
        case 'S':
            ixes = _move_dirn_helper(
                ixes=ixes,
                new_O_pos_fn=lambda prev_r, prev_c, r, c, sym: [prev_r-1, c, sym],
                # sort by column, then inv row, then symbol
                sort_fn=lambda x:(x[1], -x[0], x[2]),
            )
        case 'E':
            ixes = _move_dirn_helper(
                ixes=ixes,
                new_O_pos_fn=lambda prev_r, prev_c, r, c, sym: [r, prev_c-1, sym],
                # sort by row, then inv col, then symbol
                sort_fn=lambda x:(x[0], -x[1], x[2]),
            )
        case _: raise NotImplementedError()
    return {
        # ex:   ['8', '-1', '-']
        #       ['1', '0', 'O']
        "ixes": ixes,
        "grid": _build_grid(ixes, nrows, ncols) if rebuild_grid else None,
    }

def _move_dirn_helper(ixes, new_O_pos_fn, sort_fn):
    ixes = sorted(ixes, key=sort_fn)
    for i in range(1, len(ixes)):
        (prev_r, prev_c, prev_sym) = ixes[i-1]
        (r, c, sym) = ixes[i]
        if sym in {'-', '#'}:
            continue
        ixes[i] = new_O_pos_fn(prev_r=prev_r, prev_c=prev_c, r=r, c=c, sym=sym)
    return ixes

def _build_grid(ixes, nrows, ncols):
    new_grid = np.array([['.']*ncols for _ in range(nrows)])
    for r, c, sym in ixes:
        if sym == '-':
            continue
        # print(f"{r=}, {c=}, {sym=},{nrows=},{ncols=}")
        new_grid[r,c] = sym
    return '\n'.join((''.join(r)for r in new_grid))

In [201]:
test = parse_input("data/day14-test.txt")
inputs = parse_input("data/day14-input.txt")
test

array([['O', '.', '.', '.', '.', '#', '.', '.', '.', '.'],
       ['O', '.', 'O', 'O', '#', '.', '.', '.', '.', '#'],
       ['.', '.', '.', '.', '.', '#', '#', '.', '.', '.'],
       ['O', 'O', '.', '#', 'O', '.', '.', '.', '.', 'O'],
       ['.', 'O', '.', '.', '.', '.', '.', 'O', '#', '.'],
       ['O', '.', '#', '.', '.', 'O', '.', '#', '.', '#'],
       ['.', '.', 'O', '.', '.', '#', 'O', '.', '.', 'O'],
       ['.', '.', '.', '.', '.', '.', '.', 'O', '.', '.'],
       ['#', '.', '.', '.', '.', '#', '#', '#', '.', '.'],
       ['#', 'O', 'O', '.', '.', '#', '.', '.', '.', '.']], dtype='<U1')

In [202]:
solve_p1(grid=test), solve_p1(grid=inputs)

(136, 108813)

In [213]:
NUM_CYCLES = 1000_000_000

print('Test')
print('----')
res = solve_p2(grid=test, num_cycles=NUM_CYCLES)
print(f"\nAns: {res['load']} ")
print(res['grid'])

print('\n\nInput')
print('-----')
res = solve_p2(grid=inputs, num_cycles=NUM_CYCLES)
print(f"\nAns: {res['load']} ")
# print(res['grid'])

Test
----


Cycles:   0%|          | 0/1000000000 [00:00<?, ?it/s]

Rem cycles:   0%|          | 0/3 [00:00<?, ?it/s]


Ans: 64 
.....#....
....#...O#
.....##...
...#......
.....OOO#.
.O#...O#.#
....O#...O
......OOOO
#....###.O
#.OOO#..OO


Input
-----


Cycles:   0%|          | 0/1000000000 [00:00<?, ?it/s]

Rem cycles:   0%|          | 0/3 [00:00<?, ?it/s]


Ans: 104533 
