---
title: Knuth Bendix Solver on Z3 AST
date : 2025-03-11
---

Knuth bendix completion takes in a set of equational axioms and a term ordering (defining which side is "simpler"), and tries to produce a rewrite rule system that is confluent and terminating.

You can read more here

- Term Rewriting and All That
- Harrison's Handbook of Practical Logic and Automated Reasoning
- https://www.researchgate.net/publication/220460160_An_Introduction_to_Knuth-Bendix_Completion
- https://en.wikipedia.org/wiki/Knuth%E2%80%93Bendix_completion_algorithm
-  Handbook of Automated Reasoning chapter https://www.cs.tau.ac.il/~nachum/papers/hand-final.pdf

Confluent means the order of application of the rules ultimately doesn't matter. Greedy usage of the rules will find the "smallest" term under that theory.

If your system is non confluent but terminating, you need to add some kind of backtracking or search if you want this guarantee.

Knuth Bendix can fail when it fails to be able to orient an equation. Term orderings on non ground terms are necessarily partial orders. A demon can always pick a way to fill in the variables of `X * Y -> Y * X` to make any ordering fail. It's too symmetric.

Knuth Bendix is not quite an "complete" equational theorem proving mechanism because of this (unnecessary) failure. It's close though, and I tend to think of it as one.

Paramodulation is a name for brute force equational search, ordering be damned. Unfailing completion is a loosening of ordinary knuth bendix, but more restricted than brute paramodulation that nevertheless is complete for equational theorem proving.

Paramodulation is kind of finding all critical pairs on all sides of equations. Knuth bendix restricts to just critical pairs of left hand sides of rules but also enables simplification by rules. Unfailing completion has a pruning mechanism for which critical pairs are necessary to consider and retains the simplifcation mechanism.

I have found it useful to use Z3's ast as a centralized intercommunication way of building up a library of useful stuff. Z3 offers a nice api, a good fast hash cons, de bruijn binder manipulations and not the mention the smt solving itself.

All the fiddling with variables is kind of what makes completion complicated. String knuth bendix is more straightforward https://www.philipzucker.com/string_knuth/ , likewise ground term knuth bendix.

They all follow the same pattern though. Pick an ordering.

Find "overlaps" of the left hand sides of two rules. These overlaps may get written two different directions. Infer that as a new equation to be processed.

The critical pair https://en.wikipedia.org/wiki/Critical_pair_(term_rewriting) refers to the two terms resulting from being rewritten by the two rules. The critical pair is the two terms that are equal in the rewrite system.

For terms with variables, we actually need to use unification rather than pattern matching to do our rewriting. This is called narrowing. It's a bit fiddly.




In [1]:
import kdrag as kd
import kdrag.smt as smt
import kdrag.rewrite as rw


def critical_pair_helper(
    vs: list[smt.ExprRef], t: smt.ExprRef, lhs: smt.ExprRef, rhs: smt.ExprRef
) -> list[tuple[smt.ExprRef, dict[smt.ExprRef, smt.ExprRef]]]:
    """
    Look for pattern lhs to unify with a subterm of t.
    returns a list of all of those lhs -> rhs applied + the substitution resulting from the unification.
    The substitution is so that we can apply the other `t -> s` rule once we return.


    This helper is asymmettric between t and lhs. You need to call it twice to get all critical pairs.
    """
    res = []
    if any(t.eq(v) for v in vs): # Non trivial overlap only `X ~ lhs` is not interesting.
        return res
    subst = kd.utils.unify(vs, t, lhs)
    if subst is not None:
        res.append((rhs, subst))
    f, children = t.decl(), t.children()
    for n, arg in enumerate(children):
        # recurse into subterms and lift result under f if found something
        for s, subst in critical_pair_helper(vs, arg, lhs, rhs):
            args = children[:n] + [s] + children[n + 1 :]
            res.append((f(*args), subst))
    return res

x,y,z = smt.Reals("x y z")
critical_pair_helper([x,y], -(-(-(x))), -(-(y)), y)


