In [5]:
from utils import read_lines
import re
from collections import deque

RE = re.compile(r'Valve (\w\w) .*rate=(\d+);.* valves? (.*)')

def parse_line(line):
    m = RE.match(line)
    node = m.group(1)
    rate = int(m.group(2))
    dest = m.group(3).split(', ')
    return node, rate, dest

def parse_input(input_file):
    lines = read_lines(input_file)
    rates = {}
    graph = {}
    for line in lines:
        node, rate, dest = parse_line(line)
        rates[node] = rate
        graph[node] = dest
    return rates, graph

def shorted_path(graph):
    ans = {}
    for start in graph:
        dist = {}
        q = deque([start])
        steps = 0
        while q and len(dist) < len(graph):
            cur_len = len(q)
            for _ in range(cur_len):
                node = q.popleft()
                dist[node] = steps
                for next_node in graph[node]:
                    if next_node not in dist:
                        q.append(next_node)
            steps += 1
        ans[start] = dist
    return ans

In [2]:
from functools import cache

def part1(input_file):
    rates, graph = parse_input(input_file)
    min_moves = shorted_path(graph)
    # print(rates)
    # print(graph)
    # print(min_moves)
    value_nodes = {x for x, v in rates.items() if v > 0}
    ans = 0
    
    def backtrack(cur_node, opened_value, turn, cur_score):
        nonlocal ans
        if turn >= 30:
            ans = max(ans, cur_score)
            return
        if not value_nodes:
            ans = max(ans, cur_score + opened_value * (30 - turn))
            return
        for node in value_nodes:
            to_open =  min_moves[cur_node][node] + 1
            if turn + to_open < 30:
                value_nodes.remove(node)
                cur_score += opened_value * to_open
                opened_value += rates[node]
                backtrack(node, opened_value, turn+to_open, cur_score)
                value_nodes.add(node)
                opened_value -= rates[node]
                cur_score -= opened_value * to_open
            else:
                ans = max(ans, cur_score + opened_value * (30 - turn))
    backtrack('AA', 0, 0, 0)
    return ans

In [6]:
part1('inputs/day16.txt')

1460

In [7]:
part1('inputs/day16_test.txt')

1651

In [22]:
def max_score(rates, min_moves, nodes, max_turn):
    ans = 0
    
    def calc_score(path):
        ans = 0
        for valve, day in path:
            ans += rates[valve] * (max_turn - day)
        return ans
    
    def backtrack(path):
        nonlocal ans
        if not nodes:
            score = calc_score(path)
            ans = max(ans, score)
            return
        found_new_values = False
        for node in nodes:
            cur_node, day = path[-1]
            open_day = day + min_moves[cur_node][node] + 1
            if open_day < max_turn:
                path.append((node, open_day))
                nodes.remove(node)
                found_new_values = True
                backtrack(path)
                path.pop()
                nodes.add(node)
            
        if not found_new_values:
            score = calc_score(path)
            ans = max(ans, score)
            return

    backtrack([('AA', 0)])
    return ans

def part1_1(input_file):
    rates, graph = parse_input(input_file)
    min_moves = shorted_path(graph)
    # print(rates)
    # print(graph)
    # print(min_moves)
    nodes = {x for x, v in rates.items() if v > 0}
    max_turn = 30
    return max_score(rates, min_moves, nodes, max_turn)

# wrong answer
def part2(input_file):
    rates, graph = parse_input(input_file)
    min_moves = shorted_path(graph)
    # print(rates)
    # print(graph)
    # print(min_moves)
    nodes = {x for x, v in rates.items() if v > 0}
    max_turn = 26
    ans = 0
    
    def calc_score(path):
        ans = 0
        for valve, day in path:
            ans += rates[valve] * (max_turn - day)
        return ans
    
    def backtrack(path1, path2):
        nonlocal ans
        if not nodes:
            score = calc_score(path1) + calc_score(path2)
            ans = max(ans, score)
            return
        found_new_values = False
        for node in nodes:
            cur_node, day = path1[-1]
            open_day = day + min_moves[cur_node][node] + 1
            if open_day < max_turn:
                path1.append((node, open_day))
                nodes.remove(node)
                found_new_values = True
                if open_day < path2[-1][1]:
                    backtrack(path1, path2)
                else:
                    backtrack(path2, path1)
                path1.pop()
                nodes.add(node)
            
        if not found_new_values:
            score = calc_score(path1) + calc_score(path2)
            ans = max(ans, score)
            return
       

    backtrack([('AA', 0)], [('AA', 0)])
    return ans

In [24]:
part1_1('inputs/day16_test.txt')

1651

In [33]:
from itertools import combinations
def part2_2(input_file):
    rates, graph = parse_input(input_file)
    min_moves = shorted_path(graph)
    # print(rates)
    # print(graph)
    # print(min_moves)
    nodes = [x for x, v in rates.items() if v > 0]
    
    max_turn = 26
    half = len(nodes) // 2
    ans = 0
    for cm in combinations(nodes, half):
        p1 = set(cm)
        p2 = {x for x in nodes if x not in p1}
        score = max_score(rates, min_moves, p1, max_turn) + max_score(rates, min_moves, p2, max_turn)
        ans = max(ans, score)
    return ans

In [35]:
part2_2('inputs/day16.txt')

2117