In [17]:
from collections import defaultdict
import heapq

def parse_input(input_file):
    coords = []
    with open(input_file, 'r') as f:
        for line in f:
            coords.append([int(x) for x in line.strip().split(',')])
    return coords

def cal_dist(a, b):
    x1, y1, z1 = a
    x2, y2, z2 = b
    return (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2

def to_dist_matrix(coords):
    n = len(coords)
    ans = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            ans[i][j] = cal_dist(coords[i], coords[j])
    return ans

def to_heap(coords):
    matrix = to_dist_matrix(coords)
    ans = []
    n = len(coords)
    for i in range(n):
        for j in range(i+1, n):
            ans.append((matrix[i][j], i, j))
    heapq.heapify(ans)
    return ans

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [1] * n

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.parent[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.parent[rootX] = rootY
            else:
                self.parent[rootY] = rootX
                self.rank[rootX] += 1
            return True
        return False

def part1(input_file, k):
   
    coords = parse_input(input_file)
    heap = to_heap(coords)
    n = len(coords)
    uf = UnionFind(n)
    for _ in range(k):
        dist, i, j = heapq.heappop(heap)
        uf.union(i, j)
        
    cnt = defaultdict(int)
    for i in range(n):
        root = uf.find(i)
        cnt[root] += 1
    ranks = sorted(cnt.values())
    return ranks[-1] * ranks[-2] * ranks[-3]


def part2(input_file):
   
    coords = parse_input(input_file)
    heap = to_heap(coords)
    n = len(coords)
    uf = UnionFind(n)
    connected_components = n
    while True:
        dist, i, j = heapq.heappop(heap)
        if uf.union(i, j):
            connected_components -= 1
            if connected_components == 1:
                return coords[i][0] * coords[j][0]
    

In [13]:
part1('input/day8_test.txt', 10)

40

In [14]:
part1('input/day8.txt', 1000)

29406

In [18]:
part2('input/day8_test.txt')

25272

In [19]:
part2('input/day8.txt')

7499461416