In [76]:
from elf import (
    get_puzzle_input,
    submit_puzzle_answer,
)
import math
from itertools import combinations
from functools import reduce
from collections import defaultdict

In [2]:
text = get_puzzle_input(2025, 8)

## part 1

In [27]:
def distance(p1, p2):
    return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2 + (p1[2]-p2[2])**2)

In [22]:
class UnionFind:
    def __init__(self):
        self.parent = {}
        self.rank = {}

    def make_set(self, x):
        if x not in self.parent:
            self.parent[x] = x
            self.rank[x] = 0

    def find(self, x):
        if x not in self.parent:
            self.make_set(x)

        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])

        return self.parent[x]

    def union(self, a, b):
        ra = self.find(a)
        rb = self.find(b)

        if ra == rb:
            return

        if self.rank[ra] > self.rank[rb]:
            self.parent[rb] = ra
        elif self.rank[rb] > self.rank[ra]:
            self.parent[ra] = rb
        else:
            self.parent[rb] = ra
            self.rank[ra] += 1

In [31]:
points = [tuple(map(int, i.split(','))) for i in text.splitlines()]

In [44]:
closest_pairs = sorted(combinations(points, 2), key=lambda x: distance(x[0], x[1]))

In [78]:
uf = UnionFind()
for a, b in closest_pairs[:1000]:
    uf.make_set(a)
    uf.make_set(b)
    uf.union(a, b)

clusters = defaultdict(list)
for x in uf.parent:
    root = uf.find(x)          # path-compress all nodes
    clusters[root].append(x)

In [82]:
top_3 = sorted(clusters.items(), key=lambda x: len(x[1]), reverse=True)[:3]

In [43]:
result = submit_puzzle_answer(2025, 8, 1, reduce(lambda x,y: x*y, [len(j) for i, j in top_3]))
print(result.is_correct, result.message)

True 171503 is correct. Star awarded.


## part 2

In [53]:
class UnionFind:
    def __init__(self):
        self.parent = {}
        self.rank = {}
        self.num_clusters = 0

    def make_set(self, x):
        if x not in self.parent:
            self.parent[x] = x
            self.rank[x] = 0
            self.num_clusters += 1

    def find(self, x):
        if x not in self.parent:
            self.make_set(x)

        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])

        return self.parent[x]

    def union(self, a, b):
        ra = self.find(a)
        rb = self.find(b)

        if ra == rb:
            return

        self.num_clusters -= 1

        if self.rank[ra] > self.rank[rb]:
            self.parent[rb] = ra
        elif self.rank[rb] > self.rank[ra]:
            self.parent[ra] = rb
        else:
            self.parent[rb] = ra
            self.rank[ra] += 1

In [83]:
all_points = set(points)

uf = UnionFind()
for i, (a, b) in enumerate(closest_pairs):
    uf.make_set(a)
    uf.make_set(b)
    uf.union(a, b)

    if len(all_points.difference(uf.parent)) == 0 and uf.num_clusters == 1:
        answer = (a, b)
        break

In [84]:
answer

((90968, 13304, 86478), (99700, 4569, 92874))

In [75]:
result = submit_puzzle_answer(2025, 8, 2, answer[0][0] * answer[1][0])
print(result.is_correct, result.message)

True 9069509600 is correct. Star awarded.
