# Using `auto_norm` to compute output norms and optimize scaling factors

We go through three examples using `auto_norm`:
1. Compute norms automatically for regular PyTorch modules
2. Build modula norm automatically for regular PyTorch modules
3. Optimize scaling factors

## Ex1: compute norms automatically for regular PyTorch modules

Let's define a usual network in normal PyTorch.

In [2]:
import torch
from torch import nn
import torch.nn.functional as F

class MyResBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
        )

    def forward(self, x):
        return x + self.net(x)


`auto_norm` provides computation on `auto_norm.NormedTensorBase` subclasses, including 
+ `RMS_NormTensor`, 
+ `RMS_RMS_NormTensor`, 
+ `L1_NormTensor` and 
+ `Linf_NormTensor`.

`auto_norm.build_norm_map` is the key entrypoint, it returns a `norm_map` function that computes computes (norms of inputs, norms of parameters, norms of buffers) -> norms of outputs.

Its syntax is

```py
def build_norm_map(module: nn.Module, *example_args, dynamic_shapes: Optional = None, **example_kwargs):
    ...

    def norm_map(*normed_args, normed_state_dict, **normed_kwargs):
        # normed_* should generally contain auto_norm.*_NormTensor, instead of usual torch.Tensor
        ...
        return normed_outputs

    return norm_map
```

In [3]:
import auto_norm as auto_norm

net = MyResBlock()
example_input = torch.randn(10, 8, requires_grad=True)

norm_map = auto_norm.build_norm_map(net, example_input)  # can also specify dynamic dims (e.g., batch), but not necessary for this example

graph():
    %p_net_0_weight : [num_users=1] = placeholder[target=p_net_0_weight]
    %p_net_0_bias : [num_users=1] = placeholder[target=p_net_0_bias]
    %p_net_2_weight : [num_users=1] = placeholder[target=p_net_2_weight]
    %p_net_2_bias : [num_users=1] = placeholder[target=p_net_2_bias]
    %p_net_4_weight : [num_users=1] = placeholder[target=p_net_4_weight]
    %p_net_4_bias : [num_users=1] = placeholder[target=p_net_4_bias]
    %x : [num_users=2] = placeholder[target=x]
    %op__wrapper__torch__c__nn_linear__4507253904 : [num_users=1] = call_function[target=torch.ops.auto_norm.op__wrapper__torch__C__nn_linear__4507253904.default](args = (%x, %p_net_0_weight, %p_net_0_bias), kwargs = {})
    %op__wrapper__torch__ops_aten_aten_relu__4876676816 : [num_users=1] = call_function[target=torch.ops.auto_norm.op__wrapper__torch__ops_aten_aten_relu__4876676816.default](args = (%op__wrapper__torch__c__nn_linear__4507253904,), kwargs = {})
    %op__wrapper__torch__c__nn_linear__4507253905 : 

Construct normed input and state_dict

In [3]:
# enable debug level logging
import logging, sys
logger = auto_norm.logger
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(stream=sys.stdout))


In [4]:
normed_input = auto_norm.RMS_NormTensor(1, elem_dims=(-1,))
print('normed_input: \n', normed_input)


normed_input: 
 RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1,), ...)


In [5]:
normed_state_dict = {}
for name in net.state_dict():
    if name.endswith('weight'):
        normed_state_dict[name] = auto_norm.RMS_RMS_NormTensor(1, elem_dims=(-1, -2))  # elem_dims means which dims to norm over
    elif name.endswith('bias'):
        normed_state_dict[name] = auto_norm.RMS_NormTensor(0, elem_dims=(-1,))

print('normed_state_dict:')
from pprint import pprint
pprint(normed_state_dict)


normed_state_dict:
{'net.0.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.0.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.2.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.2.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.4.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.4.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...)}


Run `norm_map` to compute the output norm.

In [6]:
output_norm = norm_map(normed_input, normed_state_dict=normed_state_dict)
print('output_norm: \n', output_norm)

TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.add_.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.div.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.add_.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.div.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.T

In [51]:
import torch
from torch.overrides import enable_reentrant_dispatch
from torch._subclasses.fake_tensor import FakeTensorMode

fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

@torch.library.custom_op("mylib::foo", mutates_args={}, schema="(Tensor x) -> Tensor")
def foo(x: torch.Tensor) -> torch.Tensor:
    return x.clone()

@foo.register_fake
def foo_fake(x: torch.Tensor) -> torch.Tensor:
    print('fake')
    return x

@torch.library.custom_op("mylib::foo1", mutates_args={}, schema="(Tensor x) -> Tensor")
def foo1(x: torch.Tensor) -> torch.Tensor:
    return x.clone()

@foo1.register_fake
def foo1_fake(x: torch.Tensor) -> torch.Tensor:
    print('fake')
    return x



_x = 1

class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        global _x
        print('mode', func, types)
        return func(*args, **kwargs)
        return NotImplemented
        if _x == 1:
            with enable_reentrant_dispatch(), self:
                _x += 1
                return func(*args, **kwargs)
        else:
            return func(*args, **kwargs)



@torch.library.register_torch_dispatch("mylib::foo", MyMode)
def _(mode, func, types, args, kwargs):
    print('dispatch')
    x, = args
    # z = foo1(x)
    # return z + 1
    with enable_reentrant_dispatch(), mode:
        z = foo1(x)
    print(z)
    return z

@torch.library.register_torch_dispatch("mylib::foo1", MyMode)
def _(mode, func, types, args, kwargs):
    print('dispatch 1')
    x, = args
    return x

In [52]:
x = torch.randn(3).requires_grad_()
with MyMode():
    y = foo(x)
y

dispatch
dispatch 1
tensor([ 1.0699, -0.1113,  2.8711], requires_grad=True)
mode aten.view.default ()


tensor([ 1.0699, -0.1113,  2.8711],
       grad_fn=<GeneratedBackwardFor_mylib_foo_defaultBackward>)

In [11]:
import torch
from torch.overrides import enable_reentrant_dispatch
from torch._subclasses.fake_tensor import FakeTensorMode

fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

class MyT(torch.Tensor):
    LOL: bool = False
    def __torch_dispatch__(self, func, types, args, kwargs):
        print('subclass dispatch', func, types)
        if MyT.LOL:
            return NotImplemented
        return super().__torch_dispatch__(func, types, args, kwargs)

@torch.library.custom_op("mylib::foo", mutates_args={}, schema="(Tensor x) -> Tensor")
def foo(x: MyT) -> MyT:
    return x.clone()

@foo.register_fake
def foo_fake(x: MyT) -> MyT:
    print('fake')
    return x

_x = 1

class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        global _x
        print('mode', func, types)
        return NotImplemented
        if _x == 1:
            with enable_reentrant_dispatch(), self:
                _x += 1
                return func(*args, **kwargs)
        else:
            return func(*args, **kwargs)


x = torch.randn(3, out=MyT(3))
print('---')
y = foo(x)
print('--===')
assert torch.allclose(y, x)
print('---')
with MyMode():
    y = foo(x)
print('--===')
print('---')
@torch.library.register_torch_dispatch("mylib::foo", MyMode)
def _(mode, func, types, args, kwargs):
    print('dispatch')
    x, = args
    with enable_reentrant_dispatch():
        z = foo(x)
    print(x)
    zz = (foo(fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, x)))
    return z + 1
