In [18]:
from utils import read_lines
from functools import cache

def parse_input(input_file):
    lines = read_lines(input_file)
    names = {}
    for line in lines:
        parts = line.split(' ')
        for name in (parts[0], parts[-1][:-1]):
            if name not in names:
                names[name] = len(names)
    
    n = len(names)
    matrix = [[0] * n for _ in range(n)]
    for line in lines:
        parts = line.split(' ')
        n1, v, n2 = parts[0], int(parts[3]), parts[-1][:-1]
        if parts[2] == 'lose':
            v = -v
        matrix[names[n1]][names[n2]] = v

    return matrix

def normalize(matrix):
    for i in range(len(matrix) - 1):
        for j in range(i + 1, len(matrix)):
            v = matrix[i][j] + matrix[j][i]
            matrix[i][j] = v
            matrix[j][i] = v
        
def part1(input_file):
    matrix = parse_input(input_file)
    normalize(matrix)
    n = len(matrix)

    @cache
    def dp(cur, mask):
        if mask == 2**n - 1:
            return matrix[0][cur]
        ans = float('-inf')
        for i in range(1, n):
            if mask & (1 << i) == 0:
                r = dp(i, mask | (1 << i))
                ans = max(ans, r + matrix[cur][i])
        return ans
    
    return dp(0, 1)

def part2(input_file):
    matrix = parse_input(input_file)
    normalize(matrix)
    n = len(matrix)

    @cache
    def dp(cur, mask):
        if mask == 2**n - 1:
            return 0
        ans = float('-inf')
        for i in range(n):
            if mask & (1 << i) == 0:
                r = dp(i, mask | (1 << i))
                if cur != -1:
                    r += matrix[cur][i]
                ans = max(ans, r)
        return ans
    
    return dp(-1, 0)

In [12]:
part1('inputs/day13.txt')

709

In [19]:
part2('inputs/day13.txt')

668

In [10]:
matrix = parse_input('inputs/day13.txt')
normalize(matrix)
matrix

[[0, 57, -164, 25, 38, -38, 83, -130],
 [57, 0, -62, -6, 33, 54, -107, 87],
 [-164, -62, 0, 83, 94, 149, -42, -5],
 [25, -6, 83, 0, -163, 41, 98, 46],
 [38, 33, 94, -163, 0, 23, -68, 100],
 [-38, 54, 149, 41, 23, 0, 59, 15],
 [83, -107, -42, 98, -68, 59, 0, -10],
 [-130, 87, -5, 46, 100, 15, -10, 0]]