In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install z3-solver

Collecting z3-solver
  Downloading z3_solver-4.8.14.0-py2.py3-none-manylinux1_x86_64.whl (33.0 MB)
[K     |████████████████████████████████| 33.0 MB 269 kB/s 
[?25hInstalling collected packages: z3-solver
Successfully installed z3-solver-4.8.14.0


In [None]:
import time

with open('/content/drive/MyDrive/AdventOfCode2021/aoc24.txt', 'r') as file:
    data = [x.strip('\n').split() for x in file]

In [None]:
from z3 import *
def arithmetic_logic_unit(operations, is_max):
    solver = Optimize()
    digits = [z3.BitVec(i, 64) for i in range(14)]
    for d in digits:
        solver.add(1 <= d)
        solver.add(d <= 9)
    digit_pos = 0
    zero, one = z3.BitVecVal(0, 64), z3.BitVecVal(1, 64)
    vars = {'w': zero,  'x': zero, 'y': zero, 'z': zero}

    for i, instruction in enumerate(operations):
        op = instruction[0]
        a = instruction[1]
        if op == 'inp':
            vars[a] = digits[digit_pos]
            digit_pos += 1
        else:
            b = instruction[2]
            if b in vars:
                b = vars[b]
            else:
                b = int(b)
            c = z3.BitVec(f'v_{i}', 64)
            if op == 'add':
                solver.add(c == vars[a] + b)
            elif op == 'mul':
                solver.add(c == vars[a] * b)
            elif op == 'div':
                solver.add(b != 0)
                solver.add(c == vars[a] / b)
            elif op == 'mod':
                solver.add(vars[a] >= 0)
                solver.add(b > 0)
                solver.add(c == vars[a] % b)
            elif op == 'eql':
                solver.add(c == z3.If(vars[a] == b, one, zero))
            vars[a] = c
    solver.add(vars['z'] == 0)

    if is_max == True:
        solver.maximize(sum(i * d for i, d in enumerate(digits[::-1])))
    else:
        solver.minimize(sum(i * d for i, d in enumerate(digits[::-1])))
    solver.check()
    m = solver.model()
    return ''.join([str(m[d]) for d in digits])

In [None]:
start_time = time.time()
print(arithmetic_logic_unit(data, True))
print("--- %s seconds ---" % (time.time() - start_time))

94992994195998
--- 13.023061037063599 seconds ---


In [None]:
start_time = time.time()
print(arithmetic_logic_unit(data, False))
print("--- %s seconds ---" % (time.time() - start_time))

21191861151161
--- 14.826415538787842 seconds ---