[(y, {y: -x}), (-y, {x: y}), (--y, {x: -y})]

In [2]:
def all_pairs(rules):
    """
    Find all the ways the left hand side of two rules can overlap.
    Return a derived equation
    
    """
    # TODO. I'm not treating encompassment correctly
    res = []
    for rule1 in rules:
        for rule2 in rules:
            # we're double counting when rule1 = rule2 but whatever
            if any(v1.eq(v2) for v1 in rule1.vs for v2 in rule2.vs):
                rule2 = rule2.freshen()
            vs = rule1.vs + rule2.vs
            for t, subst in critical_pair_helper(vs, rule1.lhs, rule2.lhs, rule2.rhs):
                #print(rule1, rule2, t, subst)
                apply_rule1 = smt.substitute(rule1.rhs, *subst.items())
                apply_rule2 = smt.substitute(t, *subst.items())
                vs1 = list(set(vs) - set(subst.keys()))
                if len(vs1) == 0:
                    res.append(apply_rule1 == apply_rule2)
                else:
                    res.append(
                        smt.ForAll(vs1, apply_rule1 == apply_rule2)
                )  # new derived equation
    return res

a,b,c,d = smt.Reals("a b c d")
all_pairs([rw.RewriteRule(vs=[], lhs=x, rhs=y) for x,y in [(a,b), (b,c), (a,c), (a,d)]])


[b == b,
 b == c,
 b == d,
 c == c,
 c == b,
 c == c,
 c == d,
 d == b,
 d == c,
 d == d]

You also want to orient rewrite rules. `RewriteRule` is a helper namedtuple to hold the pieces of a rule. It's kind of parsing well formed rules out of arbitrary z3 expressions.

In [3]:
rw.rewrite_of_expr(smt.ForAll([x,y], x * 0 == x))

RewriteRule(vs=[X!0, Y!1], lhs=X!0*0, rhs=X!0, pf=None)

In [4]:
def orient(eq : smt.BoolRef | smt.QuantifierRef, order=rw.kbo) -> rw.RewriteRule:
    r = rw.rewrite_of_expr(eq)
    if order(r.vs, r.lhs, r.rhs) == rw.Order.GR:
        return r
    elif order(r.vs, r.rhs, r.lhs) == rw.Order.GR:
        return r._replace(lhs=r.rhs, rhs=r.lhs)
    else:
        raise Exception("Cannot orient: " + str(eq))

x,y,z = smt.Reals("x y z")
print(orient(smt.ForAll([x], -(-x) == x)))
print(orient(smt.ForAll([x], x == -(-x))))
print(orient(smt.ForAll([x,y,z], x + z == x + y + z + x + y)))

RewriteRule(vs=[X!2], lhs=--X!2, rhs=X!2, pf=None)
RewriteRule(vs=[X!3], lhs=--X!3, rhs=X!3, pf=None)
RewriteRule(vs=[X!4, Y!5, Z!6], lhs=X!4 + Y!5 + Z!6 + X!4 + Y!5, rhs=X!4 + Z!6, pf=None)


You also want to simplify equations according to the current set of rewrite rules. I use my helper function `rewrite` from knuckledragger to do this. This part is not doing narrowing, you actually want regular pattern matching.

In [5]:
def simplify(t : smt.BoolRef | smt.QuantifierRef, rules : list[rw.RewriteRule]) -> smt.ExprRef:
    r = rw.rewrite_of_expr(t)
    lhs = rw.rewrite(r.lhs, rules)
    rhs = rw.rewrite(r.rhs, rules)
    return r._replace(lhs=lhs, rhs=rhs).to_expr()

simplify(smt.ForAll([x], -(-(-(-(-x)))) == -x), [rw.rewrite_of_expr(smt.ForAll([x], -(-x) == x))])

This detects trivial `t = t` equations

In [6]:
def is_trivial(t):
    r = rw.rewrite_of_expr(t)
    return r.lhs.eq(r.rhs)

assert is_trivial(x == x)
assert not is_trivial(x == -(-x))

