In [1]:
from collections import deque
from dataclasses import dataclass, field
from itertools import combinations, permutations
import json
from typing import Optional
import re

from pyprojroot import here

In [2]:
testInput = [
    'Valve AA has flow rate=0; tunnels lead to valves DD, II, BB',
    'Valve BB has flow rate=13; tunnels lead to valves CC, AA',
    'Valve CC has flow rate=2; tunnels lead to valves DD, BB,',
    'Valve DD has flow rate=20; tunnels lead to valves CC, AA, EE',
    'Valve EE has flow rate=3; tunnels lead to valves FF, DD',
    'Valve FF has flow rate=0; tunnels lead to valves EE, GG',
    'Valve GG has flow rate=0; tunnels lead to valves FF, HH',
    'Valve HH has flow rate=22; tunnel leads to valve GG',
    'Valve II has flow rate=0; tunnels lead to valves AA, JJ',
    'Valve JJ has flow rate=21; tunnel leads to valve II'
]

In [3]:
(13 + 2 + 20 + 3 + 22 + 21) * 26

2106

In [4]:
@dataclass
class Valve:
    valveId: str
    flowRate: int
    childrenIds: list[str]

In [5]:
def calcDistance(graph: dict[str, Valve], valve1: Valve, valve2: Valve) -> int:

    @dataclass
    class Node:
        valve: Valve
        distanceTraveled: int = 0
        valvesVisited: set[str] = field(default_factory=set)

    queue = deque([Node(valve1)])
    while queue:
        node = queue.popleft()

        # valve1 and valve 2 are the same
        if node.valve.valveId == valve2.valveId:
            return 0

        # cycled
        if node.valve.valveId in node.valvesVisited:
            continue

        # directly connected
        if valve2.valveId in node.valve.childrenIds:
            return node.distanceTraveled + 1

        # indirectly connected
        for child in node.valve.childrenIds:
            childNode = Node(
                valve = graph[child],
                distanceTraveled = node.distanceTraveled + 1,
                valvesVisited = node.valvesVisited.union([node.valve.valveId])
            )
            queue.append(childNode)

    return -1


def calcDistances(graph: dict[str, Valve]) -> dict[str, dict[str, int]]:
    distances = {}

    for valve1 in graph.values():
        for valve2 in graph.values():
            distances.setdefault(valve1.valveId, {})[valve2.valveId] = calcDistance(graph, valve1, valve2)

    return distances

In [30]:
@dataclass
class Node:
    destValveA: Valve
    destValveB: Valve
    destDistA: int
    destDistB: int
    remainingValveIds: set[str]
    flowRate: int = 0
    pressureReleased: int = 0
    currentMinute: int = 1