# MyT.LOL = True
with MyMode():
    y = foo(fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, x))
    print('--===')
    y = foo(x)
print('--===')
assert torch.allclose(y, x + 1)

subclass dispatch aten.resize_.default (<class '__main__.MyT'>,)
subclass dispatch aten.normal_.default (<class '__main__.MyT'>,)
---
subclass dispatch mylib.foo.default (<class '__main__.MyT'>,)
--===
subclass dispatch aten.allclose.default (<class '__main__.MyT'>,)
---
mode mylib.foo.default (<class '__main__.MyT'>,)
subclass dispatch mylib.foo.default (<class '__main__.MyT'>,)
--===
---
dispatch
fake
FakeTensor(..., size=(3,))
fake
--===
dispatch
subclass dispatch mylib.foo.default (<class '__main__.MyT'>,)
subclass dispatch aten.reshape.default (<class '__main__.MyT'>,)


RuntimeError: .tolist() is not supported for tensor subclasses, got MyT

In [1]:
import torch
import torch.nn.functional

from torch.utils._pytree import tree_map
from torch.overrides import enable_reentrant_dispatch

import contextlib
from typing import Any

import torch
from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten

# Dumping ground for utilities that should eventual make their way into
# PyTorch proper


@contextlib.contextmanager
def no_dispatch():
    guard = torch._C._DisableTorchDispatch()
    try:
        yield
    finally:
        del guard


def tree_map2(fn: Any, pytree1: PyTree, pytree2: PyTree) -> PyTree:
    flat_args1, spec1 = tree_flatten(pytree1)
    flat_args2, spec2 = tree_flatten(pytree2)
    assert spec1 == spec2
    return tree_unflatten([fn(i, j) for i, j in zip(flat_args1, flat_args2)], spec1)


# IDK if this is actually useful or not
def unmake_subclass(tensor):
    with no_dispatch():
        return torch.Tensor._make_subclass(torch.Tensor, tensor)


def fill_defaults(args, n, defaults_tail):
    """
    __torch_dispatch__ doesn't guarantee the number of arguments you are
    passed (e.g., defaulted arguments are not passed); but usually it is
    convenient to pad out the arguments list with defaults.  This function
    helps you do that.

    Args:
        args: the list of positional arguments passed to __torch_dispatch__
        n: the number of arguments you are expecting to get
        defaults_tail: default values for the arguments, starting from the
            end of the list

    Example:

        >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
        [1, 2, 3, 4, 5]
        >>> fill_defaults([1, 2, 3], 5, [None, None, None])
        [1, 2, 3, None, None]]
    """
    if n - len(defaults_tail) > len(args):
        raise RuntimeError("not enough defaults to fill arguments")
    r = list(args)
    for i in range(len(args), n):
        r.append(defaults_tail[i - n + len(defaults_tail)])
    return r

# All of the tensor examples in this zoo inherit from BaseTensor.  Ideally,
# however, they would inherit directly from Tensor.  This is just our staging
# ground for applying behavior that hasn't yet made it into core but that
# we would like to apply by default.
class BaseTensor(torch.Tensor):
    # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
    # to ensure that super().__new__ can cooperate with each other
    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        if requires_grad is None:
            return super().__new__(cls, elem)
        else:
            return cls._make_subclass(cls, elem, requires_grad)

    # To ensure constructors can cooperate with one another, must accept and
    # ignore element tensor (TODO: is this right???)
    def __init__(self, elem):
        super().__init__()

    # If __torch_dispatch__ is defined (which it will be for all our examples)
    # the default torch function implementation (which preserves subclasses)
    # typically must be disabled
    # __torch_function__ = torch._C._disabled_torch_function_impl

# This file describes how to use wrapper tensors (ala TrivialTensorViaComposition)
# to override autograd from __torch_dispatch__.  Ordinarily,
# __torch_dispatch__ runs after autograd, so you have no way of overriding
# the autograd behavior (since it will be handled after you return).  However,
# if we put the autograd tensor *inside* a wrapper tensor (which doesn't
# itself require gradients), we get a chance to interpose (in __torch_dispatch__)
# before you handle gradients on the inner element.
#
# Note that you can also use __torch_function__ instead to implement this
# functionality, so this is mostly a question of whether or not you want to
# target the public Python API, or the internal ATen operators API
# (torch.ops.aten).