The basic completion method just sprays creating critical pairs until all of them can be reduced to trivial by the current rules. It's a very brute force almost obvious way of attempting to complete (repair confluence) of the rules.

In [7]:
def basic(E, order=rw.kbo):
    R = []
    for eq in E:
        R.append(orient(eq, order=order))
        #print("new", R[-1])
    i = 0
    done = False
    #print("pairing")
    while not done:
        done = True
        #print(R)
        for eq in all_pairs(R):
            #print(eq)
            eq1 = simplify(eq, R)
            if not is_trivial(eq1):
                #print("orig", eq,  "\nsimp", eq1)
                R.append(orient(eq1))
                #print("new", R[-1])
                done = False
        i += 1
    return R

# TRaaT 7.1.1 Central Groupoid example
T = smt.DeclareSort("CentralGroupoid")
x,y,z = smt.Consts("x y z", T)
mul = smt.Function("mul", T, T, T)
kd.notation.mul.register(T, mul)
E = [smt.ForAll([x,y,z], (x * y) * (y * z) == y)]

basic(E)

[RewriteRule(vs=[X!9, Y!10, Z!11], lhs=mul(mul(X!9, Y!10), mul(Y!10, Z!11)), rhs=Y!10, pf=None),
 RewriteRule(vs=[X!29, Z!30, Z!31, Y!32], lhs=mul(Y!32, mul(mul(Y!32, Z!31), Z!30)), rhs=mul(Y!32, Z!31), pf=None),
 RewriteRule(vs=[X!41, Z!42, X!43, Y!44], lhs=mul(mul(X!43, mul(X!41, Y!44)), Y!44), rhs=mul(X!41, Y!44), pf=None)]

This is some variation of the Huet strategy of completion. It enables rules to be removed when they are being subsumed.

In [8]:
def huet(E, order=rw.kbo):
    E = E.copy()
    R = []
    while True:
        while E:
            eq = E.pop()
            eq = simplify(eq, R)
            if is_trivial(eq):
                continue
            r = orient(eq, order=order)
            Rnew = [r]
            for r1 in R:
                lhs1 = rw.rewrite(r1.lhs , [r])
                if lhs1.eq(r1.lhs):
                    rhs1 = rw.rewrite(r1.rhs, R + [r])
                    Rnew.append(r1._replace(rhs=rhs1))
                else:
                    E.append(r1._replace(lhs=lhs1).to_expr())
            R = Rnew
        #print(R)
        for eq in all_pairs(R):
            # by marking rules, we can prune the critical pair search, but I haven't done that
            # This is something like a semi-naive or given clause optimization
            # Always smash against at least 1 fresh new thing (rule in this case).
            # It might help a lot. Perfomance benchmarking suggests simplify is the bottleneck
            eq1 = simplify(eq, R)
            if not is_trivial(eq1):
                E.append(eq1)
                #break
        if len(E) == 0:
            return R
        #print(E[-1])

huet(E)

[RewriteRule(vs=[Y!320, Z!321, Z!322, X!323], lhs=mul(Y!320, mul(mul(Y!320, Z!321), Z!322)), rhs=mul(Y!320, Z!321), pf=None),
 RewriteRule(vs=[Y!308, Z!309, X!310, X!311], lhs=mul(mul(X!311, mul(X!310, Y!308)), Y!308), rhs=mul(X!310, Y!308), pf=None),
 RewriteRule(vs=[X!272, Y!273, Z!274], lhs=mul(mul(X!272, Y!273), mul(Y!273, Z!274)), rhs=Y!273, pf=None)]

I've implemented both knuth bendix (KBO) and lexicographic path ordering (LPO) by trying to just copy them out of TRaaT

I don't find term orderings very intuitive at all.

The basic intuition of KBO is that small size is better. You break size ties by recursing into the subtrees https://www.philipzucker.com/ground_kbo/

The basic intuition of LPO is that you want to push some symbols inside other symbols (`add` gets pushed inside `succ` for `add(succ(X), Y)) -> succ(add(X,Y))` whereas the sizes are kind of the same) and also that pushing complexity to right children is better than to the left children (like orienting associativity). A related intuition is that the symbol precedence ordering has some relation to the call graph ordering. A good precedence for things that are functional programming-like is often similar to the call graph / definitional ordering of those functions.

