In [314]:
from tqdm import tqdm
import numpy as np
import heapq as heap

with open("data.txt") as file:
    data = file.read()

data = np.array([list(map(int,list(line))) for line in data.splitlines()])
dim = data.shape
traffic_map = np.tile([data], (4, 3, 1, 1))
traffic_map_2 = np.tile([data], (2, 4, 11, 1, 1))

In [315]:
dirs = [(0,1), (0,-1), (1, 0), (-1,0)]

def get_neighbours(dir, steps, y, x):

    neighs = []
    current_dir = dirs[dir]

    if steps < 2:
        # Possible to continue
        neighs.append((dir, steps + 1, y + current_dir[0], x + current_dir[1]))
    
    turn_1 = (current_dir[1], current_dir[0])
    turn_2 = (-current_dir[1], -current_dir[0])
    
    neighs.append((dirs.index(turn_1), 0, y + turn_1[0], x + turn_1[1]))
    neighs.append((dirs.index(turn_2), 0, y + turn_2[0], x + turn_2[1]))

    neighs = [n for n in neighs if n[2] >= 0 and n[2] < dim[0]]
    neighs = [n for n in neighs if n[3] >= 0 and n[3] < dim[1]]

    return set(neighs)

def search(data, neighbour_func):

    costs = np.ones(data.shape) * np.inf
    costs[0,0,0,0] = 0
    costs[2,0,0,0] = 0
    visited = set()

    pq = []

    heap.heappush(pq, (0, (0,0,0,0)))
    heap.heappush(pq, (0, (2,0,0,0)))

    while pq:

        _, node = heap.heappop(pq)
        visited.add(node)

        for adj_node in neighbour_func(*node):
            if adj_node in visited: continue
            newCost = costs[node] + data[adj_node]

            if costs[adj_node] > newCost:
                costs[adj_node] = newCost
                heap.heappush(pq, (newCost, adj_node))


    return costs

c = search(traffic_map, get_neighbours)
c[:, :, -1, -1].min()

859.0

In [322]:
def get_neighbours_2(crazy, dir, steps, y, x):

    neighs = []
    current_dir = dirs[dir]

    if steps < 10:
        new_crazy = 0 if (steps + 1) > 3 else 1
        neighs.append((new_crazy, dir, steps + 1, y + current_dir[0], x + current_dir[1]))
    
    if crazy == 0:
        turn_1 = (current_dir[1], current_dir[0])
        turn_2 = (-current_dir[1], -current_dir[0])
        
        neighs.append((1, dirs.index(turn_1), 1, y + turn_1[0], x + turn_1[1]))
        neighs.append((1, dirs.index(turn_2), 1, y + turn_2[0], x + turn_2[1]))

    neighs = [n for n in neighs if n[3] >= 0 and n[3] < dim[0]]
    neighs = [n for n in neighs if n[4] >= 0 and n[4] < dim[1]]

    return set(neighs)

def search(data, neighbour_func):

    costs = np.ones(data.shape) * np.inf
    costs[1,0,0,0,0] = 0
    costs[1,2,0,0,0] = 0
    visited = set()

    pq = []

    heap.heappush(pq, (0, (1,0,0,0,0)))
    heap.heappush(pq, (0, (1,2,0,0,0)))

    while pq:

        _, node = heap.heappop(pq)
        visited.add(node)

        for adj_node in neighbour_func(*node):
            #print(adj_node)
            if adj_node in visited: continue
            newCost = costs[node] + data[adj_node]

            if costs[adj_node] > newCost:
                costs[adj_node] = newCost
                heap.heappush(pq, (newCost, adj_node))


    return costs

c = search(traffic_map_2, get_neighbours_2)
c[0, :, :,  -1, -1].min()

1027.0