In [1]:
from collections import defaultdict
from itertools import combinations

def parse_input(input_file):
    gates = []
    values = {}
    is_gates = False
    with open(input_file) as f:
        for line in f:
            line = line.rstrip()
            if not line:
                is_gates = True
            elif not is_gates:
                a, b = line.split(': ')
                values[a] = int(b)
            else:
                a, op, b, _, c = line.split(' ')
                gates.append([a, op, b, c])
    return values, gates


def calc(a, op, b):
    match op:
        case 'XOR':
            return a ^ b
        case 'AND':
            return a & b
        case 'OR':
            return a | b
        
def part1(input_file):
    values, gates = parse_input(input_file)
    ans = 0
    unsolved = gates
    while unsolved:
        remain = []
        for a, op, b, c in unsolved:
            if a in values and b in values:
                values[c] = calc(values[a], op, values[b])
            else:
                remain.append([a, op, b, c])
        unsolved = remain
    for k, v in values.items():
        if k[0] == 'z' and v:
            idx = int(k[1:])
            ans |= (1 << idx)

    return ans


def to_node(prefix, i):
    if i < 10:
        return f'{prefix}0{i}'
    else:
        return f'{prefix}{i}'
    
def check(gates):
    ans = {}
    for i in range(45):
        ans[to_node('x', i)] = (1<<i, 0)
        ans[to_node('y', i)] = (0, 1<<i)
    
    unsolved = gates
    while unsolved:
        remain = []
        for a, op, b, c in unsolved:
            if a in ans and b in ans:
                ans[c] = (ans[a][0] | ans[b][0], ans[a][1] | ans[b][1])
            else:
                remain.append([a, op, b, c])
        unsolved = remain
    return ans

def furthest_made(gates):
    ops = {}
    for a, op, b, c in gates:
        ops[(frozenset([a, b]), op)] = c

    def get_res(a, b, op):
        return ops.get((frozenset([a, b]), op), None)

    carries = {}
    correct = set()
    prev_intermediates = set()
    for i in range(45):
        predigit = get_res(to_node('x', i), to_node('y', i), 'XOR')
        precarry1 = get_res(to_node('x', i), to_node('y', i), 'AND')
        if i == 0:
            assert predigit == to_node('z', 0)
            carries[i] = precarry1
            continue
        digit = get_res(carries[i-1], predigit, 'XOR')
        if digit != to_node('z', i):
            return i-1, correct
        correct.add(carries[i-1])
        correct.add(predigit)

        for wire in prev_intermediates:
            correct.add(wire)
        
        precarry2 = get_res(carries[i-1], predigit, 'AND')
        carry_out = get_res(precarry2, precarry1, 'OR')
        carries[i] = carry_out
        prev_intermediates = [precarry1, precarry2]
    return 45, correct

def part2(input_file):
    values, gates = parse_input(input_file)
    base, base_used = furthest_made(gates)
    swaps = []
    for _ in range(4):
        for i, j in combinations(range(len(gates)), 2):
            a_i, op_i, b_i, c_i = gates[i]
            a_j, op_j, b_j, c_j = gates[j]
            if 'z00' in (c_i, c_j):
                continue
            if c_i in base_used or c_j in base_used:
                continue
            gates[i] = [a_i, op_i, b_i, c_j]
            gates[j] = [a_j, op_j, b_j, c_i]
            attempt, attempt_used = furthest_made(gates)
            if attempt > base:
                print(f"Found a good swap. Got to a higher iteration number: {attempt}")
                swaps.extend((c_i, c_j))
                base, base_used = attempt, attempt_used
                break
            gates[i] = [a_i, op_i, b_i, c_i]
            gates[j] = [a_j, op_j, b_j, c_j]
    return ','.join(sorted(swaps))
    


In [22]:
part1('input/day24_test.txt')

2024

In [23]:
part1('input/day24.txt')

46362252142374

In [3]:
part2('input/day24.txt')

Found a good swap. Got to a higher iteration number: 12
Found a good swap. Got to a higher iteration number: 24
Found a good swap. Got to a higher iteration number: 37
Found a good swap. Got to a higher iteration number: 45


'cbd,gmh,jmq,qrh,rqf,z06,z13,z38'

In [None]:
# Day       Time  Rank  Score       Time   Rank  Score
#  24   00:23:11  1798      0   01:56:56    699      0

import re
from functools import cache
from itertools import combinations


ans1, ans2 = 0, 0
with open("input/day24.txt", "r") as f:
    text = f.read()

inputs, gates = text.split("\n\n")

input_pattern = r"([xy]\d\d): ([10])"
finished = {}
for line in inputs.split("\n"):
    match = re.search(input_pattern, line)
    input_name, val = match.groups()
    val = int(val)
    finished[input_name] = val

gate_pattern = r"([a-z0-9]{3}) ([XORAND]+) ([a-z0-9]{3}) -> ([a-z0-9]{3})"
ops = set()
op_list = []
for line in gates.split("\n"):
    match = re.search(gate_pattern, line)
    x1, op, x2, res = match.groups()
    ops.add((x1, x2, res, op))
    op_list.append((x1, x2, res, op))

