In [97]:
from tqdm import tqdm
import re
import numpy as np
from collections import defaultdict
import heapq as heap

with open('data.txt') as file:
    data = file.read().splitlines()

In [98]:
valves = {}

for valve in data:
    v = {}
    a,b = valve.split(';')

    v['open'] = False

    name = a.split(' ')[1]
    v['rate'] = int(a.split(' ')[-1].split('=')[-1])
    
    nexts = b.split(', ')
    nexts[0] = nexts[0].split(' ')[-1]

    v['next'] = nexts
    v['name'] = name
    
    valves[name] = v

In [99]:
dist_cache = {}

def search(valves, start):
    cached = dist_cache.get(start, None)
    if cached is not None:
        return cached[0], cached[1]

    costs = {key: 1000 for key in valves.keys()}
    costs[start] = 0
    paths = {key: [] for key in valves.keys()}
    visited = set()

    pq = []

    heap.heappush(pq, (0, start))

    while pq:

        _, node = heap.heappop(pq)
        visited.add(node)

        path = paths[node]

        for adj_node in valves[node]['next']:
            if adj_node in visited: continue

            newCost = costs[node] + 1

            if costs[adj_node] > newCost:
                costs[adj_node] = newCost
                paths[adj_node] = path + [adj_node]
                heap.heappush(pq, (newCost, adj_node))

    dist_cache[start] = (costs, paths)

    return costs, paths

#search(valves, 'AA')

In [100]:
#%%timeit

cache = {}


def get_best_path(key, open_valves, minutes_left, closed_valves):

    cache_key = (key, tuple(open_valves), minutes_left)
    if cached := cache.get(cache_key, None):
        return cached

    alts = []

    if minutes_left == 0 or len(valves.items()) == len(open_valves):
        return 0

    if key not in open_valves:

        new_open_vales = open_valves.copy()
        new_open_vales.add(key)

        new_closed = closed_valves.copy()
        new_closed.remove(key)

        release = (minutes_left - 1) * valves[key]["rate"] + get_best_path(
            key,
            new_open_vales,
            minutes_left=minutes_left - 1,
            closed_valves=new_closed,
        )
        alts.append(release)

    # Getting valves
    _, paths = search(valves, key)

    for valve in closed_valves:
        if len(v := paths[valve]):
            release = get_best_path(
                v[0],
                open_valves,
                minutes_left=minutes_left - 1,
                closed_valves=closed_valves,
            )
            alts.append(release)

    total_score = max(alts)
    cache[cache_key] = total_score
    return total_score


zero_rates = [key for key, val in valves.items() if val["rate"] == 0]
valves_left = [key for key, val in valves.items() if key not in zero_rates]

get_best_path("AA", set(zero_rates), 30, set(valves_left))


KeyboardInterrupt: 

In [81]:
valves_left

['BB', 'CC', 'DD', 'EE', 'HH', 'JJ']

In [59]:
cache = {}


def get_best_path_2(key, elephant, open_valves, minutes_left):

    cache_key = (tuple(set([key, elephant])), tuple(set(open_valves)), minutes_left)
    if cached := cache.get(cache_key, None):
        return cached

    alts = []

    if minutes_left == 0 or len(valves.items()) == len(open_valves):
        return 0

    if key not in open_valves:

        open_score = (minutes_left - 1) * valves[key]["rate"]

        # Elephant open also
        if elephant not in open_valves and elephant != key:

            release = (
                open_score
                + (minutes_left - 1) * valves[elephant]["rate"]
                + get_best_path_2(
                    key, elephant, open_valves + [key, elephant], minutes_left=minutes_left - 1
                )
            )

            alts.append(release)

        # Open and all elephant moves
        for elephant_move in valves[elephant]["next"]:
            release = open_score + get_best_path_2(
                key, elephant_move, open_valves + [key], minutes_left=minutes_left - 1
            )
            alts.append(release)


    for new_valve in valves[key]["next"]:

        # Elephant open

        if elephant not in open_valves:
            release = (
                (minutes_left - 1) * valves[elephant]["rate"]
                + get_best_path_2(
                    new_valve, elephant, open_valves + [elephant], minutes_left=minutes_left - 1
                )
            )

            alts.append(release)

        # Open and all elephant moves
        for elephant_move in valves[elephant]["next"]:
            release = get_best_path_2(
                new_valve, elephant_move, open_valves, minutes_left=minutes_left - 1
            )
            alts.append(release)



    total_score = max(alts)
    cache[cache_key] = total_score
    return total_score


zero_rates = [key for key, val in valves.items() if val["rate"] == 0]

get_best_path_2("AA", "AA", zero_rates, 26)


KeyboardInterrupt: 