## Setup

In [1]:
# Get raw advent-of-code data
from aocd.models import Puzzle

puzzle = Puzzle(year=2015, day=9)
input_data = puzzle.input_data
example = puzzle.examples[0]

In [3]:
# Import performance checking utility
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

from common.utils.perf_check import check_example, time_solution

## Part a

### Combinatorial approach
Classic traveling salesman problem! Let's start by just brute-forcing all permutations of locations to find the shortest route.

In [4]:
# Imports
from collections import defaultdict
from itertools import permutations
from typing import Literal

In [102]:
# Constants
OP_MAP = {"min": min, "max": max}

In [103]:
# Functions
def solve_a_combinatorial(input_data: str, *, mode: Literal["min", "max"] = "min") -> int:
    """Find the shortest route through all locations by checking all permutations."""
    distances = defaultdict(dict)

    # Construct distance map: {location_a: {location_b: distance, ...}, ...}
    for line in input_data.splitlines():
        a, _, b, _, d = line.split()
        distances[a][b] = distances[b][a] = int(d)

    locations = list(distances.keys())

    # Find the shortest / longest permutation of locations
    return OP_MAP[mode](
        sum(distances[perm[i]][perm[i + 1]] for i in range(len(perm) - 1)) for perm in permutations(locations)
    )

In [104]:
# Correctness check
check_example(solve_a_combinatorial, example, "a")

solve_a_combinatorial found answer 605, which is the correct solution for part A!


True

In [105]:
# Performance check
time_a_combinatorial = time_solution(solve_a_combinatorial, input_data)

solve_a_combinatorial takes 26.70 ms


### Recursive approach
We can speed up the combinatorial approach by using recursion, memoization, and bitmasking to represent visited cities.


In [93]:
# Imports
from functools import lru_cache

In [None]:
def parse_to_matrix(input_data: str) -> tuple[list[list[int]], int]:
    """Parse input data into a distance matrix and number of cities."""
    lines = [line.split() for line in input_data.splitlines()]

    # Collect unique cities and map to indices to allow bitmask operations for speed
    cities = {p[0] for p in lines} | {p[2] for p in lines}
    city_map = {name: i for i, name in enumerate(cities)}
    n = len(cities)

    # Initialize matrix with 0s
    matrix = [[0] * n for _ in range(n)]
    for u, _, v, _, d in lines:
        i, j = city_map[u], city_map[v]
        matrix[i][j] = matrix[j][i] = int(d)

    return matrix, n


def solve_a_recursive(input_data: str, *, mode: Literal["min", "max"] = "min") -> int:
    """Find the shortest / longest route through all locations by recursively checking all paths as bitmasks."""
    matrix, n = parse_to_matrix(input_data)

    # Bitmask when all cities have been visited
    visited_all = (1 << n) - 1

    @lru_cache(1_000)
    def find_path_length(current_city: int, visited_bitmask: int) -> int | float:
        # If mask is full, we've visited everyone
        if visited_bitmask == visited_all:
            return 0

        best_distance = float("-inf") if mode == "max" else float("inf")

        for nxt in range(n):
            if not (visited_bitmask & (1 << nxt)):
                # Cost = distance to next + recursive cost from next
                cost = matrix[current_city][nxt] + find_path_length(nxt, visited_bitmask | (1 << nxt))

                best_distance = OP_MAP[mode](best_distance, cost)

        return best_distance

    # Start from every city and pick the minimum
    if (result := OP_MAP[mode](find_path_length(i, 1 << i) for i in range(n))) == (
        float("inf") if mode == "min" else float("-inf")
    ):
        msg = "No valid path found"
        raise ValueError(msg)
    return int(result)

In [107]:
# Correctness check
check_example(solve_a_recursive, example, "a")

solve_a_recursive found answer 605, which is the correct solution for part A!


True

In [None]:
# Performance check
time_a_recursive = time_solution(solve_a_recursive, input_data)
print(f"This is {time_a_combinatorial / time_a_recursive:.1f}x faster than the combinatorial approach.")

solve_a_recursive takes 0.83 ms
This is 32.2x faster than the combinatorial approach.


### Networkx approach
Let's see how well networkx builtins can handle this problem.

In [64]:
# Imports
import networkx as nx
from networkx.algorithms import approximation as approx

In [None]:
# Functions
def parse_input_to_graph(input_data: str) -> nx.Graph:
    """Parse input data into a weighted graph."""
    graph = nx.Graph()
    for line in input_data.splitlines():
        a, _, b, _, d = line.split()
        graph.add_edge(a, b, weight=int(d))
    return graph


def solve_a_networkx_approx(input_data: str) -> int:
    """Find the shortest route that visits all locations using networkx approximation."""
    graph = parse_input_to_graph(input_data)
    path = approx.traveling_salesman_problem(graph, cycle=False, method=approx.greedy_tsp)
    return int(nx.path_weight(graph, path, weight="weight"))

In [90]:
# Performance check
time_a_networkx = time_solution(solve_a_networkx_approx, input_data)
print(f"This {time_a_combinatorial / time_a_networkx:.1f}x faster than the combinatorial approach.")

print("\nSolution A (exact combinatorial):", solve_a_combinatorial(input_data))
print("Solution A (networkx approximation):", solve_a_networkx_approx(input_data))

solve_a_networkx_approx takes 0.17 ms
This 153.2x faster than the combinatorial approach.

