In [32]:
from utils import read_lines
from itertools import permutations
from collections import deque

rotations = list(permutations([0,1,2]))
signs = [-1, 1]
facings = [(x, y, z) for x in signs for y in signs for z in signs]

def parse_input(input_file):
    lines = read_lines(input_file)
    ans = []
    cur = []
    for line in lines:
        if not line:
            ans.append(cur)
            cur = []
        elif line[0:3] =='---':
            continue
        else:
            nums = [int(x) for x in line.split(',')]
            cur.append(tuple(nums))
    ans.append(cur)
    return ans

def transform(beacons, rotation, facing):
    i, j, k = rotation
    fx, fy, fz = facing
    return [(b[i] * fx, b[j] * fy, b[k] * fz) for b in beacons]


def is_overlap(beacons1, beacons2):
    # return relative position for scanner if have more than 12 same beacon
    # fx, fy, fz = facing
    for x1, y1, z1 in beacons1:
        b1_relatives = {(x - x1, y - y1, z - z1) for x, y, z in beacons1}
        for x2, y2, z2 in beacons2:
            cnt = 0
            for x, y, z in beacons2:
                b2_rel = (x - x2, y - y2, z - z2)
                if b2_rel in b1_relatives:
                    cnt += 1
                    if cnt >= 12:
                        return x1 - x2, y1 - y2, z1 - z2
    return None

def is_connect(scaner1, scaner2):
    for rotation in rotations:
        for facing in facings:
            beacons2 = transform(scaner2, rotation, facing)
            if rel := is_overlap(scaner1, beacons2):
                return rotation, facing, rel
    return None

def find_conntections(scanners):
    n = len(scanners)
    visited = {} # index -> (index, relative_coord, rotation, facing)
    visited[0] = None
    q = deque()
    q.append(0)
    while q and len(visited) < n:
        i = q.popleft()
        for j in range(n):
            if j in visited:
                continue
            if conn := is_connect(scanners[i], scanners[j]):
                rotation, facing, rel = conn
                visited[j] = (i, rel, rotation, facing)
                q.append(j)
    return visited


    
def part1(input_file):
    scanners = parse_input(input_file)
    all_beacons = set()
    for b in scanners[0]:
        all_beacons.add(b)
    connections = find_conntections(scanners)
    for i in range(1, len(scanners)):
        j = i
        beacons = scanners[j]
        while j != 0:
            j, rel, rotation, facing = connections[j]
            beacons = transform(beacons, rotation=rotation, facing=facing)
            beacons = [(b[0] + rel[0], b[1] + rel[1], b[2] + rel[2]) for b in beacons]
        for b in beacons:
            all_beacons.add(b)
    return len(all_beacons)

def calc_coords(conn):
    n = len(conn)
    coords = [(0, 0, 0)] * n

    for i in range(1, n):
        
        ci = [0, 0, 0]
        j = i
        while j != 0:
            j, co, rotation, facing = conn[j]
            new_ci = [co[0], co[1], co[2]]
            for k in range(3):
                new_ci[k] += ci[rotation[k]] * facing[k]
            ci = new_ci
        coords[i] = ci
    return coords

def manhattan_dist(c1, c2):
    ans = 0
    for i in range(3):
        ans += abs(c1[i] - c2[i])
    return ans

def part2(input_file):
    scanners = parse_input(input_file)
    connections = find_conntections(scanners)
    coords = calc_coords(connections)
    ans = 0
    for i in range(len(coords) - 1):
        for j in range(i + 1, len(coords)):
            ans = max(ans, manhattan_dist(coords[i], coords[j]))
    return ans 


In [24]:
part1('inputs/day19_test.txt')

79

In [25]:
part1('inputs/day19.txt')

398

In [33]:
part2('inputs/day19_test.txt')

3621

In [34]:
part2('inputs/day19.txt')

10965

In [27]:
conn = part2('inputs/day19_test.txt')
conn

{0: None,
 1: (0, (68, -1246, -43), (0, 1, 2), (-1, 1, -1)),
 3: (1, (160, -1134, -23), (0, 1, 2), (1, 1, 1)),
 4: (1, (88, 113, -1104), (1, 2, 0), (1, -1, -1)),
 2: (4, (168, -1125, 72), (1, 0, 2), (1, 1, -1))}

In [29]:
n = len(conn)
coords = [(0, 0, 0)] * n

for i in range(1, n):
    
    ci = [0, 0, 0]
    j = i
    while j != 0:
        j, co, rotation, facing = conn[j]
        new_ci = [co[0], co[1], co[2]]
        for k in range(3):
            new_ci[k] += ci[rotation[k]] * facing[k]
        ci = new_ci
    coords[i] = ci
coords


[(0, 0, 0),
 [68, -1246, -43],
 [1105, -1205, 1229],
 [-92, -2380, -20],
 [-20, -1133, 1061]]