In [1]:
with open("inputs/day12-input.txt") as f:
    puzzle = f.read().splitlines()

## Part 1

In [2]:
%%time

import numpy as np
from itertools import permutations
from dataclasses import dataclass, field

@dataclass
class Moon:
    position: np.array
    velocity: np.array = field(default_factory=lambda: np.zeros(3, dtype=int))
        
    def gravity(self, other):
        return np.sign(other.position - self.position)
    
    def energy(self):
        return np.abs(self.position).sum() * np.abs(self.velocity).sum()


def get_moons(puzzle):
    moons = []
    
    for line in puzzle:
        line = line.strip("<>")
        for xyz in "xyz":
            line = line.replace(f"{xyz}=", "")
        positions = list(map(int, line.split(", ")))

        moons.append(Moon(np.array(positions)))
    
    return moons


def step_time(moons):
    for a, b in permutations(moons, 2):
        a.velocity += a.gravity(b)
        
    for a in moons:
        a.position += a.velocity


def total_energy(moons):
    return sum(map(Moon.energy, moons))


moons = get_moons(puzzle)
for _ in range(1000):
    step_time(moons)
    
print(total_energy(moons))

5350
CPU times: user 266 ms, sys: 159 ms, total: 425 ms
Wall time: 186 ms


## Part 2

In [3]:
%%time

moons = get_moons(puzzle)

states = [set(), set(), set()]
found = [False, False, False]

def save_state(moons):
    global x_found, y_found, z_found
    
    state = [[], [], []]
    
    for moon in moons:
        for xyz in range(3):
            state[xyz] += [moon.position[xyz], moon.velocity[xyz]]
    
    for xyz in range(3):
        if not found[xyz]:
            tup = tuple(state[xyz])
            
            if tup in states[xyz]:
                found[xyz] = True
            else:
                states[xyz].add(tup)


save_state(moons)
while not all(found):
    step_time(moons)
    save_state(moons)

print(np.lcm.reduce(list(map(len, states))))

467034091553512
CPU times: user 9.24 s, sys: 104 ms, total: 9.34 s
Wall time: 9.41 s