However, these intuitions get pretty mangled in order to deal with variables correctly. 


A commonly used example is completing the axioms of an abstract group. I think this needs LPO and not KBO. At least KBO isn't terminating in reasonable time for me.

Adding in redundant axioms can make it go much faster. It shortcuts deriving them itself. It's kind of cool actually that the thing is proving theorems like the left identity from the right identity law and other such things.

In [16]:
T = smt.DeclareSort("AbstractGroup")
x,y,z = smt.Consts("x y z", T)
e = smt.Const("a_e", T)
inv = smt.Function("c_inv", T, T)
mul = smt.Function("b_mul", T, T, T)
kd.notation.mul.register(T, mul)
kd.notation.invert.register(T, inv)
E = [
    smt.ForAll([x], e * x == x),
    # adding in these other redundant axioms makes it easier on the system
    #smt.ForAll([x], x * e == x),
    #smt.ForAll([x], x * inv(x) == e),
    smt.ForAll([x], inv(x) * x == e),
    smt.ForAll([x,y,z], (x * y) * z == x * (y * z)),
    #smt.ForAll([x,y], inv(x * y) == inv(y) * inv(x))
]
#basic(E, order=rw.lpo)
huet(E, order=rw.lpo)


[RewriteRule(vs=[Z!7317, X!7318], lhs=c_inv(b_mul(X!7318, Z!7317)), rhs=b_mul(c_inv(Z!7317), c_inv(X!7318)), pf=None),
 RewriteRule(vs=[Z!4647, X!4648], lhs=b_mul(X!4648, b_mul(c_inv(X!4648), Z!4647)), rhs=Z!4647, pf=None),
 RewriteRule(vs=[X!4638], lhs=c_inv(c_inv(X!4638)), rhs=X!4638, pf=None),
 RewriteRule(vs=[], lhs=c_inv(a_e), rhs=a_e, pf=None),
 RewriteRule(vs=[X!4633], lhs=b_mul(X!4633, c_inv(X!4633)), rhs=a_e, pf=None),
 RewriteRule(vs=[X!4364], lhs=b_mul(X!4364, a_e), rhs=X!4364, pf=None),
 RewriteRule(vs=[X!4282, Z!4283], lhs=b_mul(c_inv(X!4282), b_mul(X!4282, Z!4283)), rhs=Z!4283, pf=None),
 RewriteRule(vs=[X!4246], lhs=b_mul(a_e, X!4246), rhs=X!4246, pf=None),
 RewriteRule(vs=[X!4243], lhs=b_mul(c_inv(X!4243), X!4243), rhs=a_e, pf=None),
 RewriteRule(vs=[X!4238, Y!4239, Z!4240], lhs=b_mul(b_mul(X!4238, Y!4239), Z!4240), rhs=b_mul(X!4238, b_mul(Y!4239, Z!4240)), pf=None)]

# Bits and Bobbles

So it's kind of slow. 

Takes about 4 seconds for the group problem. A lot of time is in the simplifier, which in turn may be slow because of all the wrapping and unwrapping of going into and out of z3 for trivial-ish stuff. Worrying.

https://smimram.github.io/ocaml-alg/kb/ This webpage is instantaneous for the group problem

eprover is also ludicrously faster. I do believe the saturated clause set represents something like the same data as knuth bendix completion. eprover is a more complicated beast, so it's hard for me to always interpret what it is returning when I'm using it off label from refutational theorem proving.

In [10]:
%%file /tmp/group.p

cnf(id_left, axiom, mul(e, X) = X).
cnf(inv_left, axiom, mul(inv(X), X) = e).
cnf(assoc, axiom, mul(mul(X, Y), Z) = mul(X, mul(Y, Z))).


Overwriting /tmp/group.p


In [11]:
! time eprover-ho --print-saturated /tmp/group.p # you can also fiddle with the term ordering --term-ordering=KBO6  --order-weights or --precedence

