In [280]:
import re
from collections import defaultdict
from dataclasses import dataclass

@dataclass(frozen=True, order=True)
class Connection:
    left: str
    right: str
    output: str
    op: str


def run_wires(connections, gates):
    gates = gates.copy()
    connections_by_operand = defaultdict(list)
    for connection in connections:
        connections_by_operand[connection.left].append(connection)
        connections_by_operand[connection.right].append(connection)

    in_degree = defaultdict(int)
    for connection in connections:
        for next_connection in connections_by_operand[connection.output]:
            in_degree[next_connection] += 1

    zero_degrees = [
        connection for connection in connections if in_degree.get(connection, 0) == 0
    ]

    while len(zero_degrees) > 0:
        connection = zero_degrees.pop()
        match connection.op:
            case 'AND':
                gates[connection.output] = (
                    gates[connection.left] & gates[connection.right]
                )
            case 'OR':
                gates[connection.output] = (
                    gates[connection.left] | gates[connection.right]
                )
            case 'XOR':
                gates[connection.output] = (
                    gates[connection.left] ^ gates[connection.right]
                )

        for next_connection in connections_by_operand[connection.output]:
            in_degree[next_connection] -= 1
            if in_degree[next_connection] == 0:
                zero_degrees.append(next_connection)
    return gates

def parse_input(text):
    x_value_lines, connection_lines = text.split('\n\n')
    gates = {}
    for line in x_value_lines.splitlines():
        x, value = re.search('(.+): (\d)', line).groups()
        gates[x] = int(value)

    connections = set()
    for line in connection_lines.splitlines():
        left, op, right, output = re.search(
            '([^\s]+) ([^\s]+) ([^\s]+) -> ([^\s]+)', line
        ).groups()
        left, right = sorted([left, right])
        connections.add(Connection(left, right, output, op))
    return connections


def get_subset(z_gates, connections):
    connection_by_output = {}
    for connection in connections:
        connection_by_output[connection.output] = connection

    stack = [connection_by_output[z_gate] for z_gate in z_gates]
    subset = set(stack)
    while len(stack) > 0:
        connection = stack.pop()
        if connection.left in connection_by_output and connection_by_output[connection.left] not in subset:
            stack.append(connection_by_output[connection.left])
            subset.add(connection_by_output[connection.left])
        if connection.right in connection_by_output and connection_by_output[connection.right] not in subset:
            stack.append(connection_by_output[connection.right])
            subset.add(connection_by_output[connection.right])
    return subset


In [282]:
with open('../input/24.txt') as f:
    text = f.read()


In [353]:

def swap(output1, output2, connections):
    con1 = next(
        con
        for con in connections
        if con.output == output1
    )
    con2 = next(
        con
        for con in connections
        if con.output == output2
    )

    return connections - {con1, con2} | {Connection(con1.left, con1.right, con2.output, con1.op), Connection(con2.left, con2.right, con1.output, con2.op)}
    

In [379]:
connections = parse_input(text)
ans = []
connections = swap('qjj', 'gjc', connections)
connections = swap('z17', 'wmp', connections)
connections = swap('z26', 'gvm', connections)
connections = swap('z39', 'qsb', connections)

In [None]:
# test 1

for n in range(1, 45):
    z_gates = [f'z{i:02d}' for i in range(n)]
    values = {}
    for i in range(n):
        values[f'x{i:02d}'] = 0
        values[f'y{i:02d}'] = 0

    for a in range(2):
        for b in range(2):
            values[f'x{n-1:02d}'] = a
            values[f'y{n-1:02d}'] = b
            values = run_wires(get_subset(z_gates, connections), values)
            assert a ^ b == values[f'z{n-1:02d}'], (f'z{n-1:02d}', a, b)

In [378]:
','.join(sorted(['qjj', 'gjc', 'z17', 'wmp', 'z26', 'gvm', 'z39', 'qsb',]))

'gjc,gvm,qjj,qsb,wmp,z17,z26,z39'

In [371]:

# test 2

n = 30
for i in range(n):
    values[f'x{i:02d}'] = 0
    values[f'y{i:02d}'] = 1
values['x00'] = 1
z_gates = [f'z{i:02d}' for i in range(n)]
values = run_wires(get_subset(z_gates, connections), values)
{
    name: value
    for name, value in sorted(values.items())
    if name.startswith('z')
}

{'z00': 0,
 'z01': 0,
 'z02': 0,
 'z03': 0,
 'z04': 0,
 'z05': 0,
 'z06': 0,
 'z07': 0,
 'z08': 0,
 'z09': 0,
 'z10': 0,
 'z11': 0,
 'z12': 0,
 'z13': 0,
 'z14': 0,
 'z15': 0,
 'z16': 0,
 'z17': 0,
 'z18': 0,
 'z19': 0,
 'z20': 0,
 'z21': 0,
 'z22': 0,
 'z23': 0,
 'z24': 0,
 'z25': 0,
 'z26': 1,
 'z27': 1,
 'z28': 1,
 'z29': 1}

In [292]:
def to_graphviz(connections):
    s = ''
    for connection in connections:
         s += f'{connection.output} [label="{connection.op} {connection.output}"];\n'
    for connection in connections:
        s += f'{connection.left} -> {connection.output};\n'
        s += f'{connection.right} -> {connection.output};\n'

    with open('graphviz.txt', 'w') as f:
        f.write(s)

In [343]:
to_graphviz(connections)