# Using `auto_norm` to compute output norms and optimize scaling factors

We go through three examples using `auto_norm`:
1. Compute norms automatically for regular PyTorch modules
2. Build modula norm automatically for regular PyTorch modules
3. Optimize scaling factors

See end of this notebook for FAQ and a state of `auto_norm`.

## Ex1: compute norms automatically for regular PyTorch modules

Let's define a usual network in normal PyTorch.

In [18]:
import torch
from torch import nn
import torch.nn.functional as F

class MyResBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
        )

    def forward(self, x):
        return x + self.net(x)


`auto_norm` provides computation on `auto_norm.NormedTensorBase` subclasses, including 
+ `RMS_NormTensor`, 
+ `RMS_RMS_NormTensor`, 
+ `L1_NormTensor` and 
+ `Linf_NormTensor`.

`auto_norm.build_norm_map` is the key entrypoint, it returns a `norm_map` function that computes computes (norms of inputs, norms of parameters, norms of buffers) -> norms of outputs.

Its syntax is

```py
def build_norm_map(module: nn.Module, *example_args, dynamic_shapes: Optional = None, **example_kwargs):
    ...

    def norm_map(*normed_args, normed_state_dict, **normed_kwargs):
        # normed_* should generally contain auto_norm.*_NormTensor, instead of usual torch.Tensor
        ...
        return normed_outputs

    return norm_map
```

In [19]:
import auto_norm

net = MyResBlock()
example_input = torch.randn(10, 8, requires_grad=True)

norm_map = auto_norm.build_norm_map(net, example_input)  # can also specify dynamic dims (e.g., batch), but not necessary for this example

Construct normed input and state_dict

In [20]:
normed_input = auto_norm.RMS_NormTensor(1, elem_dims=(-1,))
print('normed_input: \n', normed_input)


normed_input: 
 RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1,), ...)


In [21]:
normed_state_dict = {}
for name in net.state_dict():
    if name.endswith('weight'):
        normed_state_dict[name] = auto_norm.RMS_RMS_NormTensor(1, elem_dims=(-1, -2))  # elem_dims means which dims to norm over
    elif name.endswith('bias'):
        normed_state_dict[name] = auto_norm.RMS_NormTensor(0, elem_dims=(-1,))

print('normed_state_dict:')
from pprint import pprint
pprint(normed_state_dict)


normed_state_dict:
{'net.0.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.0.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.2.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.2.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.4.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.4.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...)}


Run `norm_map` to compute the output norm.

In [22]:
output_norm = norm_map(normed_input, normed_state_dict=normed_state_dict)
print('output_norm: \n', output_norm)

output_norm: 
 RMS_NormTensor(
    norm_size=tensor(1.5000),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(10, 8)),
)


If we manually compute, it should be $1 * \frac{1}{\sqrt{2}} * 1  * \frac{1}{\sqrt{2}} * 1 + 1 = 1.5$. So yay!

Note that we get norm type and dim propagation too.

## Ex2: build modula norm automatically for regular PyTorch modules

To compute the modula norm, we need to get the local "influence" of weight norms to output. Fortunately, we can use PyTorch autograd!

Let's first specify that the weight norm sizes require gradient.

In [23]:
normed_state_dict = {k: v.norm_size_requires_grad_(True) for k, v in normed_state_dict.items()}
print('normed_state_dict:')
from pprint import pprint
pprint(normed_state_dict)

normed_state_dict:
{'net.0.bias': RMS_NormTensor(norm_size=tensor(0., requires_grad=True), elem_dims=(-1,), ...),
 'net.0.weight': RMS_RMS_NormTensor(norm_size=tensor(1., requires_grad=True), elem_dims=(-1, -2), ...),
 'net.2.bias': RMS_NormTensor(norm_size=tensor(0., requires_grad=True), elem_dims=(-1,), ...),
 'net.2.weight': RMS_RMS_NormTensor(norm_size=tensor(1., requires_grad=True), elem_dims=(-1, -2), ...),
 'net.4.bias': RMS_NormTensor(norm_size=tensor(0., requires_grad=True), elem_dims=(-1,), ...),
 'net.4.weight': RMS_RMS_NormTensor(norm_size=tensor(1., requires_grad=True), elem_dims=(-1, -2), ...)}


In [24]:
output_norm = norm_map(normed_input, normed_state_dict=normed_state_dict)
print('output_norm: \n', output_norm)

output_norm: 
 RMS_NormTensor(
    norm_size=tensor(1.5000, grad_fn=<AddBackward0>),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(10, 8)),
)


Note the `grad_fn`! Now invoke autograd...

In [None]:
output_norm.norm_size.backward()


In [28]:
sensitivities = {k: v.norm_size.grad for k, v in normed_state_dict.items()}
print('sensitivity of net.2.weight:')
print(sensitivities['net.2.weight'])



sensitivity of net.2.weight:
tensor(0.5000)


For mudula norm, we have 

$$||\{W_i\}_i||_M := \max_i  \frac{\text{total\_mass}}{\text{mass}_i} \text{influence}_i ||W_i|| $$

In [30]:
masses = {k: 1 if k.endswith('weight') else 0.1 for k in normed_state_dict}
print('masses:')
pprint(masses)

total_mass = sum(mass.values())
print(f'total_mass: {total_mass:g}')