# Initializing proof state
# Scanning for AC axioms
# mul is associative
#
#cnf(i_0_4, plain, (mul(e,X1)=X1)).
#
#cnf(i_0_5, plain, (mul(inv(X1),X1)=e)).
#
#cnf(i_0_6, plain, (mul(mul(X1,X2),X3)=mul(X1,mul(X2,X3)))).
#
#cnf(i_0_8, plain, (mul(inv(X1),mul(X1,X2))=X2)).
#
#cnf(i_0_14, plain, (mul(inv(e),X1)=X1)).
#
#cnf(i_0_12, plain, (mul(inv(inv(X1)),e)=X1)).
#
#cnf(i_0_16, plain, (mul(inv(inv(e)),X1)=X1)).
#
#cnf(i_0_23, plain, (inv(e)=e)).
#
#cnf(i_0_11, plain, (mul(inv(inv(X1)),X2)=mul(X1,X2))).
#
#cnf(i_0_12, plain, (mul(X1,e)=X1)).
#
#cnf(i_0_33, plain, (inv(inv(X1))=X1)).
#
#cnf(i_0_38, plain, (mul(X1,inv(X1))=e)).
#
#cnf(i_0_39, plain, (mul(X1,mul(inv(X1),X2))=X2)).
##
#cnf(i_0_41, plain, (mul(X1,mul(X2,inv(mul(X1,X2))))=e)).
#
#cnf(i_0_61, plain, (mul(X2,inv(mul(X1,X2)))=inv(X1))).
#
#cnf(i_0_78, plain, (mul(inv(mul(X1,X2)),X1)=inv(X2))).
#
#cnf(i_0_77, plain, (inv(mul(X2,X1))=mul(inv(X1),inv(X2)))).
##
# No proof found!
# SZS status Satisfiable
# Processed positive unit clause

Some other links to look at

https://www.philipzucker.com/string_knuth/

https://github.com/codyroux/knuth-bendix

- Twee
- Waldmeister https://www.mpi-inf.mpg.de/departments/automation-of-logic/software/waldmeister/download a fast C knuth bendix. Where's the source though?

Superposition provers are quite related. E prover, zipperposition, SPASS

mkbtt https://github.com/bytekid/mkbtt

https://www.metalevel.at/trs/ prolog knuth bendix

https://github.com/smimram/ocaml-alg

https://github.com/brandonwillard/mk-rewrite-completion


https://rg1-teaching.mpi-inf.mpg.de/autrea-ss11/script-4.7.pdf


Just implement paramodulation.



In [17]:
import z3
z3.Z3_DEBUG = False
smt.z3_debug()
smt.z3.Z3_DEBUG = False
smt.z3.z3_debug()
# turning off debug sanity checks shaves about 1s off, 25% speedup

False

In [18]:
%%prun 
huet(E, order=rw.lpo)

 

         39345171 function calls (39247118 primitive calls) in 10.420 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  3905417    1.420    0.000    1.595    0.000 z3core.py:1567(Check)
654671/654428    0.482    0.000    0.820    0.000 z3core.py:2979(Z3_func_decl_to_ast)
   559621    0.425    0.000    0.714    0.000 z3core.py:3124(Z3_get_ast_kind)
   544941    0.408    0.000    1.739    0.000 z3.py:345(__init__)
   544941    0.405    0.000    0.680    0.000 z3core.py:1637(Z3_inc_ref)
   544705    0.396    0.000    1.389    0.000 z3.py:350(__del__)
  8355539    0.385    0.000    0.385    0.000 z3types.py:24(from_param)
  2120782    0.380    0.000    0.496    0.000 z3.py:400(ctx_ref)
    35444    0.328    0.000    7.737    0.000 utils.py:48(pmatch)
   544705    0.320    0.000    0.373    0.000 z3core.py:1641(Z3_dec_ref)
   339799    0.307    0.000    0.525    0.000 z3core.py:3094(Z3_is_eq_ast)
   234008    0.297    0.000    1

