# Per-example gradnorms

In [9]:
import torch
import torch.nn.functional as F


class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, bias=True):
        super().__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim, bias=bias),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim, bias=bias),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim, bias=bias)
        )

    def forward(self, x):
        return self.main(x)


class MLPWithPerExampleGradNorm(MLP):
    def __init__(self, input_dim, hidden_dim, output_dim, bias=True):
        super().__init__(input_dim, hidden_dim, output_dim, bias)
        self.cached_inputs = {}
        self.perexample_gradnorms = {}
        for m in self.main.modules():
            if isinstance(m, torch.nn.Linear):
                m.register_forward_hook(self._make_linear_forward_hook())
                m.register_full_backward_hook(self._make_linear_backward_hook())

    def _make_linear_forward_hook(self):
        def hook(module, input, output):
            self.cached_inputs[module] = input[0]
        return hook

    def _make_linear_backward_hook(self):
        def hook(module, grad_input, grad_output):
            input = self.cached_inputs.pop(module)  # (batch_size, in_features)
            if input is None:
                raise ValueError("No cached input found for module during backward pass. How?")
            grad_output = grad_output[0]  # (batch_size, out_features)
            # Compute per-example gradnorms
            # G = sum_i G_i, where G_i = x_i s_i^T is the per-example gradient
            # ||G_i||_F^2 = Tr(G_i^T G_i) = Tr(x_i^T x_i s_i^T s_i) = ||x_i||_2^2 ||s_i||_2^2
            perexample_gradnorms = input.pow(2).sum(dim=1) * grad_output.pow(2).sum(dim=1)
            self.perexample_gradnorms[module] = perexample_gradnorms  # (batch_size,)
            if module.bias is not None:
                # Add bias contribution: ||s_i||_2^2
                self.perexample_gradnorms[module] += grad_output.pow(2).sum(dim=1)
        return hook

    def get_per_example_gradnorms(self):
        if len(self.perexample_gradnorms) == 0:
            raise ValueError("No per-example grad norms computed yet. Run a backward pass first.")
        # Sum per-example gradnorms from all layers
        total_gradnorms = torch.zeros_like(next(iter(self.perexample_gradnorms.values())))
        for gradnorms in self.perexample_gradnorms.values():
            total_gradnorms += gradnorms
        self.perexample_gradnorms = {}  # clear for next use
        return total_gradnorms.sqrt().tolist()


batch_size = 4
model_args = {
    "input_dim": 10,
    "hidden_dim": 20,
    "output_dim": 1,
}
print("Calculating per-example gradient norms (naive vs. smart implementation):\n")

for seed in range(3):

    # Naive implementation
    torch.manual_seed(seed)
    model = MLP(**model_args)
    # Random batch and labels
    x = torch.randn(batch_size, model_args["input_dim"])
    y = torch.randn(batch_size, model_args["output_dim"])
    perexample_gradnorms = []
    for i in range(batch_size):
        model.zero_grad()
        loss = F.mse_loss(model(x[[i]]), y[[i]])
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # shortcut to compute norm
        perexample_gradnorms.append(grad_norm.item())
    print(f"(Naive) Seed {seed}: Per-example grad norms: {perexample_gradnorms}")

    # Smarter implementation
    torch.manual_seed(seed)
    smart_model = MLPWithPerExampleGradNorm(**model_args)
    # --- Use same batch and labels as before
    # Copy parameters from original model
    for param, smart_param in zip(model.parameters(), smart_model.parameters()):
        smart_param.data.copy_(param.data)
    smart_model.zero_grad()
    # Sum reduction removes 1 / batch_size from backward. For mean, multiply perexample_gradnorms by batch_size
    loss = F.mse_loss(smart_model(x), y, reduction="sum")
    loss.backward()
    perexample_gradnorms = smart_model.get_per_example_gradnorms()
    print(f"(Smart) Seed {seed}: Per-example grad norms: {perexample_gradnorms}")
    print()

Calculating per-example gradient norms (naive vs. smart implementation):

(Naive) Seed 0: Per-example grad norms: [4.181707382202148, 1.6636853218078613, 4.268483638763428, 1.8983454704284668]
(Smart) Seed 0: Per-example grad norms: [4.181707382202148, 1.6636853218078613, 4.268483638763428, 1.8983453512191772]

(Naive) Seed 1: Per-example grad norms: [4.523588180541992, 0.41972821950912476, 2.3420145511627197, 2.0537424087524414]
(Smart) Seed 1: Per-example grad norms: [4.52358865737915, 0.41972821950912476, 2.3420145511627197, 2.0537424087524414]

