In [1]:
from dataclasses import dataclass
from itertools import combinations

In [2]:
data = open("input/08").read().splitlines()

In [3]:
@dataclass
class Point:
    id: int
    x: int
    y: int
    z: int

In [4]:
points = [Point(idx, *list(map(int, line.split(",")))) for idx, line in enumerate(data)]

In [5]:
def dist(a, b):
    # No need to sqrt as we don't care about actual distance
    return (a.x - b.x)**2 + (a.y - b.y)**2 + (a.z - b.z)**2

In [6]:
res = []
for a, b in combinations(points, 2):
    res.append([dist(a, b), a.id, b.id])

In [7]:
res.sort()

In [8]:
def merge_clusters(clusters):
    changed = True
    while changed:
        changed = False
        result = []
        for cluster in clusters:
            for res in result:
                # If have an overlap, merge
                if cluster & res:
                    res |= cluster
                    changed = True
                    break
            else:
                result.append(set(cluster))
        clusters = result
    return clusters

# Part 1

In [9]:
clusters_part1 = []
for _, id1, id2 in res[:1000]:
    for idx, cluster in enumerate(clusters_part1):
        if id1 in cluster or id2 in cluster:
            clusters_part1[idx] |= {id1, id2}
            break
    # Create new cluster
    else:
        clusters_part1.append({id1, id2})

In [10]:
clusters_part1 = merge_clusters(clusters_part1)

In [11]:
sizes = []
for cluster in clusters_part1:
    sizes.append(len(cluster))

In [12]:
sizes.sort()

In [13]:
part1 = sizes[-3] * sizes[-2] * sizes[-1]
print(f"Answer #1: {part1}")

Answer #1: 72150


# Part 2

In [14]:
closest_map = {}
# Verified not two have the same distance...
for dist, id1, id2 in res:
    closest_map[dist] = (id1, id2)

In [15]:
clusters = []
most_recent = None
points_added = set()
for dist, (id1, id2) in closest_map.items():
    points_added.add(id1)
    points_added.add(id2)
    for idx, cluster in enumerate(clusters):
        if id1 in cluster or id2 in cluster:
            clusters[idx] |= {id1, id2}
            most_recent = (id1, id2)
            break
    else:
        clusters.append({id1, id2})

    clusters = merge_clusters(clusters)

    if len(points_added) < len(data):
        continue

    # Only one cluster left
    if len(clusters) == 1:
        break

In [16]:
part2 = 1
for p in points:
    if p.id in most_recent:
        part2 *= p.x

In [17]:
print(f"Answer #2: {part2}")

Answer #2: 3926518899