class InnerAutogradTensor(torch.Tensor):
    REG = {}

    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        # Outer tensor's autograd is now disconnected from the inner
        # tensors autograd...
        # return super().__new__(cls, elem, requires_grad=False)
        return cls._make_wrapper_subclass(cls, elem.size(), dtype=elem.dtype,
                                          device=elem.device,
                                          requires_grad=False)  # NB: false here so that we can use reentrant dispatch on unwrapped normed tensors to get autograd on norms

    def __init__(self, elem):
        # ... but note that we save the inner tensor, so we can still
        # do autograd on operations on the inside!
        self.elem = elem

    def __repr__(self):
        if self.requires_grad:
            return f'InnerAutogradTensor({self.elem}, requires_grad=True)'
        else:
            return f'InnerAutogradTensor({self.elem})'

    @classmethod
    def __torch_function__(cls, func, types, args, kwargs=None):
        print('subclass torch function', func, types)
        # return NotImplemented
        return super().__torch_function__(func, types, args, kwargs)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # We can't handle here because we want to reentrant dispatch *without unwrapping*
        #
        # If we don't unwrap, then reentrant means we recurse back to here, infinitely!
        # If we unwrap, then we are not passing the information to the custom ops.
        #
        # Hence, the only sensible way is to handle it in the custom op. We could rely on
        # custom op's ability to dispatch based on tensor subclass, but it only
        # dispatches on a single arg type, which is insufficient for our use case.
        #
        # Therefore, we do in-house dispatch via TensorSubclassDispatcher. Regardless of input,
        # all custom ops will have the same bridging code that dispatches to the same
        # TensorSubclassDispatcher.
        #
        # To avoid registering this same code for all tensor subclasses (and for the default impl
        # that could happen with factory functions that do not have tensor args), we
        #    1. register the bridging code only for the default impl for each custom op
        #    2. here, if `func` is such a custom op, we do
        #
        #       with enable_reentrant_dispatch(), torch.set_grad_enabled(True):
        #           func(*args, **kwargs or {}),
        #
        #       which will invoke the default impl with our bridging code.
        #    3. here, otherwise, we just return NotImplemented as we don't know how to compute norms
        #       for ops that are not our custom ops.
        print('subclass torch dispatch', func, types, args, kwargs)
        # if func in cls.REG and not torch.is_grad_enabled():
        #     with torch.set_grad_enabled(True): #, enable_reentrant_dispatch():
        #         # return super().__torch_dispatch__(func, types, args, kwargs)
        #         return func(*args, **kwargs or {})
        return super().__torch_dispatch__(func, types, args, kwargs)
        return NotImplemented
        kwargs = kwargs or {}
        print('subclass dispatch', func, types)
        with enable_reentrant_dispatch(), torch.set_grad_enabled(True):
            # return func(*args, **kwargs or {})
            kwargs = dict(kwargs)
            kwargs['requires_grad'] = True
            return super().__torch_dispatch__(func, types, args, kwargs)
        return NotImplemented
        def unwrap(t):
            if isinstance(t, cls):
                return t.elem
            elif isinstance(t, torch.Tensor) and t.requires_grad:
                # If any other argument at this level does require gradients
                # it will not interact with our inner Tensor and thus this
                # should fail.
                raise RuntimeError("Bad mixup of autograd level")
            else:
                return t

        def wrap(t):
            # Micro-optimization: not necessary to rewrap if the output tensor
            # doesn't require gradients
            if (
                isinstance(t, torch.Tensor)
                and not isinstance(t, cls)
                and t.requires_grad
            ):
                return cls(t)
            else:
                return t

        with enable_reentrant_dispatch():
            # Override gradient behavior
            if func == torch.ops.aten.embedding.default:
                args = fill_defaults(args, 5, [-1, False, False])
                weight, indices, padding_idx, scale_grad_by_freq, _sparse = map(
                    unwrap, args
                )
                assert not kwargs
                # Force sparse gradients.  We could have also done this by
                # defining a custom autograd function.
                return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True))

            return tree_map(
                wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
            )


    # def test_basic(self):
    #     x = torch.randn(1, requires_grad=True)
    #     y = InnerAutogradTensor(x)
    #     self.assertFalse(y.requires_grad)
    #     self.assertTrue(y.elem.requires_grad)
    #     z = InnerAutogradTensor(x)
    #     # Although y and z do not require grad, we are still able
    #     # to differentiate
    #     r = y + z
    #     # Note we have to extract out the inner tensor (which requires_grad)
    #     # to actually differentiate
    #     r.sum().elem.backward()
    #     self.assertEqual(x.grad, torch.tensor([2.0]))  # two uses!

    # def test_embedding(self):
    #     input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
    #     weights = torch.rand(10, 3, requires_grad=True)
    #     embedding_matrix = InnerAutogradTensor(weights)
    #     r = torch.nn.functional.embedding(input, embedding_matrix)
    #     r.sum().elem.backward()
    #     # Gradient is sparse even though we didn't ask for it in embedding!
    #     self.assertTrue(weights.grad.is_sparse)

    # def test_mixing(self):
    #     # Mixing behavior is confusing.  Let's take a look
    #     w1 = torch.randn(1, requires_grad=True)
    #     w2 = torch.randn(1, requires_grad=True)

    #     # Autograd doesn't "unwrap" variables, they still remember if they
    #     # requires_grad; and in fact, inside __torch_dispatch__ it is willing
    #     # to mix gradients between multiple levels. The current class does
    #     # catch most of these though when it is looking at the different
    #     # arguments
    #     with self.assertRaisesRegex(RuntimeError, "Bad mixup of autograd level"):
    #         x = InnerAutogradTensor(w1) + w2



In [2]:
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        print('mode', func, types)
        return func(*args, **kwargs or {})

torch.library.custom_op

