In [35]:
from utils import read_lines
import re
from collections import deque

RE = re.compile(r'Valve (\w\w) .*rate=(\d+);.* valves? (.*+)')

def parse_line(line):
    m = RE.match(line)
    node = m.group(1)
    rate = int(m.group(2))
    dest = m.group(3).split(', ')
    return node, rate, dest

def parse_input(input_file):
    lines = read_lines(input_file)
    rates = {}
    graph = {}
    for line in lines:
        node, rate, dest = parse_line(line)
        rates[node] = rate
        graph[node] = dest
    return rates, graph

def shorted_path(graph):
    ans = {}
    for start in graph:
        dist = {}
        q = deque([start])
        steps = 0
        while q and len(dist) < len(graph):
            cur_len = len(q)
            for _ in range(cur_len):
                node = q.popleft()
                dist[node] = steps
                for next_node in graph[node]:
                    if next_node not in dist:
                        q.append(next_node)
            steps += 1
        ans[start] = dist
    return ans

In [82]:
from functools import cache

def part1(input_file):
    rates, graph = parse_input(input_file)
    min_moves = shorted_path(graph)
    # print(rates)
    # print(graph)
    # print(min_moves)
    value_nodes = {x for x, v in rates.items() if v > 0}
    ans = 0
    
    def backtrack(cur_node, opened_value, turn, cur_score):
        nonlocal ans
        if turn >= 30:
            ans = max(ans, cur_score)
            return
        if not value_nodes:
            ans = max(ans, cur_score + opened_value * (30 - turn))
            return
        for node in value_nodes:
            to_open =  min_moves[cur_node][node] + 1
            if turn + to_open < 30:
                value_nodes.remove(node)
                cur_score += opened_value * to_open
                opened_value += rates[node]
                backtrack(node, opened_value, turn+to_open, cur_score)
                value_nodes.add(node)
                opened_value -= rates[node]
                cur_score -= opened_value * to_open
            else:
                ans = max(ans, cur_score + opened_value * (30 - turn))


    backtrack('AA', 0, 0, 0)
    return ans


    

In [83]:
part1('inputs/day16.txt')

1460

In [84]:
part1('inputs/day16_test.txt')

1651

In [103]:
def part2(input_file):
    rates, graph = parse_input(input_file)
    min_moves = shorted_path(graph)
    # print(rates)
    # print(graph)
    # print(min_moves)
    value_nodes = {x for x, v in rates.items() if v > 0}
    MAX_TURN = 26
    best_score = 0
    remain_valves = set() 
    def backtrack(cur_node, opened_value, turn, cur_score):
        nonlocal best_score
        nonlocal remain_valves
        if turn >= MAX_TURN:
            if cur_score > best_score:
                best_score = cur_score
                remain_valves = {x for x in value_nodes}
            return
        if not value_nodes:
            best_score = max(best_score, cur_score + opened_value * (MAX_TURN - turn))
            remain_valves = set()
            return
        for node in value_nodes:
            to_open =  min_moves[cur_node][node] + 1
            if turn + to_open < MAX_TURN:
                value_nodes.remove(node)
                cur_score += opened_value * to_open
                opened_value += rates[node]
                backtrack(node, opened_value, turn+to_open, cur_score)
                value_nodes.add(node)
                opened_value -= rates[node]
                cur_score -= opened_value * to_open
            else:
                best_score = max(best_score, cur_score + opened_value * (MAX_TURN - turn))
                remain_valves = {x for x in value_nodes}


    backtrack('AA', 0, 0, 0)
    return best_score, remain_valves

In [104]:
part2('inputs/day16.txt')

(1125, {'AM', 'DB', 'FP', 'GU', 'KT', 'OG', 'OX', 'XN', 'YK', 'YS', 'ZR'})