# Day 10: Factory

## Part 1

In [147]:
with open('example.txt', 'r') as f:
    data = f.read().split('\n')[:-1]

In [148]:
machines = []
for row in data:
    split_row = re.sub(r'[\[\](){}]', '', row).split(' ') # replace brackets and parentheses with empty string and split on space

    diagram = [1 if i == '#' else 0 for i in split_row[0]]
    buttons = []
    for button in split_row[1:-1]:
        buttons.append([int(i) for i in button.split(',')])

    joltage = [int(i) for i in split_row[-1].split(',')]

    machines.append([diagram, buttons, joltage])

In [151]:
def get_potential_buttons(diagram, buttons):
    needed_indices = []
    for i in range(len(diagram)):
        if(diagram[i] == 1):
            needed_indices.append(i)

    needed_buttons = {}
    for button in buttons:
        num_affected = 0

        for i in button:
            if(i in needed_indices):
                num_affected += 1

        if(num_affected > 0):
            needed_buttons[tuple(button)] = num_affected

    return {k:v for k, v in sorted(needed_buttons.items(), key=lambda item: item[1], reverse=True)}

def press_button(state, button):
    new_state = []
    for i in range(len(state)):
        if(i in button):
            new_state.append((state[i]+1)%2)
        else:
            new_state.append(state[i])

    return new_state

ops = set()
prev_states = {}
def configure(diagram, state, buttons, depth=0):
    global ops
    global prev_states

    if(depth == 0):
        ops = set()
        prev_states = {}
    elif((len(ops) > 0 and depth >= min(ops)) or depth > len(diagram)):
        return -1

    if(tuple(state) in prev_states and prev_states.get(tuple(state)) < depth):
        return -1
    else:
        prev_states[tuple(state)] = depth

    if(diagram == state):
        return depth
    
    needed_operations = []
    for i in range(len(diagram)):
        if(diagram[i] != state[i]):
            needed_operations.append(1)
        else:
            needed_operations.append(0)

    potential_buttons = get_potential_buttons(needed_operations, buttons)

    if(len(potential_buttons) == 0 and state != diagram):
        return -1
    elif(state == diagram):
        return depth
    
    if(ops == set() and len(potential_buttons) > 1):
        potential_buttons = {k:v for k, v in sorted(potential_buttons.items(), key=lambda item: item[1], reverse=False)}

    
    for button in potential_buttons:
        new_state = press_button(state, button)
        solved_depth = configure(diagram, new_state, buttons, depth+1)

        if(solved_depth > 0 and solved_depth not in ops):
            ops.add(solved_depth)
    
    return -1 if ops == set() else min(ops)

total = 0
for index, machine in enumerate(machines):
    diagram = machine[0]
    buttons = machine[1]
    joltage = machine[2] # "can be safely ignored" lol

    cur_state = [0 for i in range(len(diagram))]
    total += configure(diagram, cur_state, buttons)
    # print(f'{index}/{len(machines)}')

# print(total)

## Part 2

In [146]:
import z3
import re

def create_z3_vars(buttons):
    buttons = machine[1]

    # create button variables and get buttons that affect each joltage index
    z3_buttons = []
    z3_indices = {} # {joltage_index: [button_x, button_y, button_z]}
    for index, button in enumerate(buttons):
        z3_button = z3.Int(f'b{index}')
        z3_buttons.append(z3_button)

        for i in button:
            if(i in z3_indices):
                z3_indices[i].append(z3_button)
            else:
                z3_indices[i] = [z3_button]
    return z3_buttons, z3_indices

def create_z3_conditions(z3_buttons, z3_indices, joltage):
    conditions = []
    for i in z3_indices:
        # the sum of all variables that affect each joltage index must equal the joltage value
        conditions.append([sum([i for i in z3_indices.get(i)]) == joltage[i]])

    # buttons cannot be pressed a negative number of times
    conditions.append([button >= 0 for button in z3_buttons])

    return conditions


with open('input.txt', 'r') as f:
    data = f.read().split('\n')[:-1]

machines = []
for row in data:
    split_row = re.sub(r'[\[\](){}]', '', row).split(' ') # replace brackets and parentheses with empty string and split on space

    diagram = [1 if i == '#' else 0 for i in split_row[0]]
    buttons = []
    for button in split_row[1:-1]:
        buttons.append([int(i) for i in button.split(',')])

    joltage = [int(i) for i in split_row[-1].split(',')]

    machines.append([diagram, buttons, joltage])

total = 0
for machine in machines:
    buttons = machine[1]
    joltage = machine[2]

    z3_buttons, z3_indices = create_z3_vars(buttons)
    z3_conditions = create_z3_conditions(z3_buttons, z3_indices, joltage)

    o = z3.Optimize()
    for condition in z3_conditions:
        o.add(condition)

    o.minimize(sum(z3_buttons))

    o.check()

    model = o.model()

    model_total = sum([int(model.evaluate(button).as_string()) for button in z3_buttons])
    total += model_total

# print(total)