def foo_fn(x: InnerAutogradTensor, y: InnerAutogradTensor) -> InnerAutogradTensor:
    # 1/0
    print('foo', x.elem, y.elem, x.elem + y.elem)
    return InnerAutogradTensor(x.elem + y.elem)


foo = torch.library.custom_op("mylib::foo", mutates_args={}, schema="(Tensor x, Tensor y) -> Tensor")(foo_fn)


# def setup_context(ctx, inputs, output):
#     x, y = inputs
#     ctx.save_for_backward(x)

def backward(ctx, grad):
    print('backward', grad)
    x, = ctx.saved_tensors
    return grad * x.cos()

from torch._library.autograd import InfoProtocol
from typing import *

def make_autograd_impl(op: torch._ops.OpOverload, info: InfoProtocol) -> Callable:
    from torch._library import utils
    from torch import _C
    from dataclasses import dataclass
    from torch.utils import _pytree
    from torch._library.autograd import supports_tensorlist, not_list_of_tensor

    name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"

    has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)

    @dataclass
    class Metadata:
        keyset: _C.DispatchKeySet
        keyword_only_args: Dict[str, Any]

    def forward_no_grad(*args):
        metadata = args[-1]
        args = args[:-1]

        with _C._AutoDispatchBelowAutograd():
            keyset = metadata.keyset
            kwargs = metadata.keyword_only_args
            result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
            return result

    def forward(ctx, *args):
        metadata = args[-1]
        args = args[:-1]

        with _C._AutoDispatchBelowAutograd():
            keyset = metadata.keyset
            kwargs = metadata.keyword_only_args
            result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
            if info._setup_context_fn:
                # The Dispatcher will remove args that are equal to their default
                # values from (args, kwargs). We're going to add it back so that
                # the user can access them.
                #
                # This is OK to do: The Dispatcher removed the args for serialization
                # FC/BC reasons (that is, a graph will not store args that are equal
                # to their default values), but that doesn't matter here. If the user
                # adds a new default arg, then they must update
                # their setup_context (along with the rest of their operator
                # registrations)
                args, kwargs = utils.fill_defaults(op._schema, args, kwargs)

                if has_kwarg_only_args:
                    info._setup_context_fn(
                        ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
                    )
                else:
                    info._setup_context_fn(ctx=ctx, inputs=args, output=result)
            return result

    def backward(ctx, *grads):
        if info._backward_fn:
            try:
                prev_needs_input_grad = ctx.needs_input_grad
                ctx.needs_input_grad = ctx.needs_input_grad[:-1]
                result = info._backward_fn(ctx, *grads)
            finally:
                ctx.needs_input_grad = prev_needs_input_grad
            if isinstance(result, tuple):
                return (*result, None)
            return result, None
        raise RuntimeError(
            f"Trying to backward through {op} but no autograd "
            f"formula was registered. "
            f"Please use register_autograd to add one."
        )

    Generated = type(
        name,
        (torch.autograd.Function,),
        {
            "forward": staticmethod(forward),
            "backward": staticmethod(backward),
        },
    )

    schema = op._schema
    if any(
        utils.is_tensorlist_like_type(a.type)
        for a in (*schema.arguments, *schema.returns)
    ):
        Generated = supports_tensorlist(Generated)

    # The dispatcher passes any keyword-only-args as kwargs and the
    # rest of the args (even if specified as kwargs) as args.
    def autograd_impl(keyset, *args, **keyword_only_args):
        print('autograd_impl', keyset, args, keyword_only_args)
        if _C.is_grad_enabled() and _pytree.tree_any_only(
            torch.Tensor, lambda x: x.requires_grad, args, not_list_of_tensor
        ):
            result = Generated.apply(*args, Metadata(keyset, keyword_only_args))  # type: ignore[attr-defined]
        else:
            result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
        return result

    return autograd_impl

foo.register_autograd(backward)
# foo._lib.impl(foo._name, make_autograd_impl(foo._opoverload, foo), "Autograd", with_keyset=True)

# @foo.register_torch_dispatch(InnerAutogradTensor)
# def _(mode, func, types, args, kwargs):
#     print('cop subclass dispatch')
#     return func(*args, **kwargs)

@torch.library.custom_op("mylib::foox", mutates_args={}, schema="() -> Tensor")
def foox() -> InnerAutogradTensor:
    return torch.randn(1, requires_grad=True)



In [3]:

lib = torch.library.Library('mylib29', 'FRAGMENT')
lib.define('foo(Tensor x, Tensor y) -> Tensor')

def wrap_keyset(fn, desc, pass_keyset=False):
    def wrapper(keyset, *args, **kwargs):
        print('wrap_keyset', desc, keyset)
        if pass_keyset:
            return fn(*args, keyset, **kwargs)
        else:
            return fn(*args, **kwargs)
    return wrapper

lib.impl('foo', wrap_keyset(foo_fn, 'foo'), 'CompositeExplicitAutograd', with_keyset=True)

class Foo(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i, j, keyset):
        print('autograd', torch.is_grad_enabled(), keyset)
        with torch.set_grad_enabled(True):
        #     return InnerAutogradTensor(i.elem + j.elem)
            return foo_fn(i, j)
            return getattr(torch.ops, lib.ns).foo.default.redispatch(
                torch._C.DispatchKeySet(torch._C.DispatchKey.CompositeExplicitAutograd),
                # keyset & torch._C._after_autograd_keyset,
                # torch._C._after_autograd_keyset,
                i, j
            )
        with torch.set_grad_enabled(True):
            return InnerAutogradTensor(i.elem + j.elem)
        # result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        return InnerAutogradTensor(grad_output.elem * 0.4), InnerAutogradTensor(grad_output.elem * 0.35), None
        result, = ctx.saved_tensors
        return grad_output * result