(Naive) Seed 2: Per-example grad norms: [0.15932044386863708, 3.5918500423431396, 0.5696545243263245, 3.012357711791992]
(Smart) Seed 2: Per-example grad norms: [0.1593204289674759, 3.5918498039245605, 0.5696545243263245, 3.012357473373413]



  loss.backward()


## Naive impl speed

In [10]:
%%time

batch_size = 1024
model_args = {
    "input_dim": 256,
    "hidden_dim": 512,
    "output_dim": 1,
}

for seed in range(3):
    # Naive implementation
    torch.manual_seed(seed)
    # Random batch and labels
    x = torch.randn(batch_size, model_args["input_dim"])
    y = torch.randn(batch_size, model_args["output_dim"])
    model = MLP(**model_args)
    data_signature = (x.sum() + y.sum()).item()
    model_signature = sum(p.sum().item() for p in model.parameters())

    perexample_gradnorms = []
    for i in range(batch_size):
        model.zero_grad()
        loss = F.mse_loss(model(x[[i]]), y[[i]])
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # shortcut to compute norm
        perexample_gradnorms.append(grad_norm.item())
    print(f"Seed {seed}:")
    print(f" - Model signature = {model_signature},\tData signature = {data_signature}")
    print(f" - Per-example grad norms = {perexample_gradnorms}\n")

