In [1]:
import numpy as np
import inspect
import itertools

In [2]:
# Define an interaction.
# In this case, we're modeling that r1 and r2 are redundant,
# (so you need to ablate both to get the loss to increase),
# but s is in series with r1 and r2, so ablating s always
# breaks the circuit.
def ablation_to_loss(r1, r2, s):
    return (r1 & r2) | s

signature = inspect.signature(ablation_to_loss)
bit_count = len(signature.parameters)
all_ablations = [
    {name: bit for name, bit in zip(signature.parameters, bits)}
    for bits in itertools.product((0, 1), repeat=bit_count)
]
all_ablations

[{'r1': 0, 'r2': 0, 's': 0},
 {'r1': 0, 'r2': 0, 's': 1},
 {'r1': 0, 'r2': 1, 's': 0},
 {'r1': 0, 'r2': 1, 's': 1},
 {'r1': 1, 'r2': 0, 's': 0},
 {'r1': 1, 'r2': 0, 's': 1},
 {'r1': 1, 'r2': 1, 's': 0},
 {'r1': 1, 'r2': 1, 's': 1}]

In [3]:
# Construct all limited degree ablations.
degree = 2

terms = [
    term
    for k in range(degree + 1)
    for term in itertools.combinations(signature.parameters, k)
]
term_to_ablation = {
    term: {k: +(k in term) for k in signature.parameters}
    for term in terms
}
term_to_ablation

{(): {'r1': 0, 'r2': 0, 's': 0},
 ('r1',): {'r1': 1, 'r2': 0, 's': 0},
 ('r2',): {'r1': 0, 'r2': 1, 's': 0},
 ('s',): {'r1': 0, 'r2': 0, 's': 1},
 ('r1', 'r2'): {'r1': 1, 'r2': 1, 's': 0},
 ('r1', 's'): {'r1': 1, 'r2': 0, 's': 1},
 ('r2', 's'): {'r1': 0, 'r2': 1, 's': 1}}

In [4]:
# Compute the matrix mat[ablation_index, term_index] which
# says if ablating some heads invokes a given term.
mat = np.array([
    [
        all(ablation[i] for i in term)
        for term in terms
    ]
    for ablation in all_ablations
], dtype=np.float64)
mat

array([[1., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 1., 0., 0., 0.],
       [1., 0., 1., 0., 0., 0., 0.],
       [1., 0., 1., 1., 0., 0., 1.],
       [1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 0., 1., 0., 1., 0.],
       [1., 1., 1., 0., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1.]])

In [5]:
def term_name(term):
    name = " ".join(i for i in term)
    return name or "1"

In [6]:
# Perform a least-squares fit, approximating the interaction to our chosen degree
sol, resid, _, _ = np.linalg.lstsq(mat, [ablation_to_loss(**ab) for ab in all_ablations], rcond=None)
assert len(sol) == len(terms)
if len(resid):
    print(f"Residual error: {resid[0]:.3f}")
for coef, term in zip(sol, terms):
    print(f"{coef:6.3f} * {term_name(term)}")

Residual error: 0.125
-0.125 * 1
 0.250 * r1
 0.250 * r2
 1.250 * s
 0.500 * r1 r2
-0.500 * r1 s
-0.500 * r2 s


In [7]:
# Compute the naive loss(ablate x and y) - loss(ablate x) - loss(ablate y) metric:
for x, y in itertools.combinations(signature.parameters, 2):
    loss_ablate_x_and_y = ablation_to_loss(**term_to_ablation[x, y])
    loss_ablate_x = ablation_to_loss(**term_to_ablation[x,])
    loss_ablate_y = ablation_to_loss(**term_to_ablation[y,])
    metric = loss_ablate_x_and_y - loss_ablate_x - loss_ablate_y
    print(f"{metric:6.3f} * {term_name((x, y))}")

 1.000 * r1 r2
 0.000 * r1 s
 0.000 * r2 s


In [8]:
# Observe: If degree = 3, then these two measures agree.
# However, at degree = 2 they disagree -- we're now asking:
#   "How can you best explain the behavior as degree 2 interactions?"
# And there *is* a degree 2 interaction between r1 and s, and also r2 and s.