lib.impl('foo', wrap_keyset(Foo.apply, 'foo autograd', pass_keyset=True), 'Autograd', with_keyset=True)


In [356]:
torch.ops.mylib2.foo.default._schema

mylib2::foo(Tensor x, Tensor y) -> Tensor

In [5]:

x = torch.randn(1, requires_grad=True)
y = InnerAutogradTensor(x).requires_grad_()
z = InnerAutogradTensor(x).requires_grad_()
# Foo.apply(y, z).elem

subclass torch function <method 'requires_grad_' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method 'requires_grad_' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)


In [202]:
torch._C.DispatchKeySet(torch._C.DispatchKey.PreDispatch)

DispatchKeySet(PreDispatch)

In [214]:
torch._C.DispatchKey.__members__

{'Undefined': <DispatchKey.Undefined: 0>,
 'CompositeExplicitAutogradNonFunctional': <DispatchKey.CompositeExplicitAutogradNonFunctional: 149>,
 'CompositeExplicitAutograd': <DispatchKey.CompositeExplicitAutograd: 148>,
 'CompositeImplicitAutogradNestedTensor': <DispatchKey.CompositeImplicitAutogradNestedTensor: 147>,
 'CompositeImplicitAutograd': <DispatchKey.CompositeImplicitAutograd: 145>,
 'AutogradNestedTensor': <DispatchKey.AutogradNestedTensor: 24>,
 'AutogradOther': <DispatchKey.AutogradOther: 22>,
 'Autograd': <DispatchKey.Autograd: 144>,
 'Conjugate': <DispatchKey.Conjugate: 18>,
 'ZeroTensor': <DispatchKey.ZeroTensor: 20>,
 'Negative': <DispatchKey.Negative: 19>,
 'BackendSelect': <DispatchKey.BackendSelect: 12>,
 'ADInplaceOrView': <DispatchKey.ADInplaceOrView: 21>,
 'PythonTLSSnapshot': <DispatchKey.PythonTLSSnapshot: 41>,
 'Python': <DispatchKey.Python: 13>,
 'FuncTorchDynamicLayerFrontMode': <DispatchKey.FuncTorchDynamicLayerFrontMode: 42>,
 'FuncTorchDynamicLayerBackMod

In [262]:
out = torch.ops.mylib15.foo.default.redispatch(
    torch._C.DispatchKeySet(torch._C.DispatchKey.CompositeExplicitAutograd),
    # torch._C._after_autograd_keyset,
    y, z
)
print(out.elem, out.grad_fn)

foo tensor([1.6930], requires_grad=True) tensor([1.6930], requires_grad=True) tensor([3.3860], grad_fn=<AddBackward0>)
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x121328840> (<class '__main__.InnerAutogradTensor'>,)
tensor([3.3860], grad_fn=<AddBackward0>) None


In [307]:
with torch._C._AutoDispatchBelowAutograd():
    out = torch.ops.mylib18.foo.default.redispatch(
        (
            torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
        ),
        y, z
    )
print(out.elem, out.grad_fn)

wrap_keyset foo DispatchKeySet(CPU)
foo tensor([1.6930], requires_grad=True) tensor([1.6930], requires_grad=True) tensor([3.3860])
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x121328840> (<class '__main__.InnerAutogradTensor'>,)
tensor([3.3860]) None


In [323]:
with torch._C._AutoDispatchBelowAutograd():
    out = torch.ops.mylib18.foo.default.redispatch((
        torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
        |
        torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
        |
        torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
    ) & torch._C._after_autograd_keyset,
    y, z
    )
print(out.elem, out.grad_fn)

wrap_keyset foo autograd DispatchKeySet(CPU, Python, AutogradCPU)
autograd False
wrap_keyset foo DispatchKeySet()
foo tensor([1.6930], requires_grad=True) tensor([1.6930], requires_grad=True) tensor([3.3860])
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x121328840> (<class '__main__.InnerAutogradTensor'>,)
tensor([3.3860]) <torch.autograd.function.FooBackward object at 0x31aed0750>


In [6]:
out = getattr(torch.ops, lib.ns).foo.default(y, z)
print(out.elem, out.grad_fn)


subclass torch function mylib29.foo.default (<class '__main__.InnerAutogradTensor'>,)
wrap_keyset foo autograd DispatchKeySet(CPU, Python, AutogradCPU)
autograd False DispatchKeySet(CPU, Python, AutogradCPU)
foo tensor([0.1011], requires_grad=True) tensor([0.1011], requires_grad=True) tensor([0.2021], grad_fn=<AddBackward0>)
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x10422c500> (<class '__main__.InnerAutogradTensor'>,)
tensor([0.2021], grad_fn=<AddBackward0>) <torch.autograd.function.FooBackward object at 0x16d701950>


In [7]:
y.grad = torch.autograd.grad(out, y, grad_outputs=InnerAutogradTensor(torch.tensor([1.])))[0]

subclass torch function <function grad at 0x123543240> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method-wrapper '__set__' of getset_descriptor object at 0x10422c7c0> (<class '__main__.InnerAutogradTensor'>,)


In [8]:
y.grad

subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x10422c7c0> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x10422c980> (<class '__main__.InnerAutogradTensor'>,)


InnerAutogradTensor(tensor([0.4000]))

In [9]:
out.backward(InnerAutogradTensor(torch.tensor([1.])))

: 

In [388]:
y.grad

: 

In [314]:
torch._C._after_autograd_keyset

DispatchKeySet(CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, VE, Lazy, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, Meta, FPGA, MAIA, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedXLA, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedLazy, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, QuantizedMeta, CustomRNGKeyId, MkldnnCPU, SparseCPU, SparseCUDA, SparseHIP, SparseXLA, SparseMPS, SparseIPU, SparseXPU, SparseHPU, SparseVE, SparseLazy, SparseMTIA, SparsePrivateUse1, SparsePrivateUse2, SparsePrivateUse3, SparseMeta, SparseCsrCPU, SparseCsrCUDA, SparseCsrHIP, SparseCsrXLA, SparseCsrMPS, SparseCsrIPU, SparseCsrXPU, SparseCsrHPU, SparseCsrVE, SparseCsrLazy, SparseCsrMTIA, SparseCsrPrivateUse1, SparseCsrPrivateUse2, SparseCsrPrivateUse3, SparseCsrMeta, NestedTensorCPU, NestedTensorCUDA, NestedTensorHIP, NestedTensorXLA, NestedTensorMPS, NestedTensorIPU, NestedTensorXPU, NestedTensorHPU, NestedTensorVE, NestedTens

In [235]:
out = torch.ops.mylib.foo.default.redispatch(
    torch._C.DispatchKeySet(torch._C.DispatchKey.CompositeExplicitAutograd),
    # torch._C._after_autograd_keyset,
    y, z
)
print(out.elem, out.grad_fn)

subclass torch function <method 'untyped_storage' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method 'untyped_storage' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)
foo tensor([1.6930], requires_grad=True) tensor([1.6930], requires_grad=True) tensor([3.3860], grad_fn=<AddBackward0>)
subclass torch function <method 'untyped_storage' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method 'untyped_storage' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x121328840> (<class '__main__.InnerAutogradTensor'>,)
tensor([3.3860], grad_fn=<AddBackward0>) None


In [226]:
out = torch.ops.mylib.foo.default._op_dk(
    torch._C.DispatchKey.Undefined,
    # 'a',
    # torch._C._after_autograd_keyset,
    y, z
)
print(out.elem, out.grad_fn)

subclass torch function mylib.foo.default (<class '__main__.InnerAutogradTensor'>,)
subclass torch dispatch mylib.foo.default (<class '__main__.InnerAutogradTensor'>,) False False
foo tensor([1.6930], requires_grad=True) tensor([1.6930], requires_grad=True) tensor([3.3860])
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x121328840> (<class '__main__.InnerAutogradTensor'>,)
tensor([3.3860]) <torch.autograd.function.GeneratedBackwardFor_mylib_foo_defaultBackward object at 0x31aba0c50>


In [212]:
torch.ops.mylib.foo.default._dispatch_cache

{<DispatchKey.CompositeExplicitAutograd: 148>: <DispatchKey.CompositeExplicitAutograd: 148>,
 <DispatchKey.Autograd: 144>: <DispatchKey.Autograd: 144>,
 <DispatchKey.Python: 13>: <DispatchKey.Python: 13>}

In [128]:
foo._lib._op_impls

{'mylib/foo/Autograd', 'mylib/foo/CompositeExplicitAutograd'}

In [125]:
torch.ops.mylib.foo.default(y, z).grad_fn

subclass torch function mylib.foo.default (<class '__main__.InnerAutogradTensor'>,)
subclass torch dispatch mylib.foo.default (<class '__main__.InnerAutogradTensor'>,) False False
foo tensor([0.2494], requires_grad=True) tensor([0.2494], requires_grad=True) tensor([0.4988])
subclass torch function <method-wrapper '__get__' of getset_descriptor object at 0x121328840> (<class '__main__.InnerAutogradTensor'>,)


<torch.autograd.function.GeneratedBackwardFor_mylib_foo_defaultBackward at 0x31a6d8550>

In [94]:
foo._lib.impl(foo._name, foo, "Autograd", with_keyset=True)
# torch.ops.mylib.foo.default._op_dk(torch._C.DispatchKey.Autograd, y, z).elem

  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: mylib::foo(Tensor x, Tensor y) -> Tensor
    registered at /dev/null:185
  dispatch key: Autograd
  previous kernel: no debug info
       new kernel: registered at /dev/null:185 (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1729647038473/work/aten/src/ATen/core/dispatch/OperatorEntry.cpp:162.)
  self.m.impl(


In [83]:
torch.ops.mylib.foo.default._op_dk(torch._C.DispatchKey.Autograd, y, z).elem

subclass torch function mylib.foo.default (<class '__main__.InnerAutogradTensor'>,)
subclass torch dispatch mylib.foo.default (<class '__main__.InnerAutogradTensor'>,) False False
foo tensor([-2.5916], requires_grad=True) tensor([-2.5916], requires_grad=True) tensor([-5.1833])


tensor([-5.1833])

In [69]:
InnerAutogradTensor.REG.clear()
# InnerAutogradTensor.REG[torch.ops.mylib.foo.default] = foo
x = torch.randn(1, requires_grad=True)
y = InnerAutogradTensor(x).requires_grad_()
z = InnerAutogradTensor(x).requires_grad_()

subclass torch function <method 'requires_grad_' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)
subclass torch function <method 'requires_grad_' of 'torch._C.TensorBase' objects> (<class '__main__.InnerAutogradTensor'>,)


In [70]:
torch.ops.mylib.foo(y, z).elem

subclass torch function mylib.foo (<class '__main__.InnerAutogradTensor'>,)
subclass torch dispatch mylib.foo.default (<class '__main__.InnerAutogradTensor'>,) False False
foo tensor([-2.5916], requires_grad=True) tensor([-2.5916], requires_grad=True) tensor([-5.1833])


tensor([-5.1833])

In [203]:

# Although y and z do not require grad, we are still able
# to differentiate
r = y + z
# Note we have to extract out the inner tensor (which requires_grad)
# to actually differentiate
r.sum().elem.backward()

subclass dispatch aten.add.Tensor (<class '__main__.InnerAutogradTensor'>,) False True


AttributeError: 'Tensor' object has no attribute 'elem'

In [58]:
x.grad

tensor([2.])

In [15]:
torch.ops.aten.linear.default.redispatch()

TypeError: OpOverload.redispatch() missing 1 required positional argument: 'keyset'

If we manually compute, it should be $1 * \frac{1}{\sqrt{2}} * 1  * \frac{1}{\sqrt{2}} * 1 + 1 = 1.5$. So yay!

Note that we get norm type and dim propagation too.

## Ex2: build modula norm automatically for regular PyTorch modules

To compute the modula norm, we need to get the local "influence" of weight norms to output. Fortunately, we can use PyTorch autograd!

Let's first specify that the weight norm sizes require gradient.

In [7]:
normed_state_dict = {k: v.norm_size_requires_grad_(True) for k, v in normed_state_dict.items()}
print('normed_state_dict:')
from pprint import pprint
pprint(normed_state_dict)

normed_state_dict:
{'net.0.bias': RMS_NormTensor(norm_size=tensor(0., requires_grad=True), elem_dims=(-1,), ...),
 'net.0.weight': RMS_RMS_NormTensor(norm_size=tensor(1., requires_grad=True), elem_dims=(-1, -2), ...),
 'net.2.bias': RMS_NormTensor(norm_size=tensor(0., requires_grad=True), elem_dims=(-1,), ...),
 'net.2.weight': RMS_RMS_NormTensor(norm_size=tensor(1., requires_grad=True), elem_dims=(-1, -2), ...),
 'net.4.bias': RMS_NormTensor(norm_size=tensor(0., requires_grad=True), elem_dims=(-1,), ...),
 'net.4.weight': RMS_RMS_NormTensor(norm_size=tensor(1., requires_grad=True), elem_dims=(-1, -2), ...)}


In [8]:
output_norm = norm_map(normed_input, normed_state_dict=normed_state_dict)
print('output_norm: \n', output_norm)

TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.add_.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.div.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.add_.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.div.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.T

Note the `grad_fn`! Now invoke autograd...

In [9]:
torch.ops.aten.linear.default.overloadpacket.default

<OpOverload(op='aten.linear', overload='default')>

In [24]:
dir(torch.ops.auto_norm)

['__doc__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 '_dir',
 'name',
 'op__constant_scaler_mul__auto_norm_reg_fake_norm_ops_constant_scaler_ConstantScaler__mul_with_scaler__13174249152',
 'op__wrapper__torch__C__nn_linear__4713105744',
 'op__wrapper__torch__C__nn_scaled_dot_product_attention__4713108864',
 'op__wrapper__torch__ops_aten_aten_add_Tensor__6134135248',
 'op__wrapper__torch__ops_aten_aten_randn__6142561744',
 'op__wrapper__torch__ops_aten_aten_relu__6143355280',
 'op__wrapper__torch_nn_functional_layer_norm__4765698144']

In [29]:
# from auto_norm.normed_mode_dispatch import normed_mode_propagate
from auto_norm.reg_fake_norm_op_registry import REG_FAKE_NORM_OP_REGISTRY

# with normed_mode_propagate():
x, w, b = (
    normed_input,
    normed_state_dict['net.0.weight'],
    normed_state_dict['net.0.bias'],
)
x = x.finalize(torch.empty(10, 8)).requires_grad_()
w = w.finalize(torch.empty(16, 8)).requires_grad_()
b = b.finalize(torch.empty(16)).requires_grad_()
op = REG_FAKE_NORM_OP_REGISTRY[torch.nn.functional.linear]
y = op.normed_dispatcher(x, w, b)
y, y.grad_fn


TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor


(RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1,), ...), None)

