In [244]:
data = """2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533
"""
data = open('puzzle.data').read()

from helper import *
from dataclasses import dataclass
from heapq import heappop, heappush

@dataclass
class Entry:
    loss: int
    pos: complex
    dir: complex
    path: list[complex]

    def __lt__(self, other: 'Entry'):
        return self.loss <= other.loss

def extend(grid: Grid, entry: Entry, straight_range: tuple[int, int]) -> list[Entry]:
    open_list = []
    if entry.pos + entry.dir * straight_range[0] not in grid:
        return open_list
    loss = entry.loss + sum(grid[entry.pos + entry.dir * i] for i in range(1, straight_range[0]))
    for i in range(*straight_range):
        next_pos = entry.pos + entry.dir * i
        if next_pos not in grid:
            break
        loss += grid[next_pos]
        for dir in [TURN_RIGHT[entry.dir], TURN_LEFT[entry.dir]]:
            if next_pos + dir in grid:
                open_list.append(Entry(loss, next_pos, dir, entry.path + [next_pos]))
    return open_list


def solve(data: str, straight_range: tuple[int, int]) -> int:
    grid = Grid.from_str(data, parse=lambda c: int(c))
    end_pos = complex(grid.width - 1, grid.height - 1)
    visited = set()
    
    open_list = extend(grid, Entry(0, 0j, RIGHT, []), straight_range)

    while open_list:
        entry = heappop(open_list)

        visited_key = (entry.pos, entry.dir)
        if visited_key in visited:
            continue
        
        visited.add(visited_key)
        
        if entry.pos == end_pos:
            return entry.loss
        
        for entry in extend(grid, entry, straight_range):
            heappush(open_list, entry)

solve(data, straight_range=(1, 4))

Entry(loss=847, pos=(140+140j), dir=-1j, path=[(1+0j), (1+2j), 2j, 5j, (1+5j), (1+8j), (2+8j), (2+11j), (3+11j), (3+13j), (2+13j), (2+14j), (1+14j), (1+17j), 17j, 20j, (1+20j), (1+21j), 21j, 24j, (1+24j), (1+26j), (2+26j), (2+29j), (1+29j), (1+31j), 31j, 34j, (1+34j), (1+35j), (2+35j), (2+37j), (3+37j), (3+40j), (2+40j), (2+43j), (1+43j), (1+45j), (2+45j), (2+47j), (3+47j), (3+49j), (2+49j), (2+50j), (1+50j), (1+53j), 53j, 55j, (1+55j), (1+58j), (2+58j), (2+61j), (3+61j), (3+64j), (4+64j), (4+67j), (5+67j), (5+70j), (3+70j), (3+73j), (2+73j), (2+76j), (1+76j), (1+79j), 79j, 82j, (1+82j), (1+84j), 84j, 87j, (1+87j), (1+90j), (2+90j), (2+93j), (3+93j), (3+96j), (4+96j), (4+99j), (5+99j), (5+101j), (6+101j), (6+104j), (8+104j), (8+106j), (9+106j), (9+107j), (10+107j), (10+110j), (11+110j), (11+112j), (12+112j), (12+115j), (13+115j), (13+118j), (15+118j), (15+121j), (16+121j), (16+124j), (17+124j), (17+126j), (19+126j), (19+128j), (22+128j), (22+129j), (23+129j), (23+132j), (26+132j), (26+

In [245]:
solve(data, (4, 11))

Entry(loss=997, pos=(140+140j), dir=(-1+0j), path=[(6+0j), (6+4j), (15+4j), (15+8j), (19+8j), (19+3j), (27+3j), (27+7j), (34+7j), (34+1j), (43+1j), (43+5j), (52+5j), (52+1j), (61+1j), (61+5j), (69+5j), (69+9j), (79+9j), (79+4j), (89+4j), (89+0j), (99+0j), (99+4j), (107+4j), (107+8j), (116+8j), (116+13j), (122+13j), (122+17j), (126+17j), (126+25j), (131+25j), (131+34j), (136+34j), (136+43j), (140+43j), (140+53j), (136+53j), (136+63j), (140+63j), (140+73j), (134+73j), (134+83j), (138+83j), (138+92j), (134+92j), (134+101j), (138+101j), (138+111j), (134+111j), (134+116j), (139+116j), (139+124j), (135+124j), (135+134j), (140+134j), (140+140j)])