https://adventofcode.com/2022/day/21

Does using networkx make it too easy? Nah.

In [20]:
import operator
from math import isnan

import networkx as nx

NaN = float("NaN")

In [2]:
with open("data/21.txt") as fh:
    data = fh.read()

In [3]:
testdata = """\
root: pppw + sjmn
dbpl: 5
cczh: sllz + lgvd
zczc: 2
ptdq: humn - dvpt
dvpt: 3
lfqf: 4
humn: 5
ljgn: 2
sjmn: drzm * dbpl
sllz: 4
pppw: cczh / lfqf
lgvd: ljgn * ptdq
drzm: hmdt - zczc
hmdt: 32
"""

In [68]:
ops = {
    "+": operator.add,
    "-": operator.sub,
    "*": operator.mul,
    "/": operator.truediv,
    "=": operator.eq
}

def load_graph(data):
    g = nx.DiGraph()
    for line in data.strip().splitlines():
        node, rest = line.split(": ")
        remainder = rest.split()
        if len(remainder) == 1:
            g.add_node(node, value=int(remainder[0]))
        elif len(remainder) == 3:
            p1, op, p2 = remainder
            g.add_node(node, value=NaN, formula=(op, p1, p2))
            g.add_edge(p1, node)
            g.add_edge(p2, node)
        else:
            raise ValueError(remainder)
    return g

In [69]:
%%time
g = load_graph(testdata)
nodes = nx.nodes(g)
for n in nx.topological_sort(g):
    node = nodes[n]
    if isnan(node["value"]):
        op, p1, p2 = node["formula"]
        node["value"] = ops[op](nodes[p1]["value"], nodes[p2]["value"])
nodes["root"]["value"]

CPU times: user 369 µs, sys: 0 ns, total: 369 µs
Wall time: 384 µs


152.0

In [71]:
%%time
g = load_graph(data)
nodes = nx.nodes(g)
for n in nx.topological_sort(g):
    node = nodes[n]
    if isnan(node["value"]):
        op, p1, p2 = node["formula"]
        node["value"] = ops[op](nodes[p1]["value"], nodes[p2]["value"])
nodes["root"]["value"]

CPU times: user 19.5 ms, sys: 0 ns, total: 19.5 ms
Wall time: 18.9 ms


38731621732448.0

Part 2

In [162]:
def make_solvers():

    def add(a, b, c):
        """a + b = c"""
        if isnan(a):
            return c-b, b
        else:
            return a, c-a

    def sub(a, b, c):
        """a - b = c"""
        if isnan(a):
            return b+c, b
        else:
            return a, a-c

    def mul(a, b, c):
        """a * b = c"""
        if isnan(a):
            return c/b, b
        else:
            return a, c/a

    def div(a, b, c):
        """a / b = c"""
        if isnan(a):
            return b*c, b
        else:
            return a, a/c

    def eq(a, b, c):
        if isnan(a):
            return b, b, b==b
        else:
            return a, a, a==a

    return {
        "+": add,
        "-": sub,
        "*": mul,
        "/": div,
        "=": eq,
    }

solvers = make_solvers()

In [189]:
%%time
g = load_graph(testdata)
nodes = nx.nodes(g)

CPU times: user 77 µs, sys: 11 µs, total: 88 µs
Wall time: 93 µs


In [190]:
nodes["humn"]["value"] = NaN

In [191]:
root = nodes["root"]
oldop, p1, p2 = root["formula"]
root["formula"] = ("=", p1, p2)

In [192]:
for n in nx.topological_sort(g):
    node = nodes[n]
    if isnan(node["value"]):
        try:
            op, p1, p2 = node["formula"]
        except KeyError:
            pass
        else:
            node["value"] = ops[op](nodes[p1]["value"], nodes[p2]["value"])

In [194]:
list(g.predecessors("root"))

['pppw', 'sjmn']

In [195]:
nodes["pppw"], nodes["sjmn"]

({'value': nan, 'formula': ('/', 'cczh', 'lfqf')},
 {'value': 150, 'formula': ('*', 'drzm', 'dbpl')})

In [196]:
starter = "pppw"
nodes[starter]["value"] = 150

In [207]:
L = []
it = nx.bfs_successors(g.reverse(copy=False), starter)
for n, *_ in it:
    if n == "humn":
        break
    node = nodes[n]
    c = node["value"]
    if isnan(c):
        L.append(n)
        continue
    op, p1, p2 = node["formula"]
    pa = nodes[p1]
    a = pa["value"]
    pb = nodes[p2]
    b = pb["value"]
    if not(isnan(a) or isnan(b)):
        continue
    a1, b1 = solvers[op](a, b, c)
    pa["value"] = a1
    pb["value"] = b1

print(f"humn = {nodes['humn']['value']}")

humn = 301.0


In [205]:
nodes["humn"]

{'value': 301.0}

In [208]:
%%time
g = load_graph(data)
nodes = nx.nodes(g)

CPU times: user 15.1 ms, sys: 0 ns, total: 15.1 ms
Wall time: 14.5 ms


In [209]:
nodes["humn"]["value"] = NaN

In [210]:
root = nodes["root"]
oldop, p1, p2 = root["formula"]
root["formula"] = ("=", p1, p2)

In [211]:
for n in nx.topological_sort(g):
    node = nodes[n]
    if isnan(node["value"]):
        try:
            op, p1, p2 = node["formula"]
        except KeyError:
            pass
        else:
            node["value"] = ops[op](nodes[p1]["value"], nodes[p2]["value"])

In [212]:
list(g.predecessors("root"))

['lsbv', 'bsgz']

In [213]:
nodes["lsbv"], nodes["bsgz"]

({'value': nan, 'formula': ('*', 'mgtt', 'lvcj')},
 {'value': 2228768553328.0, 'formula': ('*', 'dsbn', 'mmwh')})

In [214]:
starter = "lsbv"
nodes[starter]["value"] = 2228768553328

In [215]:
%%time
L = []
it = nx.bfs_successors(g.reverse(copy=False), starter)
for n, *_ in it:
    if n == "humn":
        break
    node = nodes[n]
    c = node["value"]
    if isnan(c):
        L.append(n)
        continue
    op, p1, p2 = node["formula"]
    pa = nodes[p1]
    a = pa["value"]
    pb = nodes[p2]
    b = pb["value"]
    if not(isnan(a) or isnan(b)):
        continue
    a1, b1 = solvers[op](a, b, c)
    pa["value"] = a1
    pb["value"] = b1

print(f"humn = {nodes['humn']['value']}")

humn = 3848301405790.0
CPU times: user 5.44 ms, sys: 164 µs, total: 5.6 ms
Wall time: 5.37 ms