In [46]:
y.grad_fn

In [13]:
REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP

{}

In [37]:
y.grad_fn

In [25]:
output_norm.grad_fn

In [8]:
output_norm.backward()


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [14]:
output_norm.norm_size.backward()


For mudula norm, we have 

$$||\{W_i\}_i||_\mathsf{M} := \max_i  \frac{\textsf{total\_mass}}{\textsf{mass}_i} \textsf{influence}_i ||W_i||, $$

where

$$ \textsf{influences}_i = \frac{\partial\ \textsf{out\_norm}}{\partial ||W_i||} $$

(If you have read the original modula paper, then for atomic module $\mathsf{M}$, we assume $\mathsf{M.norm}(W) := \alpha ||W||$, for some norm choice $||\cdot||$ and some scalar $\alpha$.)

In [15]:
influences = {k: v.norm_size.grad for k, v in normed_state_dict.items()}
print('influences of net.2.weight:')
print(influences['net.2.weight'])

influences of net.2.weight:
tensor(0.5000)


In [16]:
masses = {k: 1 if k.endswith('weight') else 0.1 for k in normed_state_dict}
print('masses:')
pprint(masses)

total_mass = sum(masses.values())
print(f'total_mass: {total_mass:g}')


masses:
{'net.0.bias': 0.1,
 'net.0.weight': 1,
 'net.2.bias': 0.1,
 'net.2.weight': 1,
 'net.4.bias': 0.1,
 'net.4.weight': 1}
