https://adventofcode.com/2021/day/15

In [3]:
%%time

from dataclasses import dataclass
from heapq import heappush, heappop

import numpy as np

with open('data/15.txt') as fh:
    data = fh.read()

@dataclass
class Pt:
    x: int
    y: int
    weight: int
    graph: dict
    total = np.inf
    prev = None
    
    def __lt__(self, other):
        return self.total < other.total
    
    def __eq__(self, other):
        return self.total == other.total
    
    def __repr__(self):
        return 'Pt(x=%s, y=%s, weight=%s, total=%s, prev=%s)' % (self.x, self.y, self.weight, self.total, self.prevcoord)
    
    @property
    def coord(self):
        return (self.x, self.y)
    
    @property
    def prevcoord(self):
        if self.prev is not None:
            return self.prev.coord
    
    @property
    def neighbors(self):
        for dx, dy in [(1, 0), (0, 1), (-1, 0), (0, -1)]:
            try:
                yield self.graph[(self.x + dx, self.y + dy)]
            except KeyError:
                pass


def load_data(data):
    g = {}
    for y, line in enumerate(data.split()):
        for x, c in enumerate(line):
            g[(x, y)] = Pt(x, y, int(c), g)
    return g

def dijkstra(start):
    start.total = 0
    q = []
    visited = set()
    heappush(q, start)
    visited.add(start.coord)
    while q:
        current = heappop(q)
        for node in current.neighbors:
            if current.total + node.weight < node.total:
                node.prev = current
                node.total = node.weight + current.total
            if node.coord not in visited:
                heappush(q, node)
                visited.add(node.coord)

                
g = load_data(data)
dijkstra(g[min(g)])

print('part_1 =', g[max(g)].total)

# Part 2

def wrap9(n):
    while n > 9:
        n -= 9
    return n    

def embiggen(g):
    xmax, ymax = max(g)
    tilewidth, tileheight = xmax+1, ymax+1
    for pt in list(g.values()):
        for d in range(1, 5):
            newpt = Pt(pt.x + (d * tilewidth), pt.y, wrap9(pt.weight + d), g)
            g[(newpt.x, newpt.y)] = newpt
    for pt in list(g.values()):
        for d in range(1, 5):
            newpt = Pt(pt.x, pt.y + (d * tileheight), wrap9(pt.weight + d), g)
            g[(newpt.x, newpt.y)] = newpt


g = load_data(data)
embiggen(g)
dijkstra(g[min(g)])

print('part_2 =', g[max(g)].total)


part_1 = 472
part_2 = 2851
CPU times: user 1.91 s, sys: 28.3 ms, total: 1.94 s
Wall time: 1.94 s