masses:
{'net.0.bias': 0.1,
 'net.0.weight': 1,
 'net.2.bias': 0.1,
 'net.2.weight': 1,
 'net.4.bias': 0.1,
 'net.4.weight': 1}
total_mass: 3.3


In [31]:
modula_norm = max(
    total_mass / masses[k] * sensitivities[k] * normed_state_dict[k].norm_size.detach()
    for k in normed_state_dict
)
print(f'modula_norm: {modula_norm:.4f}')

modula_norm: 1.6500


## Ex3: Optimize scaling factors

Here the output norm is 1.5, not unit norm. How can we scale the layers so that it becomes unit norm?

Let's use the special class `auto_norm.ConstantScaler` to optimize for scaling factors!

In [48]:
class MyResBlockWithScaling(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 16),
            auto_norm.ConstantScaler(),  # insert scales at places we want to tune. by default, it is noop
            nn.ReLU(),
            nn.Linear(16, 16),
            auto_norm.ConstantScaler(),
            nn.ReLU(),
            nn.Linear(16, 8),
            auto_norm.ConstantScaler(),
        )
        self.idt_scaler = auto_norm.ConstantScaler()

    def forward(self, x):
        return self.idt_scaler(x) + self.net(x)


scaled_net = MyResBlockWithScaling()
norm_map_for_scaled_net = auto_norm.build_norm_map(scaled_net, example_input)
scaled_net

MyResBlockWithScaling(
  (net): Sequential(
    (0): Linear(in_features=8, out_features=16, bias=True)
    (1): ConstantScaler()
    (2): ReLU()
    (3): Linear(in_features=16, out_features=16, bias=True)
    (4): ConstantScaler()
    (5): ReLU()
    (6): Linear(in_features=16, out_features=8, bias=True)
    (7): ConstantScaler()
  )
  (idt_scaler): ConstantScaler()
)

Now the state dict contains these new scale factor. We can send any scale factors to a `norm_map` via the normed state dict.

In [49]:
def build_normed_state_dict_for_scaled_net(post_linear_scale, idt_scale):
    normed_state_dict = {}
    for name in scaled_net.state_dict():
        if name.endswith('weight'):
            normed_state_dict[name] = auto_norm.RMS_RMS_NormTensor(1, elem_dims=(-1, -2))
        elif name.endswith('bias'):
            normed_state_dict[name] = auto_norm.RMS_NormTensor(0, elem_dims=(-1,))
        elif name == 'idt_scaler.scale':
            normed_state_dict[name] = idt_scale
        elif name.endswith('scale'):
            normed_state_dict[name] = post_linear_scale
    return normed_state_dict

Let's verify the current output norm is the same as without the scaler (since they default to scale=1).

In [50]:
normed_state_dict = build_normed_state_dict_for_scaled_net(post_linear_scale=torch.tensor(1.), idt_scale=torch.tensor(1.))


output_norm = norm_map_for_scaled_net(normed_input, normed_state_dict=normed_state_dict)
print('output_norm: \n', output_norm)

output_norm: 
 RMS_NormTensor(
    norm_size=tensor(1.5000),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(10, 8)),
)


Now let's tune the scaling factors so that the output norm becomes 1!

First, let's prepare the normed state dict with scale factors that require grad:

In [52]:
post_linear_scale = torch.tensor(1., requires_grad=True)  # requres grad!
idt_scale = torch.tensor(1., requires_grad=True)
normed_state_dict = build_normed_state_dict_for_scaled_net(post_linear_scale, idt_scale)
print('normed_state_dict:')
pprint(normed_state_dict)

normed_state_dict:
{'idt_scaler.scale': tensor(1., requires_grad=True),
 'net.0.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.0.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.1.scale': tensor(1., requires_grad=True),
 'net.3.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.3.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.4.scale': tensor(1., requires_grad=True),
 'net.6.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.6.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.7.scale': tensor(1., requires_grad=True)}


Now, simply optimize with autograd...

In [54]:
optim = torch.optim.SGD([post_linear_scale, idt_scale], lr=0.01)
for ii in range(1, 201):
    optim.zero_grad()
    output_norm = norm_map_for_scaled_net(normed_input, normed_state_dict=normed_state_dict)
    loss = F.mse_loss(output_norm.norm_size, torch.tensor(1.))
    if ii % 50 == 0:
        print(f'iter {ii:03d}: loss={loss:.4f} output_norm={output_norm.norm_size:.4f}')
    loss.backward()
    optim.step()

print('post_linear_scale: \n', post_linear_scale)
print('idt_scale: \n', idt_scale)


iter 050: loss=0.0000 output_norm=1.0000
iter 100: loss=0.0000 output_norm=1.0000
iter 150: loss=0.0000 output_norm=1.0000
iter 200: loss=0.0000 output_norm=1.0000
post_linear_scale: 
 tensor(0.7549, requires_grad=True)
idt_scale: 
 tensor(0.7849, requires_grad=True)


We can verify that they works manually too:

In [58]:
import math

manual_output_norm = (
    (scaler_contribution := post_linear_scale ** 3) *
    (relu_contribution := (1 / math.sqrt(2)) ** 2) +
    (idt_contribution := idt_scale)
)
assert torch.allclose(manual_output_norm, torch.tensor(1.))
print(f'manual_output_norm: {manual_output_norm:.4f}')


manual_output_norm: 1.0000