total_mass: 3.3


In [17]:
modula_norm = max(
    total_mass / masses[k] * influences[k] * normed_state_dict[k].norm_size.detach()
    for k in normed_state_dict
)
print(f'modula_norm: {modula_norm:.4f}')

modula_norm: 1.6500


## Ex3: Optimize scaling factors

Here the output norm is 1.5, not unit norm. How can we scale the layers so that it becomes unit norm?

Let's use the special class `auto_norm.ConstantScaler` to optimize for scaling factors!

In [18]:
class MyResBlockWithScaling(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 16),
            auto_norm.ConstantScaler(),  # insert scales at places we want to tune. by default, it is noop
            nn.ReLU(),
            nn.Linear(16, 16),
            auto_norm.ConstantScaler(),
            nn.ReLU(),
            nn.Linear(16, 8),
            auto_norm.ConstantScaler(),
        )
        self.idt_scaler = auto_norm.ConstantScaler()

    def forward(self, x):
        return self.idt_scaler(x) + self.net(x)


scaled_net = MyResBlockWithScaling()
norm_map_for_scaled_net = auto_norm.build_norm_map(scaled_net, example_input)
scaled_net

ExportFakeFunctionMode dispatching <method-wrapper '__get__' of getset_descriptor object at 0x106d29100>, (<class 'torch.Tensor'>,), (tensor([[-0.5987, -1.4379, -0.0727,  0.7559,  0.5247,  0.2773, -0.1642,  1.4604],
        [ 0.5857, -0.0591, -0.2283,  0.2118, -0.0901, -0.4193, -1.3814,  0.7365],
        [-0.2951,  1.5737, -1.6958, -2.6760, -1.4862, -1.0574,  2.0471, -0.1784],
        [ 0.8522,  0.3332, -1.3924, -2.4705, -0.3980,  1.2932,  0.8630,  0.7114],
        [ 0.5178, -0.8385, -1.1924,  0.2043, -1.0558, -1.9803, -0.8809,  0.4878],
        [ 1.2046,  1.1006,  1.3434, -1.2230,  0.2825, -1.0586, -0.7185, -0.8391],
        [-0.4893,  0.5534, -0.2063,  1.6485,  0.6543, -0.0716, -0.7954, -0.2680],
        [ 0.7354, -0.6773,  0.1088, -0.6905, -1.2562, -0.5258,  0.1183,  1.5214],
        [ 0.6987,  0.2470, -0.6693,  0.3906, -0.4304,  0.2741, -0.6133,  1.1662],
        [-0.0221, -0.7958,  1.3911,  1.2474,  0.5356,  2.9735, -0.3995,  0.8350]],
       requires_grad=True),), None