Seed 0:
 - Model signature = 23.65858732513152,	Data signature = -836.2041015625
 - Per-example grad norms = [29.17416763305664, 15.9073486328125, 0.8735095858573914, 10.837361335754395, 35.8778190612793, 16.906208038330078, 19.583072662353516, 0.5693705677986145, 6.071939945220947, 10.3909912109375, 7.332764148712158, 9.699533462524414, 5.973868370056152, 8.744900703430176, 13.382445335388184, 7.366543769836426, 17.127714157104492, 2.558971405029297, 2.2834677696228027, 13.379472732543945, 2.7214338779449463, 25.425153732299805, 9.37035846710205, 5.889320373535156, 2.703874349594116, 6.452304840087891, 10.52841854095459, 5.967186450958252, 5.297926425933838, 0.7846266031265259, 9.718938827514648, 18.41118049621582, 12.14617919921875, 11.719391822814941, 14.81567668914795, 1.6678798198699951, 5.808945655822754, 5.56652307510376, 24.40176010131836, 8.098150253295898, 5.926897048950195, 15.639555931091309, 13.868232727050781, 18.499622344970703, 10.567106246948242, 10.44338321685791, 9.2

## Hook impl speed

Make sure signatures match

In [11]:
%%time

batch_size = 1024
model_args = {
    "input_dim": 256,
    "hidden_dim": 512,
    "output_dim": 1,
}

for seed in range(3):
    # Smarter implementation
    torch.manual_seed(seed)
    x = torch.randn(batch_size, model_args["input_dim"])
    y = torch.randn(batch_size, model_args["output_dim"])
    # Copy parameters from original model
    model = MLPWithPerExampleGradNorm(**model_args)
    data_signature = (x.sum() + y.sum()).item()
    model_signature = sum(p.sum().item() for p in model.parameters())

    model.zero_grad()
    # Sum reduction removes 1 / batch_size from backward. For mean, multiply perexample_gradnorms by batch_size
    loss = F.mse_loss(model(x), y, reduction="sum")
    loss.backward()
    perexample_gradnorms = model.get_per_example_gradnorms()
    print(f"Seed {seed}:")
    print(f" - Model signature = {model_signature},\tData signature = {data_signature}")
    print(f" - Per-example grad norms = {perexample_gradnorms}\n")

Seed 0:
 - Model signature = 23.65858732513152,	Data signature = -836.2041015625
 - Per-example grad norms = [29.174192428588867, 15.907368659973145, 0.8735107183456421, 10.837377548217773, 35.87785339355469, 16.90621566772461, 19.583093643188477, 0.56937175989151, 6.071944236755371, 10.390999794006348, 7.332773685455322, 9.699540138244629, 5.973875999450684, 8.744909286499023, 13.382454872131348, 7.366549015045166, 17.12773323059082, 2.558974504470825, 2.2834696769714355, 13.379487991333008, 2.7214372158050537, 25.4251766204834, 9.370365142822266, 5.889322757720947, 2.7038753032684326, 6.4523115158081055, 10.528424263000488, 5.967191219329834, 5.297930717468262, 0.7846272587776184, 9.718947410583496, 18.41120147705078, 12.146191596984863, 11.719407081604004, 14.815692901611328, 1.6678818464279175, 5.808952808380127, 5.5665283203125, 24.401775360107422, 8.098155975341797, 5.926903247833252, 15.639564514160156, 13.868242263793945, 18.49962615966797, 10.567116737365723, 10.44339370727539



(you can ignore the UserWarning)

# Per-example influence with `autograd.Function`

i.e., $\langle \nabla f(x_i), \nabla f(x_\text{ref}) \rangle$ per example $x_i$

In [12]:
class LinearInfluenceFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, input_ref=None, precond=None, buf=None, compute_grads=False):
        ctx.save_for_backward(input, weight, bias, input_ref, precond)
        # Don't put the buffer in saved_tensors because there is a check for inplace changes
        # autograd will complain that buf was changed inplace, even though it is not used in any operation
        ctx._influence_buf = buf
        ctx._compute_grads = compute_grads

        # output wrt input: output = input @ weight.T + bias
        output = input @ weight.T
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        # output wrt input_ref: output = input_ref @ weight.T + bias
        output_ref = None
        if input_ref is not None:
            output_ref = input_ref @ weight.T
            if bias is not None:
                output_ref += bias.unsqueeze(0).expand_as(output_ref)

        return output, output_ref

    @staticmethod
    def backward(ctx, grad_output, grad_output_ref):
        input, weight, bias, input_ref, precond = ctx.saved_tensors
        if precond is None:
            precond_input = precond_grad_output = 1.0
        else:
            precond_input = precond_grad_output = precond

        # We only need to compute per-example gradient norms
        if grad_output_ref is None:
            perexample_influence = input.pow(2).sum(dim=1) * grad_output.pow(2).sum(dim=1)
            if bias is not None:
                perexample_influence += grad_output.pow(2).sum(dim=1)
        else:
            input_inner = input.mul(input_ref.mul(precond_input)).sum(dim=1)
            grad_output_inner = grad_output.mul(grad_output_ref.mul(precond_grad_output)).sum(dim=1)
            perexample_influence = input_inner * grad_output_inner
            if bias is not None:
                perexample_influence += grad_output_inner

        ctx._influence_buf += perexample_influence

        # We also need to backpropagate grad_inputs
        grad_input = grad_weight = grad_bias = grad_input_ref = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight
        if ctx.needs_input_grad[1] and ctx._compute_grads:
            grad_weight = grad_output.T @ input
        if ctx.needs_input_grad[2] and bias is not None and ctx._compute_grads:
            grad_bias = grad_output.sum(0)
        if ctx.needs_input_grad[3] and grad_output_ref is not None:
            grad_input_ref = grad_output_ref @ weight

        return grad_input, grad_weight, grad_bias, grad_input_ref, None, None, None


class LinearInfluence(torch.nn.Linear):
    def forward(self, x, x_ref=None, precond=None, buf=None, compute_grads=False):
        return LinearInfluenceFunction.apply(
            x, self.weight, self.bias, x_ref, precond, buf, compute_grads
        )


class MLPWithPerExampleInfluence(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, bias=True):
        super().__init__()
        self.influence_fc1 = LinearInfluence(input_dim, hidden_dim, bias=bias)
        self.influence_fc2 = LinearInfluence(hidden_dim, hidden_dim, bias=bias)
        self.influence_fc3 = LinearInfluence(hidden_dim, output_dim, bias=bias)
        self.relu = torch.nn.ReLU()

    def forward(self, x, x_ref=None, buf=None):
        assert buf is not None, "Must provide buffer to accumulate per-example influence"
        # We can create a buffer here and return it, but let's put the burden on the caller to avoid weird bugs
        x, x_ref = self.influence_fc1(x, x_ref=x_ref, buf=buf)
        x = self.relu(x)
        if x_ref is not None:
            x_ref = self.relu(x_ref)
        x, x_ref = self.influence_fc2(x, x_ref=x_ref, buf=buf)
        x = self.relu(x)
        if x_ref is not None:
            x_ref = self.relu(x_ref)
        x, x_ref = self.influence_fc3(x, x_ref=x_ref, buf=buf)
        return x, x_ref

## Test grad norms

### 1D sanity check

In [13]:
for seed in range(3):
    torch.manual_seed(seed)
    x = torch.randn(1, 1)
    y = torch.randn(1, 1)
    w = torch.randn(1, 1)


    model = LinearInfluence(1, 1, bias=False)
    model.weight.data.copy_(w)
    model.zero_grad()
    influence_buf = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
    output, _ = model(x, buf=influence_buf, compute_grads=True)
    F.mse_loss(output, y, reduction="sum").backward()

    # for linear
    linear = torch.nn.Linear(1, 1, bias=False)
    linear.weight.data.copy_(w)
    linear.zero_grad()
    F.mse_loss(linear(x), y, reduction="sum").backward()

    print(f"Seed {seed}:")
    print(" - InfluenceLinear grad norm\t=", model.weight.grad.abs().item())
    print(" - Linear grad norm\t\t=", linear.weight.grad.abs().item())
    print(f" - Per-example grad norms\t=", influence_buf.sqrt().tolist()[0])
    print()

Seed 0:
 - InfluenceLinear grad norm	= 9.443461418151855
 - Linear grad norm		= 9.443461418151855
 - Per-example grad norms	= 9.443460464477539

Seed 1:
 - InfluenceLinear grad norm	= 0.2991080582141876
 - Linear grad norm		= 0.2991080582141876
 - Per-example grad norms	= 0.2991080582141876

Seed 2:
 - InfluenceLinear grad norm	= 0.07706676423549652
 - Linear grad norm		= 0.07706676423549652
 - Per-example grad norms	= 0.07706676423549652



### MLP

In [14]:
%%time

batch_size = 1024
model_args = {
    "input_dim": 256,
    "hidden_dim": 512,
    "output_dim": 1,
}

for seed in range(3):
    torch.manual_seed(seed)
    x = torch.randn(batch_size, model_args["input_dim"])
    y = torch.randn(batch_size, model_args["output_dim"])
    x_ref = y_ref = None
    # x_ref = torch.randn(1, model_args["input_dim"])
    # y_ref = torch.randn(1, model_args["output_dim"])

    model = MLPWithPerExampleInfluence(**model_args)
    data_signature = (x.sum() + y.sum()).item()
    model_signature = sum(p.sum().item() for p in model.parameters())

    model.zero_grad()
    influence_buf = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
    output, output_ref = model(x, x_ref=x_ref, buf=influence_buf)
    loss = F.mse_loss(output, y, reduction="sum")
    if output_ref is not None:
        loss = loss + F.mse_loss(output_ref, y_ref, reduction="sum")
    loss.backward()
    # influence_buf should now contain || grad ||^2 or <x_grad, x_ref_grad>

    print(f"Seed {seed}:")
    print(f" - Model signature = {model_signature},\tData signature = {data_signature}")
    # NOTE: The buffer holds inner products (i.e., squared norms)
    print(f" - Per-example grad norms = {influence_buf.sqrt().tolist()}\n")

Seed 0:
 - Model signature = 23.65858732513152,	Data signature = -836.2041015625
 - Per-example grad norms = [29.174192428588867, 15.907368659973145, 0.8735106587409973, 10.837379455566406, 35.87784957885742, 16.90621566772461, 19.583093643188477, 0.5693714022636414, 6.071943759918213, 10.390999794006348, 7.3327741622924805, 9.699540138244629, 5.973875522613525, 8.74490737915039, 13.382455825805664, 7.366549015045166, 17.127731323242188, 2.558974504470825, 2.2834699153900146, 13.379488945007324, 2.7214369773864746, 25.4251766204834, 9.370365142822266, 5.889322757720947, 2.7038755416870117, 6.4523115158081055, 10.528423309326172, 5.967191219329834, 5.297930717468262, 0.7846270799636841, 9.718948364257812, 18.41120147705078, 12.146191596984863, 11.719406127929688, 14.815692901611328, 1.6678816080093384, 5.808952808380127, 5.5665283203125, 24.401775360107422, 8.098155975341797, 5.926903247833252, 15.639564514160156, 13.868243217468262, 18.49962615966797, 10.567115783691406, 10.44339466094

## Test inner products

### 1D sanity check

In [15]:
for seed in range(3):
    torch.manual_seed(seed)
    x = torch.randn(1, 1)
    y = torch.randn(1, 1)
    x_ref = torch.randn(1, 1)
    y_ref = torch.randn(1, 1)
    w = torch.randn(1, 1)


    model = LinearInfluence(1, 1, bias=False)
    model.weight.data.copy_(w)
    model.zero_grad()
    influence_buf = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
    output, output_ref = model(x, x_ref=x_ref, buf=influence_buf)
    loss = F.mse_loss(output, y, reduction="sum")
    loss = loss + F.mse_loss(output_ref, y_ref, reduction="sum")
    loss.backward()

    # for linear
    linear = torch.nn.Linear(1, 1, bias=False)
    linear.weight.data.copy_(w)
    linear.zero_grad()
    F.mse_loss(linear(x), y, reduction="sum").backward()
    g = linear.weight.grad.clone()
    linear.zero_grad()
    F.mse_loss(linear(x_ref), y_ref, reduction="sum").backward()
    g_ref = linear.weight.grad.clone()

    print(f"Seed {seed}:")
    print(" - Linear\t\t=", (g * g_ref).sum().item())
    print(f" - LinearInfluence\t=", influence_buf.tolist()[0])
    print()

Seed 0:
 - Linear		= 33.205875396728516
 - LinearInfluence	= 33.205875396728516

Seed 1:
 - Linear		= 0.05993038788437843
 - LinearInfluence	= 0.059930384159088135

Seed 2:
 - Linear		= -0.2766566276550293
 - LinearInfluence	= -0.2766566276550293



### MLP

In [16]:
%%time

batch_size = 1024
model_args = {
    "input_dim": 256,
    "hidden_dim": 512,
    "output_dim": 1,
}

for seed in range(3):
    torch.manual_seed(seed)
    x = torch.randn(batch_size, model_args["input_dim"])
    y = torch.randn(batch_size, model_args["output_dim"])
    # x_ref = y_ref = None
    x_ref = torch.randn(1, model_args["input_dim"])
    y_ref = torch.randn(1, model_args["output_dim"])

    model = MLPWithPerExampleInfluence(**model_args)
    data_signature = (x.sum() + y.sum()).item()
    model_signature = sum(p.sum().item() for p in model.parameters())

    model.zero_grad()
    influence_buf = torch.zeros(x.size(0), device=x.device, dtype=x.dtype)
    output, output_ref = model(x, x_ref=x_ref, buf=influence_buf)
    loss = F.mse_loss(output, y, reduction="sum")
    if output_ref is not None:
        loss = loss + F.mse_loss(output_ref, y_ref, reduction="sum")
    loss.backward()
    # influence_buf should now contain <grad_i, grad_ref> for all i

    # Do it the naive way for verification
    mlp = MLP(**model_args)
    with torch.no_grad():
        for mlp_p, p in zip(mlp.parameters(), model.parameters()):
            mlp_p.copy_(p.data)
    mlp.zero_grad()
    F.mse_loss(mlp(x_ref), y_ref, reduction="sum").backward()
    grad_ref = torch.cat([p.grad.view(-1) for p in mlp.parameters() if p.grad is not None])
    perexample_inners = []
    for i in range(batch_size):
        mlp.zero_grad()
        F.mse_loss(mlp(x[[i]]), y[[i]], reduction="sum").backward()
        grad_i = torch.cat([p.grad.view(-1) for p in mlp.parameters() if p.grad is not None])
        inner_product = (grad_i * grad_ref).sum().item()
        perexample_inners.append(inner_product)

    print(f"Seed {seed}:")
    print(f" - Model signature = {model_signature},\tData signature = {data_signature}")
    print(f" - <grad_i, grad_ref> (naive)\t= {perexample_inners}")
    print(f" - <grad_i, grad_ref> (fast)\t= {influence_buf.tolist()}\n")

Seed 0:
 - Model signature = 23.87609591986984,	Data signature = -836.2041015625
 - <grad_i, grad_ref> (naive)	= [68.1590576171875, 36.18339538574219, -4.54060173034668, 32.409751892089844, 95.67410278320312, 61.02016067504883, -41.66485595703125, -6.207691669464111, 14.474365234375, 26.934823989868164, -15.570629119873047, 21.964345932006836, 12.461263656616211, -27.197097778320312, 35.12216567993164, 25.541154861450195, 42.333621978759766, -8.407011985778809, -8.267784118652344, -29.732912063598633, -13.252540588378906, -56.813323974609375, 21.112058639526367, -16.16707992553711, 4.438309192657471, 20.633867263793945, 25.410934448242188, -11.957016944885254, -13.247001647949219, 2.481919050216675, -22.131591796875, 56.616485595703125, -27.006393432617188, -31.40955924987793, 36.964630126953125, -0.19542020559310913, -12.280418395996094, -12.557950019836426, 63.0473518371582, -19.89986228942871, -15.063065528869629, 40.54959487915039, -29.714202880859375, -38.84140396118164, -24.82708