---
title: "Weighted Union Find and Ground Knuth Bendix Completion"
date: 2026-02-18
---

A union find variant I think is simple and interesting is the "weighted" union find. This is distinguished from "size" or "rank" in that weight is considered a property of the id given by the user, not a internal property of the data structure https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Union_by_size . Deciding who becomes parent of whom is decided by comparing weights.

In [None]:
from dataclasses import dataclass,field

@dataclass
class WUF():
    parents : list[int] = field(default_factory=list)
    weights : list[int] = field(default_factory=list)

    def makeset(self, weight):
        id = len(self.parents)
        self.parents.append(id)
        self.weights.append(weight)
        return id
    
    def find(self, x):
        while self.parents[x] != x:
            x = self.parents[x] 
        return x

    def tiebreak(self, x, y):
        wx, wy = self.weights[x], self.weights[y]
        if wx > wy:
            return True
        elif wy > wx:
            return False
        else:
            return True # arbitrary tie break

    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            if self.tiebreak(x,y):
                self.parents[x] = y
            else:
                self.parents[y] = x


uf = WUF()
x = uf.makeset(3)
y = uf.makeset(4)
z = uf.makeset(5)
uf.union(x, y)
uf.union(y, z)
assert uf.find(x) == x
assert uf.find(y) == x
assert uf.find(z) == x


The reason I think this is interesting is we can then lift this to use on an egraph that more closely matches ground knuth bendix completion using a knuth bendix ordering https://www.philipzucker.com/ground_kbo/ . Ground knuth bendix ordering is basically comparing terms by size with tie breaking. The memo table is _for serious_ a hash cons. Each "id" describes exactly one term, not an eclass. 

In hash consing it often makes sense to memoize other properties of your terms immediately at construction. This can include precomputing the hash of the node and also the size, which is merely the sum of the memoized size of the children + 1. You can also do depth or any other variation you like.

Extraction becomes trivial as it is just turning the hash consed tree with `Id` indirection back into a regular tree. The ordering makes

Because nodes is in construction ordering, sweeping feels kind of nice.

Pointing to the best new terms is more like what compiler writers use Union finds for https://pypy.org/posts/2022/07/toy-optimizer.html

In [None]:
type Id = int

@dataclass(frozen=True)
class App:
    f : str
    args : tuple[Id, ...]

@dataclass
class GKB():
    memo : dict[App, Id] = field(default_factory=dict)
    nodes : list[App] = field(default_factory=list)
    parents : list[Id] = field(default_factory=list)
    weights : list[int] = field(default_factory=list)

    def mk_app(self, f, args):
        id = self.memo.get(App(f, args))
        if id is not None:
            return id
        else:
            id = len(self.parents)
            self.memo[App(f, args)] = id
            self.parents.append(id)
            self.nodes.append(App(f, args))
            self.weights.append(1 + sum(self.weights[arg] for arg in args))
            return id

    def find(self, x):
        while self.parents[x] != x:
            x = self.parents[x] 
        return x
    
    def tiebreak(self, x, y): # does Ground KBO basically
        wx, wy = self.weights[x], self.weights[y]
        if wx > wy:
            return True
        elif wy > wx:
            return False
        else:
            appx, appy = self.nodes[x], self.nodes[y]
            if appx.f > appy.f:
                return True
            elif appy.f > appx.f:
                return False
            else:
                assert len(appx.args) == len(appy.args) # assume same length args for now
                for ax, ay in zip(appx.args, appy.args):
                    #ax, ay = self.find(argx), self.find(argy) # perhaps do this. Changes meaning awat from terms though
                    if ax != ay:
                        return self.tiebreak(ax,ay)
                assert False, "should never reach here, tiebreak should have been resolved by now"

    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            if self.tiebreak(x, y):
                self.parents[x] = y
            else:
                self.parents[y] = x
        
    def rebuild(self):
        done = False
        while not done:
            done = True
            for id in range(len(self.nodes)):
                app = self.nodes[id]
                id1 = self.mk_app(app.f, tuple(self.find(arg) for arg in app.args))
                if self.find(id) != self.find(id1):
                    done = False
                    self.union(id, id1)


    def extract(self, id : Id):
        app = self.nodes[self.find(id)]
        return (app.f, tuple(self.extract(arg) for arg in app.args))

gkb = GKB()
a = gkb.mk_app("a", ())
a1 = gkb.mk_app("a", ())
assert a == a1
b = gkb.mk_app("b", ())
fa = gkb.mk_app("f", (a,))
fb = gkb.mk_app("f", (b,))
ffa = gkb.mk_app("f", (fa,))
ffb = gkb.mk_app("f", (fb,))
gkb.union(a, b)
gkb.rebuild()

assert gkb.find(ffa) == gkb.find(ffb)
print(gkb.extract(ffb))
gkb.union(ffa, a)
gkb.rebuild()
print(gkb.extract(ffb))

('f', (('f', (('a', ()),)),))


('a', ())

This is another in a sequence of union find variation posts