ExportFake

MyResBlockWithScaling(
  (net): Sequential(
    (0): Linear(in_features=8, out_features=16, bias=True)
    (1): ConstantScaler()
    (2): ReLU()
    (3): Linear(in_features=16, out_features=16, bias=True)
    (4): ConstantScaler()
    (5): ReLU()
    (6): Linear(in_features=16, out_features=8, bias=True)
    (7): ConstantScaler()
  )
  (idt_scaler): ConstantScaler()
)

Now the state dict contains these new scale factor. We can send any scale factors to a `norm_map` via the normed state dict.

In [19]:
def build_normed_state_dict_for_scaled_net(post_linear_scale, idt_scale):
    normed_state_dict = {}
    for name in scaled_net.state_dict():
        if name.endswith('weight'):
            normed_state_dict[name] = auto_norm.RMS_RMS_NormTensor(1, elem_dims=(-1, -2))
        elif name.endswith('bias'):
            normed_state_dict[name] = auto_norm.RMS_NormTensor(0, elem_dims=(-1,))
        elif name == 'idt_scaler.scale':
            normed_state_dict[name] = idt_scale
        elif name.endswith('scale'):
            normed_state_dict[name] = post_linear_scale
    return normed_state_dict

Let's verify the current output norm is the same as without the scaler (since they default to scale=1).

In [20]:
normed_state_dict = build_normed_state_dict_for_scaled_net(post_linear_scale=torch.tensor(1.), idt_scale=torch.tensor(1.))


output_norm = norm_map_for_scaled_net(normed_input, normed_state_dict=normed_state_dict)
print('output_norm: \n', output_norm)

TensorSubclassDispatcher dispatching (input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.add_.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.div.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormP

Now let's tune the scaling factors so that the output norm becomes 1!

First, let's prepare the normed state dict with scale factors that require grad:

In [21]:
post_linear_scale = torch.tensor(1., requires_grad=True)  # requres grad!
idt_scale = torch.tensor(1., requires_grad=True)
normed_state_dict = build_normed_state_dict_for_scaled_net(post_linear_scale, idt_scale)
print('normed_state_dict:')
pprint(normed_state_dict)

normed_state_dict:
{'idt_scaler.scale': tensor(1., requires_grad=True),
 'net.0.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.0.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.1.scale': tensor(1., requires_grad=True),
 'net.3.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.3.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.4.scale': tensor(1., requires_grad=True),
 'net.6.bias': RMS_NormTensor(norm_size=tensor(0.), elem_dims=(-1,), ...),
 'net.6.weight': RMS_RMS_NormTensor(norm_size=tensor(1.), elem_dims=(-1, -2), ...),
 'net.7.scale': tensor(1., requires_grad=True)}


Now, simply optimize with autograd...

In [22]:
optim = torch.optim.SGD([post_linear_scale, idt_scale], lr=0.03)
for ii in range(1, 201):
    optim.zero_grad()
    output_norm = norm_map_for_scaled_net(normed_input, normed_state_dict=normed_state_dict)
    loss = F.mse_loss(output_norm.norm_size, torch.tensor(1.))
    if ii % 50 == 0:
        print(f'iter {ii:03d}: loss={loss:.4f} output_norm={output_norm.norm_size:.4f}')
    loss.backward()
    optim.step()

print('post_linear_scale: \n', post_linear_scale)
print('idt_scale: \n', idt_scale)


TensorSubclassDispatcher dispatching (input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.add_.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.mul.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor) -> torch.Tensor
NormPropagateDispatchMode: aten.div.Tensor, ()
NormPropagateDispatchMode: aten.empty.memory_format, ()
TensorSubclassDispatcher dispatching (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor
NormP

We can verify that they works manually too:

In [23]:
import math

manual_output_norm = (
    (scaler_contribution := post_linear_scale ** 3) *
    (relu_contribution := (1 / math.sqrt(2)) ** 2) +
    (idt_contribution := idt_scale)
)
assert torch.allclose(manual_output_norm, torch.tensor(1.))
print(f'manual_output_norm: {manual_output_norm:.4f}')


manual_output_norm: 1.0000
