## Setup

In [1]:
# Get raw advent-of-code data
from aocd.models import Puzzle

puzzle = Puzzle(year=2025, day=8)
input_data = puzzle.input_data
example = puzzle.examples[0]

In [2]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

from common.utils.perf_check import check_time

## Part a

### Pure numpy approach
I started with a pure numpy approach to compute pairwise distances and form circuits. However, this approach was not efficient for large datasets due to the O(n^2) complexity of distance calculations.


In [3]:
# Imports
import numpy as np

In [4]:
# Functions
def parse_input(input_data: str) -> np.ndarray:
    """Parse input data into a numpy array of box coordinates."""
    return np.loadtxt(input_data.splitlines(), delimiter=",", dtype=np.int32)


def calc_pairwise_distances_numpy(boxes: np.ndarray) -> tuple[tuple[np.ndarray, np.ndarray], np.ndarray]:
    """Calculate pairwise Euclidean distances between boxes using numpy."""
    # Construct the upper triangle indices to avoid duplicates and self-distances
    i_idx, j_idx = np.triu_indices(len(boxes), k=1)  # cSpell:ignore triu
    return (
        (i_idx, j_idx),  # Return the indices of upper triangle pairs
        np.linalg.norm(boxes[i_idx] - boxes[j_idx], axis=1),  # And the corresponding distances
    )


def find_closest_box_pairs(
    dists: np.ndarray,
    i_idx: np.ndarray,
    j_idx: np.ndarray,
    *,
    num_closest_pairs: int = 1_000,
    cut_off_higher_partition: bool = True,
) -> list[tuple[int, int]]:
    """Find the closest box pairs for an array of distance and their upper triangle indices."""
    # Adjust number of closest pairs k if necessary
    k = min(num_closest_pairs, len(dists))

    # First, we can use argpartition to efficiently get the indices of the k smallest distances
    partitioned_dists = np.argpartition(dists, k - 1)
    k_shortest_dists = partitioned_dists[:k]

    # Then, sort these k distances to get them in exact order
    sorted_k_dists = k_shortest_dists[np.argsort(dists[k_shortest_dists])]

    # Re-include distances beyond the k-th smallest if needed
    sorted_dists = (
        np.concatenate([sorted_k_dists, partitioned_dists[k:]]) if not cut_off_higher_partition else sorted_k_dists
    )

    # Get the closest box pairs for the sorted distances
    return list(zip(i_idx[sorted_dists], j_idx[sorted_dists], strict=True))


def get_unique_circuits(closest_box_pairs: list[tuple[int, int]], boxes_count: int) -> set[tuple[int, ...]]:
    """Get unique circuits from the circuits dictionary."""
    # Initialize circuits for each box
    circuits = {i: {i} for i in range(boxes_count)}

    for a, b in closest_box_pairs:
        # Merge circuits for each closest pair if they are not already connected
        if a not in circuits[b] and b not in circuits[a]:
            combined_circuit = circuits[a] | circuits[b]
            for box in combined_circuit:
                # Update the circuit for each box in the combined circuit
                circuits[box] = combined_circuit

    # Get unique circuits
    return {tuple(c) for c in circuits.values()}


def find_largest_circuits_numpy(input_data: str, num_closest_pairs: int = 1_000) -> int:
    """Form circuits by connecting the closest boxes. Return the product of the sizes of the three largest circuits."""
    # Load input data
    boxes = parse_input(input_data)

    # Construct upper triangle pairs and calculate their distances
    (i_idx, j_idx), dists = calc_pairwise_distances_numpy(boxes)

    # Get the closest box pairs
    closest_box_pairs = find_closest_box_pairs(dists, i_idx, j_idx, num_closest_pairs=num_closest_pairs)

    # Get unique circuits
    unique_circuits = get_unique_circuits(closest_box_pairs, len(boxes))

    # Multiply the sizes of the three largest circuits
    return int(np.prod(sorted(len(c) for c in unique_circuits)[-3:]))

In [5]:
# Correctness check
str(find_largest_circuits_numpy(example.input_data, num_closest_pairs=10)) == example.answer_a

True

In [9]:
# Performance check
numpy_time_a = check_time(find_largest_circuits_numpy, input_data, number=50, repeat_times=3)
print(f"The pure-numpy implementation takes {numpy_time_a:.0f} ms per run.")

The pure-numpy implementation takes 21 ms per run.