def releaseMaxPressure(
    graph: dict[str, Valve],
    distances: dict[str, dict[str, int]],
    minutesAllowed: int,
    startingValveId: str
    ) -> int:
        
    maxPressureReleased = 0
    remainingValveIds = set([valve.valveId for valve in graph.values()])
    valveIdCombos = combinations(remainingValveIds, 2)
    stack: list[Node] = []
    nodesVisited = 0
    maxFlowRate = 0

    for valveIdA, valveIdB in valveIdCombos:
        node = Node(
            destValveA = graph[valveIdA],
            destValveB = graph[valveIdB],
            destDistA = distances[startingValveId][valveIdA],
            destDistB = distances[startingValveId][valveIdB],
            remainingValveIds = remainingValveIds
        )
        stack.append(node)

    while stack:
        node = stack.pop()
        nodesVisited += 1

        # valves remaining
        remainingValveIds = node.remainingValveIds - set([node.destValveA.valveId, node.destValveB.valveId])

        # time is up
        if node.currentMinute == minutesAllowed + 1:
            maxPressureReleased = max(maxPressureReleased, node.pressureReleased)
            continue

        # release pressure
        node.pressureReleased += node.flowRate
        maxFlowRate = max(maxFlowRate, node.flowRate)
        node.currentMinute += 1

        # continue traveling A and continue traveling B
        if node.destDistA > 0 and node.destDistB > 0:
            node.destDistA -= 1
            node.destDistB -= 1
            stack.append(node)
        
        # continue traveling A and open valve B
        if node.destDistA > 0 and node.destDistB == 0:
            node.destDistA -= 1
            node.flowRate += node.destValveB.flowRate

            for valveId in remainingValveIds:
                child = Node(
                    destValveA = node.destValveA,
                    destValveB = graph[valveId],
                    destDistA = node.destDistA,
                    destDistB = distances[node.destValveB.valveId][valveId],
                    remainingValveIds = remainingValveIds,
                    flowRate = node.flowRate,
                    pressureReleased = node.pressureReleased,
                    currentMinute = node.currentMinute
                )
                stack.append(child)

        # open valve A and continue traveling B
        if node.destDistA == 0 and node.destDistB > 0:
            node.flowRate += node.destValveA.flowRate
            node.destDistB -= 1

            for valveId in remainingValveIds:
                child = Node(
                    destValveA = graph[valveId],
                    destValveB = node.destValveB,
                    destDistA = distances[node.destValveA.valveId][valveId],
                    destDistB = node.destDistB,
                    remainingValveIds = remainingValveIds,
                    flowRate = node.flowRate,
                    pressureReleased = node.pressureReleased,
                    currentMinute = node.currentMinute
                )
                stack.append(child)

        # open valve A and open valve B
        if node.destDistA == node.destDistB == 0:
            node.flowRate += node.destValveA.flowRate + node.destValveB.flowRate

            # there are at least two valves remaining
            if len(remainingValveIds) > 1:
                valveIdPerms = permutations(remainingValveIds, 2)
                # valveIdPerms = combinations(remainingValveIds, 2)

                for valveIdA, valveIdB in valveIdPerms:
                    child = Node(
                        destValveA = graph[valveIdA],
                        destValveB = graph[valveIdB],
                        destDistA = distances[node.destValveA.valveId][valveIdA],
                        destDistB = distances[node.destValveB.valveId][valveIdB],
                        remainingValveIds = remainingValveIds,
                        flowRate = node.flowRate,
                        pressureReleased = node.pressureReleased,
                        currentMinute = node.currentMinute
                    )
                    stack.append(node)

            # there is one valve remaining
            elif len(remainingValveIds) == 1:
                valveId = list(remainingValveIds)[0]
                node.remainingValveIds = set()
                destDistA = distances[node.destValveA.valveId][valveId]
                destDistB = distances[node.destValveB.valveId][valveId]

                # valveA is closest
                if destDistA < destDistB:
                    node.destValveA = graph[valveId]
                    node.destDistA = destDistA
                    node.destDistB -= 1
                
                # valveB is closest or as close as valveA
                else:
                    node.destValveB = graph[valveId]
                    node.destDistB = destDistB
                    node.destDistA -= 1

                stack.append(node)

            # all already valves opened
            else:
                node.destDistA -= 1
                node.destDistB -= 1
                stack.append(node)
        
        else:
            node.destDistA -= 1
            node.destDistB -= 1
            stack.append(node)
            continue
            
    return nodesVisited, maxFlowRate, maxPressureReleased

In [32]:
path = here('./16/input.txt')
with open(path, 'r') as fp:
    # lines = fp.readlines()
    lines = testInput

data = [(*re.findall('([A-Z]{2}|\d+)(?!,|$)', line), re.findall('([A-Z]{2})(?=,|$)', line)) for line in lines]
valvesGraph = {valveId: Valve(valveId, int(flowRate), childrenIds) for valveId, flowRate, childrenIds in data}
distances = calcDistances(valvesGraph)
startingValveId = 'AA'
targetGraph = {valveId: valve for valveId, valve in valvesGraph.items() if valve.flowRate > 0}

result = releaseMaxPressure(targetGraph, distances, 26, startingValveId)
print(result)

(20077, 1045, 13825)