In [13]:
T = smt.DeclareSort("AbstractGroup")
x,y,z = smt.Consts("x y z", T)
e = smt.Const("a_e", T)
inv = smt.Function("c_inv", T, T)
mul = smt.Function("b_mul", T, T, T)
kd.notation.mul.register(T, mul)
E = [
    smt.ForAll([x,y], inv(x * y) == inv(y) * inv(x)), #k28
    smt.ForAll([x,y], x * (inv(x) * y) == y), # k 16
    smt.ForAll([x], x * inv(x) == e), # k12
    smt.ForAll([x], inv(e) == e), # k11
    smt.ForAll([x], inv(inv(x)) == x), #k9
    smt.ForAll([x], x * e == x), # k2
    smt.ForAll([x,y], inv(x) * (x * y) == y), #k1
    #smt.ForAll([x], x * e == x),
    #smt.ForAll([x], x * inv(x) == e),
    smt.ForAll([x,y,z], (x * y) * z == x * (y * z)), #R1
    smt.ForAll([x], inv(x) * x == e),
    smt.ForAll([x], e * x == x),
]

basic(E, order=rw.lpo)
#huet(E, order=rw.lpo)

[RewriteRule(vs=[X!3990, Y!3991], lhs=c_inv(b_mul(X!3990, Y!3991)), rhs=b_mul(c_inv(Y!3991), c_inv(X!3990)), pf=None),
 RewriteRule(vs=[X!3992, Y!3993], lhs=b_mul(X!3992, b_mul(c_inv(X!3992), Y!3993)), rhs=Y!3993, pf=None),
 RewriteRule(vs=[X!3994], lhs=b_mul(X!3994, c_inv(X!3994)), rhs=a_e, pf=None),
 RewriteRule(vs=[X!3995], lhs=c_inv(a_e), rhs=a_e, pf=None),
 RewriteRule(vs=[X!3996], lhs=c_inv(c_inv(X!3996)), rhs=X!3996, pf=None),
 RewriteRule(vs=[X!3997], lhs=b_mul(X!3997, a_e), rhs=X!3997, pf=None),
 RewriteRule(vs=[X!3998, Y!3999], lhs=b_mul(c_inv(X!3998), b_mul(X!3998, Y!3999)), rhs=Y!3999, pf=None),
 RewriteRule(vs=[X!4000, Y!4001, Z!4002], lhs=b_mul(b_mul(X!4000, Y!4001), Z!4002), rhs=b_mul(X!4000, b_mul(Y!4001, Z!4002)), pf=None),
 RewriteRule(vs=[X!4003], lhs=b_mul(c_inv(X!4003), X!4003), rhs=a_e, pf=None),
 RewriteRule(vs=[X!4004], lhs=b_mul(a_e, X!4004), rhs=X!4004, pf=None)]

In [14]:
T = smt.DeclareSort("AbstractGroup")
x,y,z = smt.Consts("x y z", T)
e = smt.Const("a_e", T)
inv = smt.Function("c_inv", T, T)
mul = smt.Function("b_mul", T, T, T)
kd.notation.mul.register(T, mul)

print(rw.lpo([x,y], inv(x * y),  inv(y) * inv(x)))
print(rw.lpo([x,y],  inv(y) * inv(x), inv(x * y)))
print([x,y], inv(x * y),  inv(y) * inv(x))
print(orient(smt.ForAll([x,y], inv(x * y) == inv(y) * inv(x)), order=rw.lpo))

inv.name() < mul.name()

Order.GR
Order.NGE
[x, y] c_inv(b_mul(x, y)) b_mul(c_inv(y), c_inv(x))
RewriteRule(vs=[X!4230, Y!4231], lhs=c_inv(b_mul(X!4230, Y!4231)), rhs=b_mul(c_inv(Y!4231), c_inv(X!4230)), pf=None)


False

In [15]:

from dataclasses import dataclass
@dataclass
class KBState:
    E: list[smt.BoolRef | smt.QuantifierRef]
    R: list[rw.RewriteRule]