### Scipy optimization
SciPy has some neat optimizations for distance computations. This speeds up the distance calculations significantly compared to the pure numpy approach.

In [5]:
# Imports
from scipy.spatial.distance import pdist

In [17]:
# Functions
def find_closest_pair_indices(
    pts: np.ndarray, k: int = 1_000, *, cut_off_higher_partition: bool = True
) -> list[tuple[int, int]]:
    """Find the k closest pairs for an array of points."""
    # Calculate pairwise distances using scipy pdist
    dists = pdist(pts, metric="euclidean")

    # First, we can use argpartition to efficiently get the indices of the k smallest distances
    partitioned = np.argpartition(dists, k - 1)[:k]

    # Then, sort these k distances to get them in exact order
    sorted_k = partitioned[np.argsort(dists[partitioned])]

    if cut_off_higher_partition:
        final_idx = sorted_k
    else:
        # Include all remaining indices after the sorted k
        remaining = np.setdiff1d(np.arange(len(dists)), sorted_k, assume_unique=True)
        final_idx = np.concatenate([sorted_k, remaining])

    # Get upper triangle indices
    i_idx, j_idx = np.triu_indices(len(pts), k=1)

    # Return the indices of the closest pairs in the original point-array
    return list(zip(i_idx[final_idx], j_idx[final_idx], strict=True))


def find_largest_circuits_scipy(input_data: str, num_closest_pairs: int = 1_000) -> int:
    """Form circuits by connecting the closest boxes. Return the product of the sizes of the three largest circuits."""
    # Load input data
    boxes = parse_input(input_data)

    # Get the closest box pairs using a common scipy-based function
    closest_box_pairs = find_closest_pair_indices(boxes, k=num_closest_pairs)

    # Get unique circuits
    unique_circuits = get_unique_circuits(closest_box_pairs, len(boxes))

    # Multiply the sizes of the three largest circuits
    return int(np.prod(sorted(len(c) for c in unique_circuits)[-3:]))

In [19]:
# Correctness check
str(find_largest_circuits_scipy(example.input_data, num_closest_pairs=10)) == example.answer_a

True

In [20]:
# Performance check
scipy_time_a = check_time(find_largest_circuits_scipy, input_data)
print(f"The scipy implementation takes {scipy_time_a:.1f} ms per run.")
print(f"This is {numpy_time_a / scipy_time_a:.1f}x faster than the pure-numpy implementation.")

The scipy implementation takes 7.2 ms per run.
This is 2.8x faster than the pure-numpy implementation.


In [191]:
# Submit answer
puzzle.answer_a = find_largest_circuits_scipy(input_data)

