In [2]:
from dataclasses import dataclass, field
from enum import IntEnum

@dataclass
class State:
    a: int
    b: int = 0
    c: int = 0
    pc: int = 0
    out: list[int] = field(default_factory=list)

with open('17.txt') as f:
    A, B, C, _, p = [l.split(': ')[-1] for l in f.read().splitlines()]
    program = list(map(int, p.split(',')))


# Part 1: Computation

In [3]:
class Op(IntEnum):
    adv = 0
    bxl = 1
    bst = 2
    jnz = 3
    bxc = 4
    out = 5
    bdv = 6
    cdv = 7

READ_COMBO = {Op.adv, Op.bdv, Op.cdv, Op.bst, Op.out}

def run(
    program: list[tuple[Op, int]],
    state: State,
    desired: list[int] = None,
):
    while state.pc < len(program):
        op = program[state.pc]
        operand = program[state.pc + 1]
        state.pc += 2
        op = Op(op)
        if op in READ_COMBO:
            match operand:
                case 4: operand = state.a
                case 5: operand = state.b
                case 6: operand = state.c
        # print(state, op.name, operand)
        match op:
            case Op.adv:
                state.a = state.a >> operand
            case Op.bxl:
                state.b = state.b ^ operand
            case Op.bst:
                state.b = operand % 8
            case Op.jnz:
                if state.a != 0:
                    state.pc = operand
            case Op.bxc:
                state.b = state.b ^ state.c
            case Op.out:
                state.out.append(operand % 8)
            case Op.bdv:
                state.b = state.a >> operand
            case Op.cdv:
                state.c = state.a >> operand
        if desired is not None:
            if len(state.out) >= len(desired):
                return state
            if not all(a == d for a, d in zip(state.out, desired)):
                return state
    return state

state = run(program, State(int(A), int(B), int(C)))
print(','.join(map(str, state.out)))

5,0,3,5,7,6,1,5,4


# Part 2: Taking a Nibble out of A

Let's figure out how to analyse the code we've been given:

In [4]:
def pprint(program: list[int]):
    for i in range(len(program) // 2):
        op = program[2*i]
        operand = program[2*i+1]
        if op in READ_COMBO:
            match operand:
                case 4: operand = 'A'
                case 5: operand = 'B'
                case 6: operand = 'C'
        match op:
            case Op.adv:
                line = f'A = A >> {operand}'
            case Op.bxl:
                line = f'B = B ^ {operand}'
            case Op.bst:
                line = f'B = {operand} % 8'
            case Op.jnz:
                line = f'if A != 0: PC = {operand}'
            case Op.bxc:
                line = 'B = B ^ C'
            case Op.out:
                line = f'print({operand} % 8)'
            case Op.bdv:
                line = f'B = A >> {operand}'
            case Op.cdv:
                line = f'C = A >> {operand}'
        print(i, line)

pprint(program)

0 B = A % 8
1 B = B ^ 1
2 C = A >> B
3 B = B ^ 5
4 A = A >> 3
5 B = B ^ C
6 print(B % 8)
7 if A != 0: PC = 0


While a bit arcane, if we rewrite it we see a bit more of the underlying methodology:
```
while A:
   B = (A % 8) ^ 1
   C = A >> B
   push ((B ^ 5) ^ C) % 8
   A = A // 8
```

The `A>>B` behaviour means we can't really do much in the way of XORing. However, this loop shows that really this is a function on the _input nibbles_ of A. More telling is that we only look backwards, not ahead – that is to say, nibble _i_ of the output depends ONLY on nibbles _0...i_ of the input. In other words:

In [None]:
def emulate(a: int):
    while a:
        b = (a % 8) ^ 1
        c = a >> b
        yield ((b ^ 5) ^ c) % 8
        a //= 8

from random import randint

a_start = randint(0, 1_000_000)
assert list(emulate(a_start)) == run(program, State(a_start)).out

In [19]:
def solve(a_start=0, j=15):
    if j == -1:
        yield a_start
        return
    for n in range(8):
        a = a_start + n*8**j
        output = list(emulate(a))
        if len(output) != len(program):
            continue
        if output[j] == program[j]:
            yield from solve(a, j-1)

min(solve())

164516454365621