In [1]:
with open("inputs/day12-input.txt") as f:
    puzzle = f.read().splitlines()

## Part 1

In [2]:
%%time

import numpy as np
from itertools import permutations
from dataclasses import dataclass, field

@dataclass
class Moon:
    position: np.array
    velocity: np.array = field(default_factory=lambda: np.zeros(3, dtype=int))
        
    def gravity(self, other):
        return np.sign(other.position - self.position)
    
    def energy(self):
        return np.abs(self.position).sum() * np.abs(self.velocity).sum()


def get_moons(puzzle):
    moons = []
    
    for line in puzzle:
        line = line.strip("<>")
        for xyz in "xyz":
            line = line.replace(f"{xyz}=", "")
        positions = list(map(int, line.split(", ")))

        moons.append(Moon(np.array(positions)))
    
    return moons


def step_time(moons):
    for a, b in permutations(moons, 2):
        a.velocity += a.gravity(b)
        
    for a in moons:
        a.position += a.velocity


def total_energy(moons):
    return sum(map(Moon.energy, moons))


moons = get_moons(puzzle)
for _ in range(1000):
    step_time(moons)
    
print(total_energy(moons))

5350
CPU times: user 353 ms, sys: 82.6 ms, total: 436 ms
Wall time: 568 ms


## Part 2

In [3]:
%%time

moons = get_moons(puzzle)

states = [set(), set(), set()]
found = [False, False, False]

def save_state(moons):
    for xyz in range(3):
        state = []
        if not found[xyz]:
            for moon in moons:
                state += [moon.position[xyz], moon.velocity[xyz]]

            tup = tuple(state)
            
            if tup in states[xyz]:
                found[xyz] = True
            else:
                states[xyz].add(tup)


save_state(moons)
while not all(found):
    step_time(moons)
    save_state(moons)

print(np.lcm.reduce(list(map(len, states))))

467034091553512
CPU times: user 8.17 s, sys: 123 ms, total: 8.3 s
Wall time: 8.34 s


## Vectorized

In [4]:
%%time

def get_moons_matrix(puzzle):
    moons = []
    
    for line in puzzle:
        line = line.strip("<>")
        for xyz in "xyz":
            line = line.replace(f"{xyz}=", "")
        positions = list(map(int, line.split(", ")))

        moons.append(np.array(positions))
    
    return np.array(moons)


states = [set(), set(), set()]
found = [False, False, False]

def save_state_vectors(position, velocity):    
    state = [[], [], []]

    for xyz in range(3):
        if not found[xyz]:
            state = tuple(position[:, xyz]) + tuple(velocity[:, xyz])
            
            if state in states[xyz]:
                found[xyz] = True
            else:
                states[xyz].add(state)


p = get_moons_matrix(puzzle)
v = np.zeros((4, 3), dtype=int)

save_state_vectors(p, v)

while not all(found):
    v += (np.sign(p[0] - p) + np.sign(p[1] - p) + np.sign(p[2] - p) + np.sign(p[3] - p))
    p += v
    
    save_state_vectors(p, v)

print(np.lcm.reduce(list(map(len, states))))

467034091553512
CPU times: user 5.3 s, sys: 61.7 ms, total: 5.37 s
Wall time: 5.39 s
