---
title: A Contextual Union Finds
date: 2026-01-13
---

Something that is desired in egraph rewriting is rewriting under assumptions. The canonical example of this is writing inside the branches of an if-then-else `if x = y and y != 1 then x/y else x+2`. Obviously in one branch we know that `x/y` can be reduce to the constant `1`. However, we do not know that `x=y` globally. Another case that Eytan showed was `max(x,y) - min(x,y) = abs(x - y)`. You may also want assumptions to just see if you can make progress in an expression and then output those assumptions later (people do this sort of thing to try and simplify traces from a symbolic executor by assuming non aliasing of addresses). Or in the cases of a inductive proof.

The technique of assume nodes gives a way to encode this into an egraph rewriting system. The colored egraph work tries to bake this in.

As is often the case, I think there is a lot to be learned by taking a step back to look at the simpler case of a contextual union find. I think it's actually fairly straightforward.

The basic idea is to maintain a hierarchy of union finds. Unions asserted into the child union finds should not affect the parents, but find operations inside the children may have to look inside the parent.


If you can assume the parent union find stays fixed, that simplifies things. Then a persistent union find is acceptable. This is the sort of thing that occurs in a backtracking solver.

But we basically want to assume the case where new equalities are being discovered both in the global union find and in the child union finds

# Sparse Unions Finds

There are at least 3 flavors of union find. One flavor uses refcells, another uses a vector arena, and a third uses hashmaps.

I like the latter 2 more because it gives you a handle on the entire union find as a single entity, which can be useful for sweeping if need be.

This is the vector arena style. It's nice that it only requires a vector and hence has fast lookup.


In [None]:
from dataclasses import dataclass, field
@dataclass
class UFArena():
    parents : list[int] = field(default_factory=list)
    def makeset(self):
        eid = len(self.parents)
        self.parents.append(eid)
        return eid
    def find(self, x : int):
        while self.parents[x] != x:
            x = self.parents[x]
        return x
    def union(self, x : int, y : int):
        x,y = self.find(x), self.find(y)
        if x != y:
            if x < y:
                x,y = y,x
            self.parents[x] = y
        return y
    def rebuild(self):
        for i in len(self.parents):
            self.parents[i] = self.find(i)

uf = UFArena()
x,y,z = [uf.makeset() for i in range(3)]
uf.union(x,y)
uf

UFArena(parents=[0, 0, 2])

This is a different style. The vector above is in a sense being used as `dict[int,int]`. What is nice about the hashmap style is that it is more space efficient if you have very sparse unions, and also that it supports arbitrary hashable objects as keys

In [72]:
@dataclass
class UFDict():
    uf : dict[object,object] = field(default_factory=dict)
    def find(self, x):
        while x in self.uf:
            x = self.uf[x]
        return x
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            y,x = min(x,y), max(x,y)
            self.uf[x] = y
        return y
    def rebuild(self):
        for k in self.uf.keys():
            self.uf[k] = self.find(k)
    def items(self):
        return self.uf.items()

uf = UFDict()
uf.union(0,1)
uf.union(0,2)
uf

UFDict(uf={1: 0, 2: 0})

# A Context

The basic structure of a single context.

You unfortunately _do_ need to do some search like stuff in `find` if you want to avoid false negatives. If is possible for the parent union find to receive an update such that any strategy of eagerly finding and bouncing around between the parent and child union find misses the pathway to the truly canonical node.

As an example, consider a starting state of

```
[0,1,2,3]
{}
```

We then receive 1=2=3 in the context uf

```
[0,1,2,3]
{3 : 2, 2 : 1}
```

If we then receive `union(0,2)` in the parent context we get to

```
[0,1,0,3]
{3 : 2, 2 : 1}
```

Now if we eagerly take the lower union find on `find(3)` we get `3 -> 2 -> 1`, which misses the pathway from `2 -> 0` in the parent union find.

I believe this sort of problem can be cooked up for any fixed scam of bouncing around between the parent and children union find.

Having said all that, with path compression, the search doesn't have to be paid over and over, so maybe it's not all bad

This is a microcosm of theory combination actually, in that it is harder to combine (union) rewrite rules sets than you might think. Rebuilding is running completion again. Min is a mutually compatible ordering.

I need to maintain just an explicit enumeration? This sucks?

In [60]:
from collections import defaultdict
@dataclass
class UFContext():
    parentuf : UFArena
    uf : dict[object,object] = field(default_factory=dict)
    # Could also use linked list based or tree based enumerator
    ids : dict[object, set[object]] = field(default_factory=lambda: defaultdict(set))
    def makeset(self):
        x = self.parentuf.makeset()
        return x
    def find(self, x):
        while x in self.uf:
            x = self.uf[x]
        if x not in self.ids:
            return self.parentuf.find(x)
        # We could compress ys with respect to parents
        else:
            ys = self.ids[x]
            return min(min(self.parentuf.find(y) for y in ys), self.parentuf.find(x))
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            y,x = min(x,y), max(x,y)
            self.uf[x] = y
            self.ids[y] |= self.ids[x]
            self.ids[y].add(x)
        return y
    def rebuild(self):
        for k in self.uf.keys():
            self.uf[k] = self.find(k)