Solution A (exact combinatorial): 141
Solution A (networkx approximation): 143


The networkx approximation is a lot faster! However, it turns out that it doesn't give the correct solution by default. We can improve the approximation by writing a custom TSP solver that tries starting from each node in the graph, and taking the best result.

In [152]:
def solve_a_networkx_tsp(input_data: str, mode: str = "min") -> int:
    """Find the shortest / longest route through all locations using a custom TSP solver with networkx."""
    graph = parse_input_to_graph(input_data)

    best_distances = []

    for start_node in graph.nodes():
        # Initialize visited nodes and current node for this starting point
        visited = [start_node]
        current_node = start_node

        # Build the path greedily until all nodes are visited
        while len(visited) < len(graph.nodes()):
            # Get neighbors that haven't been visited yet
            neighbors = [
                (n, graph[current_node][n]["weight"]) for n in graph.neighbors(current_node) if n not in visited
            ]

            if not neighbors:
                # We've completed the path or are stuck
                break

            # Pick the best neighbor based on the mode
            next_node = OP_MAP[mode](neighbors, key=lambda x: x[1])[0]

            visited.append(next_node)
            current_node = next_node

        # Calculate the total weight of the path we just built
        if len(visited) == len(graph.nodes()):
            best_distances.append(nx.path_weight(graph, visited, weight="weight"))

    return OP_MAP[mode](best_distances)

In [None]:
# Performance check
time_a_networkx_improved = time_solution(solve_a_networkx_tsp, input_data)
print(
    f"This {time_a_combinatorial / time_a_networkx_improved:.1f}x faster than the combinatorial approach"
    f" and {time_a_recursive / time_a_networkx_improved:.1f}x faster than the recursive approach!"
)

print("\nSolution A (exact combinatorial):", solve_a_combinatorial(input_data))
print("Solution A (networkx approximation):", solve_a_networkx_tsp(input_data))

solve_networkx_tsp takes 0.12 ms
This 221.1x faster than the combinatorial approach and 6.9x faster than the recursive approach!

Solution A (exact combinatorial): 141
Solution A (networkx approximation): 141


### Held-Karp approach
Finally, I came across the Held-Karp algorithm, which is a classic dynamic programming solution for the traveling salesman problem. Let's see if we can implement it here.

In [148]:
def solve_a_held_karp(input_data: str, *, mode: Literal["min", "max"] = "min") -> int:
    """Find the shortest route that visits all locations using the Held-Karp algorithm."""
    matrix, n = parse_to_matrix(input_data)

    # dp[mask][i] = best distance visiting 'mask' cities, ending at city 'i'
    dp = [[float("inf") if mode == "min" else float("-inf")] * n for _ in range(1 << n)]

    # Initial state: starting at any single city costs 0
    for i in range(n):
        dp[1 << i][i] = 0

    # Fill DP table
    for mask in range(1, 1 << n):
        for curr in range(n):
            # Skip if current city isn't in the subset
            if not (mask & (1 << curr)):
                continue

            # If we're at 'curr', try moving to a new 'nxt' city
            for nxt in range(n):
                # Skip if nxt is already visited
                if mask & (1 << nxt):
                    continue

                new_mask = mask | (1 << nxt)
                new_dist = dp[mask][curr] + matrix[curr][nxt]

                dp[new_mask][nxt] = OP_MAP[mode](dp[new_mask][nxt], new_dist)

    # Result is the minimum value in the last row (all cities visited)
    return int(OP_MAP[mode](dp[(1 << n) - 1]))

In [149]:
# Correctness check
check_example(solve_a_held_karp, example, "a")

solve_a_held_karp found answer 605, which is the correct solution for part A!


True

In [151]:
# Performance check
time_a_networkx = time_solution(solve_a_held_karp, input_data)
print(
    f"This is {time_a_combinatorial / time_a_networkx:.1f}x faster than the combinatorial approach,"
    f" {time_a_recursive / time_a_networkx:.1f}x faster than the recursive approach,"
    f" and {time_a_networkx_improved / time_a_networkx:.1f}x faster than the networkx approximation!"
)

solve_a_held_karp takes 0.72 ms
This is 36.9x faster than the combinatorial approach, 1.1x faster than the recursive approach, and 0.2x faster than the networkx approximation!


In [99]:
# Submit answer
puzzle.answer_a = solve_a_held_karp(input_data)

## Part b
This does not change the required logic, just we want the longest route instead of the shortest. I've added a `mode (Literal["min", "max"])` parameter to each function to handle this.

In [153]:
# Correctness check
check_example(solve_a_combinatorial, example, "b", mode="max")
check_example(solve_a_recursive, example, "b", mode="max")
check_example(solve_a_networkx_tsp, example, "b", mode="max")
check_example(solve_a_held_karp, example, "b", mode="max")

solve_a_combinatorial found answer 982, which is the correct solution for part B!
solve_a_recursive found answer 982, which is the correct solution for part B!
solve_a_networkx_tsp found answer 982, which is the correct solution for part B!
solve_a_held_karp found answer 982, which is the correct solution for part B!


True

In [None]:
# Submit answer
puzzle.answer_b = solve_a_networkx_tsp(input_data, mode="max")

[32mThat's the right answer!  You are one gold star closer to powering the weather machine.You have completed Day 9! You can [Shareon
  Bluesky
Twitter
Mastodon] this victory or [Return to Your Advent Calendar].[0m
