In [2]:
from collections import deque
from dataclasses import dataclass, field
from itertools import combinations, permutations
import json
from typing import Optional
import re

from pyprojroot import here

In [3]:
@dataclass
class Valve:
    valveId: str
    flowRate: int
    childrenIds: list[str]

In [4]:
def calcDistance(graph: dict[str, Valve], valve1: Valve, valve2: Valve) -> int:

    @dataclass
    class Node:
        valve: Valve
        distanceTraveled: int = 0
        valvesVisited: set[str] = field(default_factory=set)

    queue = deque([Node(valve1)])
    while queue:
        node = queue.popleft()

        # valve1 and valve 2 are the same
        if node.valve.valveId == valve2.valveId:
            return 0

        # cycled
        if node.valve.valveId in node.valvesVisited:
            continue

        # directly connected
        if valve2.valveId in node.valve.childrenIds:
            return node.distanceTraveled + 1

        # indirectly connected
        for child in node.valve.childrenIds:
            childNode = Node(
                valve = graph[child],
                distanceTraveled = node.distanceTraveled + 1,
                valvesVisited = node.valvesVisited.union([node.valve.valveId])
            )
            queue.append(childNode)

    return -1


def calcDistances(graph: dict[str, Valve]) -> dict[str, dict[str, int]]:
    distances = {}

    for valve1 in graph.values():
        for valve2 in graph.values():
            distances.setdefault(valve1.valveId, {})[valve2.valveId] = calcDistance(graph, valve1, valve2)

    return distances

In [5]:
def releaseMaxPressure(
    targetGraph: dict[str, Valve],
    distances: dict[str, dict[str, int]],
    minutesAllowed: int,
    startingValveId: str
    ) -> int:
    
    @dataclass
    class Node:
        destValve: Valve
        destDist: int
        remainingValveIds: set[str]
        flowRate: int = 0
        pressureReleased: int = 0
        currentMinute: int = 1
    
    maxPressureReleased = 0
    remainingValveIds = set([valve.valveId for valve in targetGraph.values()])
    stack = [Node(valve, distances[startingValveId][valveId], remainingValveIds) for valveId, valve in targetGraph.items()]

    while stack:
        node = stack.pop()

        # time is up
        if node.currentMinute == minutesAllowed + 1:
            maxPressureReleased = max(maxPressureReleased, node.pressureReleased)
            continue

        # release pressure
        node.pressureReleased += node.flowRate
        node.currentMinute += 1

        # continue traveling
        if node.destDist > 0:
            node.destDist -= 1
            stack.append(node)

        # open destination valve
        else:
            node.flowRate += node.destValve.flowRate
            remainingValveIds = node.remainingValveIds - set([node.destValve.valveId])

            for valve in [targetGraph[valveId] for valveId in remainingValveIds]:
                child = Node(
                    destValve = valve,
                    destDist = distances[node.destValve.valveId][valve.valveId],
                    remainingValveIds = remainingValveIds,
                    flowRate = node.flowRate,
                    pressureReleased = node.pressureReleased,
                    currentMinute = node.currentMinute,
                )
                stack.append(child)

    return maxPressureReleased

In [6]:
path = here('./16/input.txt')
with open(path, 'r') as fp:
    lines = fp.readlines()

data = [(*re.findall('([A-Z]{2}|\d+)(?!,|$)', line), re.findall('([A-Z]{2})(?=,|$)', line)) for line in lines]
valvesGraph = {valveId: Valve(valveId, int(flowRate), childrenIds) for valveId, flowRate, childrenIds in data}
distances = calcDistances(valvesGraph)
startingValveId = 'AA'
targetGraph = {valveId: valve for valveId, valve in valvesGraph.items() if valve.flowRate > 0}

result = releaseMaxPressure(targetGraph, distances, 30, startingValveId)
print(result)

1737