uf0 = UFArena()
uf1 = UFContext(uf0)

x,y,z,w = [uf1.makeset() for i in range(4)]
uf1
uf1.union(y,z)
uf1.union(z,w)
uf1
uf0.union(x,z)
uf1
uf1.find(x)
assert uf1.find(w) == uf1.find(z) # uh oh!
uf1

UFContext(parentuf=UFArena(parents=[0, 1, 0, 3]), uf={2: 1, 3: 1}, ids=defaultdict(<class 'set'>, {1: {2, 3}, 2: set(), 3: set()}))

# Structural Canonization of Union Finds for Keys



In [81]:
from collections import defaultdict
type CanonUF = object
@dataclass
class UFContextKeyed():
    biguf : UFArena = field(default_factory=UFArena)
    context_ufs : dict[CanonUF, UFContext] = field(default_factory=dict)
    def makeset(self):
        x = self.biguf.makeset()
        return x
    def make_key(self, *eqs):
        uf = UFDict()
        for l,r in eqs:
            uf.union(self.biguf.find(l), self.biguf.find(r))
        uf.rebuild()
        print(uf)
        return tuple(sorted(uf.items()))
    def make_context(self, *eqs):
        key = self.make_key(*eqs)
        uf = self.context_ufs.get(key)
        if uf is None:
            uf = UFContext(self.biguf)
            for l,r in eqs:
                uf.union(l,r)
            self.context_ufs[key] = uf
            return key, uf
        else:
            return key, uf
    def find(self, ctx, x):
        return self.context_ufs[ctx].find(x)
    def union(self, x, y, ctx=None):
        if ctx is None:
            return self.biguf.union(x,y)
        else:
            return self.context_ufs[ctx].union(x,y)
    def rebuild(self):
        # rebuild keys merge on key collision
        for k in self.uf.keys():
            self.uf[k] = self.find(k)
uf = UFContextKeyed()
x,y,z,w = [uf.makeset() for _ in range(4)]
key, uf1 = uf.make_context((x,y))
uf.union(y, z, ctx=key)
uf


UFDict(uf={1: 0})


UFContextKeyed(biguf=UFArena(parents=[0, 1, 2, 3]), context_ufs={((1, 0),): UFContext(parentuf=UFArena(parents=[0, 1, 2, 3]), uf={1: 0, 2: 0}, ids=defaultdict(<class 'set'>, {0: {1, 2}, 1: set(), 2: set()}))})

## Two Failures
At some point I thought just calling find on the child uf and then the parent uf would work. It does not.

Then I thought a form of search during find might be sufficient. It is not. This search isn't really that much better than maintaining the eclass set anyway.

I dunno. I may be missing something nice to do. If you figure it out please do tell. I will tell you that some things you try that feel intuitively fine are wrong.

The following is wrong

In [None]:
@dataclass
class UFContext():
    parentuf : UFArena
    uf : dict[object,object] = field(default_factory=dict)
    def makeset(self):
        return self.parentuf.makeset()
    def find(self, x):
        seen = set([x])
        todo = [x]
        while todo:
            x = todo.pop()
            y = self.uf.get(x)
            if y is not None and y not in seen:
                seen.add(y)
                todo.append(y)
            y = self.parentuf.parents[x]
            if y != x and y not in seen:
                seen.add(y)
                todo.append(y)
        y = min(seen)
        #for x in seen: # might as well path compress
        #    if y != x:
        #        self.uf[x] = y
        return y
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            y,x = min(x,y), max(x,y)
            self.uf[x] = y
        return y
    def rebuild(self):
        for k in self.uf.keys():
            self.uf[k] = self.find(k)

uf0 = UFArena()
x,y,z = [uf0.makeset() for i in range(3)]
uf1 = UFContext(uf0)
uf2 = UFContext(uf0)

uf1.union(x,y)
uf1
assert uf1.find(x) == uf1.find(y)
assert uf0.find(x) != uf0.find(y)
assert uf2.find(x) != uf2.find(y)

uf0.union(y,z) # contexts inherit 
assert uf2.find(y) == uf2.find(z)
assert uf1.find(x) == uf1.find(z)

uf1


UFContext(parentuf=UFArena(parents=[0, 1, 1]), uf={1: 0})

In [32]:
uf0 = UFArena()
x,y,z,w = [uf0.makeset() for i in range(4)]
uf1 = UFContext(uf0)

uf1.union(y,z)
uf1.union(z,w)
uf1

UFContext(parentuf=UFArena(parents=[0, 1, 2, 3]), uf={2: 1, 3: 1})

In [35]:
uf0.union(x,z)
uf1
uf1.find(x)
uf1.find(w) # uh oh!
uf1

UFContext(parentuf=UFArena(parents=[0, 1, 0, 3]), uf={2: 1, 3: 1})

In [31]:
uf1.rebuild()
uf1.rebuild()
uf1.find(w)
uf1

UFContext(parentuf=UFArena(parents=[0, 1, 0, 3]), uf={2: 0, 3: 1})

# Structural Caonization of Union Finds for 

The contextual union find
Show the counterexamples
Normalizing union finds as keys.

https://github.com/eytans/easter-egg/blob/master/src/colors.rs