# Part 1: simulation.
# Note that the dependencies make the outputs form a tree structure, with input nodes as leaves. Process nodes in order of depth -- this means we always have the operands ready once we get to any given gate.

# Calculating the structure of the tree
parents = {}
op_map = {}  # Mapping output name to corresponding operation (XOR, OR, AND)
for x1, x2, res, op in ops:
    parents[res] = (x1, x2)
    op_map[res] = op

@cache
def get_depth(reg):
    if reg in finished:
        return 0
    assert reg in parents
    x1, x2 = parents[reg]
    # Need to finish x1 and x2 first
    return max(get_depth(x1), get_depth(x2)) + 1

# Calculate in optimal order
vars = [(res, get_depth(res)) for _, _, res, _ in ops]
vars.sort(key=lambda x: x[1])  # Process lower depth first
for reg, _ in vars:
    assert reg in parents
    x1, x2 = parents[reg]
    v1, v2 = finished[x1], finished[x2]
    op = op_map[reg]
    val = {
        "XOR": lambda a, b: a ^ b,
        "OR": lambda a, b: a | b,
        "AND": lambda a, b: a & b,
    }[op](v1, v2)
    finished[reg] = val

# Concatenate outputs -> binary -> decimal
regs = list(finished.items())
regs.sort(key=lambda x: x[0])
num_out = int(str(regs[-1][0])[-2:]) + 1
bin_str = "".join(str(val) for _, val in regs[-num_out:])
ans1 = int(bin_str[::-1], 2)

# Part 2: given the list of operations, see how far we get and keep track of wires that must be correct.
# Derived which input lines corresponded to which wires in a ripple-carry adder. "Commit point" is when the output matches z15 or whatever -- we know everything that that output depended on must be correct.
def furthest_made(op_list):
    ops = {}
    for x1, x2, res, op in op_list:
        ops[(frozenset([x1, x2]), op)] = res  # hashability reason

    # here, x1 and x2 can be provided in any order :)
    def get_res(x1, x2, op):
        return ops.get((frozenset([x1, x2]), op), None)

    carries = {}
    correct = set()
    prev_intermediates = set()
    for i in range(45):
        pos = f"0{i}" if i < 10 else str(i)
        predigit = get_res(f"x{pos}", f"y{pos}", "XOR")
        precarry1 = get_res(f"x{pos}", f"y{pos}", "AND")
        if i == 0:
            # only two, XOR and AND
            assert predigit == "z00"
            carries[i] = precarry1
            continue
        digit = get_res(carries[i - 1], predigit, "XOR")
        if digit != f"z{pos}":
            return i - 1, correct

        # If it DID work, we know carries[i-1] and predigit were correct
        correct.add(carries[i - 1])
        correct.add(predigit)
        # Also add variables from previous position's ripple-carry adder module
        for wire in prev_intermediates:
            correct.add(wire)

        # Next, we compute the carries
        precarry2 = get_res(carries[i - 1], predigit, "AND")
        carry_out = get_res(precarry1, precarry2, "OR")
        carries[i] = carry_out
        prev_intermediates = set([precarry1, precarry2])

    return 45, correct

swaps = set()

base, base_used = furthest_made(op_list)  # Everything up to 20 is fine
for _ in range(4):
    # try swapping
    for i, j in combinations(range(len(op_list)), 2):
        x1_i, x2_i, res_i, op_i = op_list[i]
        x1_j, x2_j, res_j, op_j = op_list[j]
        # Don't switch z00 out
        if "z00" in (res_i, res_j):
            continue
        # Don't switch if these wires have already been used
        if res_i in base_used or res_j in base_used:
            continue
        # Switch output wires
        op_list[i] = x1_i, x2_i, res_j, op_i
        op_list[j] = x1_j, x2_j, res_i, op_j
        attempt, attempt_used = furthest_made(op_list)
        if attempt > base:
            print(f"Found a good swap. Got to a higher iteration number: {attempt}")
            swaps.add((res_i, res_j))
            base, base_used = attempt, attempt_used
            break
        # Switch output wires back
        op_list[i] = x1_i, x2_i, res_i, op_i
        op_list[j] = x1_j, x2_j, res_j, op_j
print(swaps)

ans2 = ",".join(sorted(sum(swaps, start=tuple())))

print(f"Part 1 answer: {ans1}")
print(f"Part 2 answer: {ans2}")

Found a good swap. Got to a higher iteration number: 12
Found a good swap. Got to a higher iteration number: 24
Found a good swap. Got to a higher iteration number: 37
Found a good swap. Got to a higher iteration number: 45
{('z06', 'jmq'), ('cbd', 'rqf'), ('z13', 'gmh'), ('z38', 'qrh')}
Part 1 answer: 46362252142374
Part 2 answer: cbd,gmh,jmq,qrh,rqf,z06,z13,z38