[32mThat's the right answer!  You are one gold star closer to decorating the North Pole. [Continue to Part Two][0m


## Part b
I'm quite sure we can keep the same approach for calculating distance, we just need to go further down the sorted list of closest pairs to ensure all boxes are connected into a single circuit.

I quickly figured that storing the circuits explicitly was slowing things down, so I first implemented a union-find data structure to efficiently manage the connections between boxes.


### Union-find approach
I started off by adapting the closest box pair iteration to use a union-find data structure to efficiently manage the connected components of boxes. This allows us to quickly determine when all boxes are connected into a single circuit.

In [11]:
class UnionFind:
    """Union-find algorithm with path compression and union by size."""

    def __init__(self, n: int):
        self.roots = list(range(n))
        self.sizes = [1] * n
        self.components = n

    def find(self, x: int) -> int:
        """Find the root of x with path compression."""
        p = self.roots
        while p[x] != x:
            # Traverse up the tree
            p[x] = p[p[x]]
            # Compress path
            x = p[x]
        return x

    def union(self, a: int, b: int) -> bool:
        """Merge the sets containing a and b. Return True if merged, False if already merged."""
        root_a, root_b = self.find(a), self.find(b)

        # Already connected
        if root_a == root_b:
            return False

        # Ensure root_a is the larger tree
        if self.sizes[root_a] < self.sizes[root_b]:
            root_a, root_b = root_b, root_a

        # Merge smaller tree into larger tree
        self.roots[root_b] = root_a
        self.sizes[root_a] += self.sizes[root_b]

        # Decrease component count
        self.components -= 1
        return True

In [12]:
def find_final_box_pair_union_find(box_pairs: list[tuple[int, int]], boxes_count: int) -> tuple[int, int]:
    """Find final connection that connects all boxes into a single circuit using UnionFind."""
    uf = UnionFind(boxes_count)
    for a, b in box_pairs:
        # If there is only one component (connected circuit) left, return the current pair
        if uf.union(a, b) and uf.components == 1:
            return a, b
    msg = "There were not enough connections to connect all boxes into a single circuit."
    raise ValueError(msg)


def find_last_connection_union_find(input_data: str) -> int:
    """Form circuits by connecting the closest boxes. Return the product of the sizes of the three largest circuits."""
    # Load input data
    boxes = parse_input(input_data)

    # Get the closest box pairs using a common scipy-based function
    sorted_box_pairs = find_closest_pair_indices(boxes, cut_off_higher_partition=False)

    # Find the final box pair that connects all boxes into a single circuit
    final_box_pair = find_final_box_pair_union_find(sorted_box_pairs, len(boxes))

    # Multiply the X-coordinates of the final connected boxes
    return int(boxes[final_box_pair[0]][0].astype(np.int64) * boxes[final_box_pair[1]][0].astype(np.int64))

In [13]:
# Correctness check
str(find_last_connection_union_find(example.input_data)) == example.answer_b

True

In [16]:
# Performance check
time_union_find_b = check_time(find_last_connection_union_find, input_data, number=5)
print(f"The union-find approach takes {time_union_find_b:.2f} ms per run.")

The union-find approach takes 50.65 ms per run.


### k-NN graph
The union find approach still takes 50 ms per run, which is decent but I think we can do better. Using k-NN graphs should help here.

In [None]:
import heapq
from math import log2

from scipy.spatial import cKDTree

In [None]:
def find_last_connection_knn(input_data: str) -> int:
    """Use a k-NN graph to find neighbors.

    We reuse the union-find structure to keep track of connected components.
    """
    # Load input data
    boxes = parse_input(input_data).astype(np.int64)
    N = len(boxes)  # noqa: N806 # Capitalize variable to follow k-NN convention

    # Build a k-D tree and union-find structure for efficient nearest neighbor search
    tree = cKDTree(boxes)
    uf = UnionFind(N)

    # Initialize heap and seen set
    heap = []
    seen = set()  # pid = a*N + b for unordered pair (a<b)

    # Set the initial k at roughly sqrt(N)
    k = int(2 ** (-(-log2(N) // 2)))
    while uf.components > 1:
        if not heap:
            # Ensure k does not exceed N-1
            k = min(k, N - 1)

            # Fetch k+1 nearest neighbors (including self). Shape (N, k+1)
            dists, idxs = tree.query(boxes, k=k + 1)

            # Construct edges excluding self
            rows = np.repeat(np.arange(N), k)
            cols = idxs[:, 1:].ravel()

            # Distances
            ds = dists[:, 1:].ravel()

            if rows.size:
                # Add edges to the heap
                mn = np.minimum(rows, cols).astype(np.int16)
                mx = np.maximum(rows, cols).astype(np.int16)
                for a, b, d in zip(mn, mx, ds, strict=False):
                    # Unique pair id
                    pid = int(a) * N + int(b)
                    if pid in seen:
                        continue
                    seen.add(pid)

                    # Push to heap
                    heapq.heappush(heap, (float(d), int(a), int(b)))

            if k >= N - 1 and not heap:
                msg = "No edges available to connect points"
                raise ValueError(msg)

            if not heap:
                # Increase k and retry
                k = min(N - 1, k * 2)
                continue

        # Process the smallest edge
        d, a, b = heapq.heappop(heap)

        # If there is only one component (connected circuit) left, return the current pair
        if uf.union(a, b) and uf.components == 1:
            return int(boxes[a, 0] * boxes[b, 0])

    msg = "Could not connect all points."
    raise ValueError(msg)

In [53]:
# Correctness check
str(find_last_connection_knn(example.input_data)) == example.answer_b

True

In [69]:
# Performance check
time_knn_b = check_time(find_last_connection_knn, input_data, number=10)
print(f"The k-NN approach takes {time_knn_b:.2f} ms per run.")
print(f"This is {time_union_find_b / time_knn_b:.1f}x faster than the union-find implementation.")

The k-NN approach takes 15.32 ms per run.
This is 3.3x faster than the union-find implementation.


In [234]:
# Submit answer
puzzle.answer_b = find_last_connection_knn(input_data)

[32mThat's the right answer!  You are one gold star closer to decorating the North Pole.You have completed Day 8! You can [Shareon
  Bluesky
Twitter
Mastodon] this victory or [Return to Your Advent Calendar].[0m
