## Setup

In [5]:
# Get raw advent-of-code data
from aocd.models import Puzzle

puzzle = Puzzle(year=2015, day=7)
input_data = puzzle.input_data

In [6]:
# Import performance checking utility
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

from common.utils.perf_check import time_solution

## Part a

### NetworkX graph construction
It looks like we need to construct a directed graph from the wires and their connections. Let's use NetworkX for this.

In [None]:
# Imports
import networkx as nx

In [111]:
# Constants
OP_TO_FUNC = {
    "IS": lambda x: x,  # Manually added for direct connections
    "NOT": lambda x: ~x,
    "AND": lambda x, y: x & y,
    "OR": lambda x, y: x | y,
    "LSHIFT": lambda x, y: x << y,
    "RSHIFT": lambda x, y: x >> y,
}

In [None]:
# Functions
def parse_input_networkx(input_data: str) -> nx.DiGraph:
    """Parse the input data into a directed acyclic graph (DAG)."""
    graph = nx.DiGraph()

    for line in input_data.splitlines():
        # Parse each line into input and output
        inp, out = line.split(" -> ")
        inp_tokens = inp.split()

        # Add the output node to the graph
        graph.add_node(out)

        # Then, collect the arguments and operations from the input
        if len(inp_tokens) == 1:
            # Direct connection (e.g. <wire> -> <wire>)
            op = "IS"  # Custom operation for direct connection
            args = [inp_tokens[0]]
        elif len(inp_tokens) == 2:
            # Unary operation (e.g. NOT <wire> -> <wire>)
            op = inp_tokens[0]
            args = [inp_tokens[1]]
        else:
            # Binary operation (e.g. <wire> AND <wire> -> <wire>)
            args = [inp_tokens[0], inp_tokens[2]]
            op = inp_tokens[1]

        # Add the operation and arguments as attributes of the output node
        graph.add_node(out, op=op, args=args)

        for arg in args:
            # Add an edge from each argument to the output node
            graph.add_edge(arg, out)
            if arg.isdigit():
                # Argument is a digit, store its value directly as a node attribute
                graph.add_node(arg, value=int(arg))

    return graph


def solve_a_networkx(input_data: str) -> int:
    """Compute all node values in the graph and return the value of node 'a'."""
    graph = parse_input_networkx(input_data)

    for node in nx.topological_sort(graph):
        if "value" in graph.nodes[node]:
            # Value already computed
            continue

        # Unpack operation and arguments
        op = graph.nodes[node]["op"]
        args = graph.nodes[node]["args"]
        arg_vals = [int(graph.nodes[a]["value"]) for a in args]

        if len(args) == 1:
            # Unary operation
            graph.nodes[node]["value"] = OP_TO_FUNC[op](arg_vals[0])

        else:
            # Binary operation
            graph.nodes[node]["value"] = OP_TO_FUNC[op](*arg_vals)

    return int(graph.nodes["a"]["value"])

956

In [108]:
# Performance check
time_a_networkx = time_solution(solve_a_networkx, input_data)

solve_a_networkx takes 1.09 ms


### Recursive approach
Let's see if we even need to use NetworkX for this problem. Instead of traversing the while graph starting with the known values and then finding the value of 'a', let's try to recursively evaluate the value of 'a' based on its dependencies.

In [124]:
# Imports
from functools import lru_cache

In [141]:
# Functions
def parse_input_recursive(input_data: str) -> dict[str, tuple[str, list[str]]]:
    """Parse input data into a {wire: (op, [operands])} dictionary."""
    wire_expressions = {}

    for line in input_data.splitlines():
        # Parse each line into input and output
        inp, out = line.split(" -> ")
        input_tokens = inp.split()

        # Collect the arguments and operations from the input
        if len(input_tokens) == 1:
            # Direct connection
            wire_expressions[out] = ("IS", [input_tokens[0]])
        elif len(input_tokens) == 2:
            # Unary operation
            op, arg = input_tokens
            wire_expressions[out] = (op, [arg])
        else:
            # Binary operation
            left, op, right = input_tokens
            wire_expressions[out] = (op, [left, right])

    return wire_expressions


def solve_a_recursive(
    input_data: str | dict[str, tuple[str, list[str]]],
) -> tuple[int, dict[str, tuple[str, list[str]]]]:
    """Evaluate the value of wire 'a' using recursion and memoization."""
    wire_expressions = parse_input_recursive(input_data) if isinstance(input_data, str) else input_data

    @lru_cache(min(len(wire_expressions), 1_000))
    def get_wire_value(token: str) -> int:
        """Evaluate the value of a wire by recursively evaluating its expression arguments."""
        if token.isdigit():
            # It's a direct integer value
            return int(token)

        op, args = wire_expressions[token]

        if len(args) == 1:
            val = OP_TO_FUNC[op](get_wire_value(args[0]))
        else:
            a = get_wire_value(args[0])
            b = get_wire_value(args[1])
            val = OP_TO_FUNC[op](a, b)

        return val

    # Return the value of wire 'a' and the wire expressions for recursive graph parsing
    return int(get_wire_value("a")), wire_expressions

In [138]:
# Performance check
time_a_recursive = time_solution(solve_a_recursive, input_data)
print(f"This is {time_a_networkx / time_a_recursive:.1f}x faster than the NetworkX approach.")

solve_a_recursive takes 0.19 ms
This is 5.7x faster than the NetworkX approach.


In [None]:
# Submit answer
puzzle.answer_a = solve_a_recursive(input_data)[0]

[32mThat's the right answer!  You are one gold star closer to powering the weather machine. [Continue to Part Two][0m


## Part b
To avoid having to parse the input data twice, I've modified the recursive solution to return both the value of wire 'a' and the parsed wire expressions.

In [142]:
# Functions
def solve_b(input_data: str) -> int:
    """Find the value of 'a', feed it back as wire 'b', and re-evaluate 'a'."""
    # Calculate initial value of 'a'
    a_initial, wire_expressions = solve_a_recursive(input_data)

    # Inject the value of 'a' into wire 'b'
    wire_expressions["b"] = ("IS", [str(a_initial)])

    # Re-evaluate 'a' with the updated wire expressions
    return int(solve_a_recursive(wire_expressions)[0])

In [143]:
# Performance check
time_b = time_solution(solve_b, input_data)

solve_b takes 0.28 ms


In [144]:
# Submit answer
puzzle.answer_b = solve_b(input_data)

[32mThat's the right answer!  You are one gold star closer to powering the weather machine.You have completed Day 7! You can [Shareon
  Bluesky
Twitter
Mastodon] this victory or [Return to Your Advent Calendar].[0m
