In [1]:
%matplotlib widget

import abc
import torch.nn.modules.module
from torch.export import export, Dim
from torch.export.exported_program import ExportedProgram, InputKind, OutputKind
import uuid
import torch._dynamo
from torch._functorch.aot_autograd import aot_module, aot_module_simplified
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
from torch._decomp import core_aten_decompositions
from torch._dynamo.backends.inductor import inductor
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._guards import detect_fake_mode
from typing import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import functools
import math
import random


from collections import namedtuple
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn, optim
import torch.nn.functional as F
import torch
import functools
import os

In [282]:
import functools
import inspect
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from torch.overrides import enable_reentrant_dispatch
from collections import defaultdict
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._pytree import tree_map
from typeguard import typechecked
from multimethod import multimethod

ENABLE_NORM_DISPATCH = True

@contextlib.contextmanager
def disable_norm_dispatch():
    global ENABLE_NORM_DISPATCH
    old_flag = ENABLE_NORM_DISPATCH
    ENABLE_NORM_DISPATCH = False
    yield
    ENABLE_NORM_DISPATCH = old_flag

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}


def get_output_fake_tensors(func, *args, **kwargs):
    # Create a fake mode
    fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
    def convert_from_real_tensor(x):
        if isinstance(x, torch.Tensor):
            return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, x)
        return x
    # Fakeify some real tensors
    with fake_mode, disable_norm_dispatch():
        args = tree_map(convert_from_real_tensor, args)
        kwargs = tree_map(convert_from_real_tensor, kwargs)
        return func(*args, **kwargs)


def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        # sig = inspect.signature(func)
        # func = typechecked(func)
        functools.update_wrapper(func, torch_function)
        assert torch_function not in HANDLED_FUNCTIONS
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator


class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        if ENABLE_NORM_DISPATCH:
            print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
            if any(issubclass(t, NormTensorBase) for t in types):
                return NotImplemented
        return func(*args, **(kwargs or {}))


class NormTensorBase(torch.Tensor):
    @staticmethod
    def __new__(cls, norm_size: Union[float, torch.Tensor], *, size: torch.Size,
                dtype: torch.dtype, device: torch.device, requires_grad=None):
        return cls._make_wrapper_subclass(cls, size, dtype=dtype, device=device, requires_grad=False)

    def __init__(self, norm_size: Union[float, torch.Tensor], *, size: torch.Size,
                 dtype: torch.dtype, device: torch.device, requires_grad=None):
        if isinstance(norm_size, torch.Tensor):
            assert requires_grad is None
            self._norm_size = norm_size
        else:
            self._norm_size = torch.full((), norm_size, dtype=torch.float32, requires_grad=requires_grad)

    @property
    def norm_size(self) -> torch.Tensor:
        return self._norm_size

    def __repr__(self):
        return f"{self.__class__.__name__}(norm_size={self.norm_size!r})"

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        if ENABLE_NORM_DISPATCH and func in HANDLED_FUNCTIONS:
            with enable_reentrant_dispatch():
                return HANDLED_FUNCTIONS[func](*args, **kwargs)
        # for handler, sig in HANDLED_FUNCTIONS.get(func, []):
        #     print(f"Trying {handler}", sig, args, kwargs)
        #     try:
        #         bound = sig.bind(*args, **kwargs)
        #     except TypeError as e:
        #         continue
        #     with enable_reentrant_dispatch():
        #         out = handler(*bound.args, **bound.kwargs)
        #     print(out.norm_size.__class__)
        #     print(out)
        #     print(out.norm_size)
        #     return out
        return NotImplemented


class RMS_NormTensor(NormTensorBase):
    pass

class RMS_RMS_NormTensor(NormTensorBase):
    pass

class L1_NormTensor(NormTensorBase):
    pass


@implements(torch.ops.aten.unsqueeze.default)
@multimethod
def unsqueeze(input: RMS_NormTensor, dim: int) -> RMS_NormTensor:
    out_fake = get_output_fake_tensors(torch.ops.aten.unsqueeze.default, input, dim)
    return RMS_NormTensor(input.norm_size, size=out_fake.size(), dtype=out_fake.dtype, device=out_fake.device)

@implements(torch.ops.aten.squeeze_.dim)
@multimethod
def squeeze(input: RMS_NormTensor, dim: int) -> RMS_NormTensor:
    out_fake = get_output_fake_tensors(torch.ops.aten.squeeze_.dim, input, dim)
    return RMS_NormTensor(input.norm_size, size=out_fake.size(), dtype=out_fake.dtype, device=out_fake.device)

@implements(torch.ops.aten.t.default)
@multimethod
def t(input: RMS_RMS_NormTensor) -> RMS_RMS_NormTensor:
    assert input.ndim == 2
    return RMS_RMS_NormTensor(input.norm_size, size=input.size()[::-1], dtype=input.dtype, device=input.device)

# @t.register
# def t(input: RMSNormTensor) -> RMSRMSNormTensor:
#     assert input.ndim == 2
#     print(input.size(), 'l')
#     return RMSRMSNormTensor(input.norm_size + 2, size=input.size()[::-1], dtype=input.dtype, device=input.device)

@implements(torch.ops.aten.addmm.default)
@multimethod
def addmm(input: RMS_NormTensor, mat1: RMS_NormTensor, mat2: RMS_RMS_NormTensor, *, beta: float = 1, alpha: float = 1) -> RMS_NormTensor:
    # output = input * beta + mat1 @ mat2 * alpha
    final_norm_size = input.norm_size * beta + mat1.norm_size * mat2.norm_size * alpha
    out_fake = get_output_fake_tensors(torch.ops.aten.addmm.default, input, mat1, mat2, beta=beta, alpha=alpha)
    return RMS_NormTensor(final_norm_size, size=out_fake.size(), dtype=out_fake.dtype, device=out_fake.device)

@addmm.register
def _(input: RMS_NormTensor, mat1: RMS_NormTensor, mat2: RMS_RMS_NormTensor, *, beta: float = 1, alpha: float = 1) -> RMS_NormTensor:
    # output = input * beta + mat1 @ mat2 * alpha
    final_norm_size = input.norm_size * beta + mat1.norm_size * mat2.norm_size * alpha
    out_fake = get_output_fake_tensors(torch.ops.aten.addmm.default, input, mat1, mat2, beta=beta, alpha=alpha)
    return RMS_NormTensor(final_norm_size, size=out_fake.size(), dtype=out_fake.dtype, device=out_fake.device)


@implements(torch.ops.aten.mm.default)
@multimethod
def mm(input: RMS_NormTensor, mat1: RMS_RMS_NormTensor) -> RMS_NormTensor:
    final_norm_size = input.norm_size * mat1.norm_size
    out_fake = get_output_fake_tensors(torch.ops.aten.mm.default, input, mat1)
    return RMS_NormTensor(final_norm_size, size=out_fake.size(), dtype=out_fake.dtype, device=out_fake.device)


In [324]:
with DispatchLog():
    y = torch.ops.aten.linear.default(
        RMS_NormTensor(3, size=(15,), dtype=torch.float32, device=torch.device("cpu"), requires_grad=True),
        RMS_RMS_NormTensor(4, size=(16, 15), dtype=torch.float32, device=torch.device("cpu"), requires_grad=True),
        # RMSNormTensor(3.5, size=(16,), dtype=torch.float32, device=torch.device("cpu"), requires_grad=True),
    )
    print(y, y.norm_size)


Dispatch Log: aten.full.default(*([], 3), **{'dtype': torch.float32, 'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.full.default(*([], 4), **{'dtype': torch.float32, 'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.t.default(*(RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True)),), **{})
Dispatch Log: aten.t.default(*(RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True)),), **{})
Dispatch Log: aten.unsqueeze.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), 0), **{})
Dispatch Log: aten.unsqueeze.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), 0), **{})
Dispatch Log: aten.mm.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True))), **{})
Dispatch Log: aten.mm.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True))), **{})
Dispatch Log: aten.mul.T

In [71]:
import torch
import torch.nn.functional
import torch

# All of the tensor examples in this zoo inherit from BaseTensor.  Ideally,
# however, they would inherit directly from Tensor.  This is just our staging
# ground for applying behavior that hasn't yet made it into core but that
# we would like to apply by default.
class BaseTensor(torch.Tensor):
    # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
    # to ensure that super().__new__ can cooperate with each other
    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        if requires_grad is None:
            return super().__new__(cls, elem)
        else:
            return cls._make_subclass(cls, elem, requires_grad)

    # To ensure constructors can cooperate with one another, must accept and
    # ignore element tensor (TODO: is this right???)
    def __init__(self, elem):
        super().__init__()

    # If __torch_dispatch__ is defined (which it will be for all our examples)
    # the default torch function implementation (which preserves subclasses)
    # typically must be disabled
    __torch_function__ = torch._C._disabled_torch_function_impl

from torch.utils._pytree import tree_map
import contextlib
from typing import Any

import torch
from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten

# Dumping ground for utilities that should eventual make their way into
# PyTorch proper


@contextlib.contextmanager
def no_dispatch():
    guard = torch._C._DisableTorchDispatch()
    try:
        yield
    finally:
        del guard


def tree_map2(fn: Any, pytree1: PyTree, pytree2: PyTree) -> PyTree:
    flat_args1, spec1 = tree_flatten(pytree1)
    flat_args2, spec2 = tree_flatten(pytree2)
    assert spec1 == spec2
    return tree_unflatten([fn(i, j) for i, j in zip(flat_args1, flat_args2)], spec1)


# IDK if this is actually useful or not
def unmake_subclass(tensor):
    with no_dispatch():
        return torch.Tensor._make_subclass(torch.Tensor, tensor)


def fill_defaults(args, n, defaults_tail):
    """
    __torch_dispatch__ doesn't guarantee the number of arguments you are
    passed (e.g., defaulted arguments are not passed); but usually it is
    convenient to pad out the arguments list with defaults.  This function
    helps you do that.

    Args:
        args: the list of positional arguments passed to __torch_dispatch__
        n: the number of arguments you are expecting to get
        defaults_tail: default values for the arguments, starting from the
            end of the list

    Example:

        >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
        [1, 2, 3, 4, 5]
        >>> fill_defaults([1, 2, 3], 5, [None, None, None])
        [1, 2, 3, None, None]]
    """
    if n - len(defaults_tail) > len(args):
        raise RuntimeError("not enough defaults to fill arguments")
    r = list(args)
    for i in range(len(args), n):
        r.append(defaults_tail[i - n + len(defaults_tail)])
    return r

from torch.overrides import enable_reentrant_dispatch

# This file describes how to use wrapper tensors (ala TrivialTensorViaComposition)
# to override autograd from __torch_dispatch__.  Ordinarily,
# __torch_dispatch__ runs after autograd, so you have no way of overriding
# the autograd behavior (since it will be handled after you return).  However,
# if we put the autograd tensor *inside* a wrapper tensor (which doesn't
# itself require gradients), we get a chance to interpose (in __torch_dispatch__)
# before you handle gradients on the inner element.
#
# Note that you can also use __torch_function__ instead to implement this
# functionality, so this is mostly a question of whether or not you want to
# target the public Python API, or the internal ATen operators API
# (torch.ops.aten).


class InnerAutogradTensor(BaseTensor):
    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        # Outer tensor's autograd is now disconnected from the inner
        # tensors autograd...
        return super().__new__(cls, elem, requires_grad=False)

    def __init__(self, elem):
        # ... but note that we save the inner tensor, so we can still
        # do autograd on operations on the inside!
        self.elem = elem

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            if isinstance(t, cls):
                return t.elem
            elif isinstance(t, torch.Tensor) and t.requires_grad:
                # If any other argument at this level does require gradients
                # it will not interact with our inner Tensor and thus this
                # should fail.
                raise RuntimeError("Bad mixup of autograd level")
            else:
                return t

        def wrap(t):
            # Micro-optimization: not necessary to rewrap if the output tensor
            # doesn't require gradients
            if (
                isinstance(t, torch.Tensor)
                and not isinstance(t, cls)
                and t.requires_grad
            ):
                return cls(t)
            else:
                return t

        with enable_reentrant_dispatch():
            # Override gradient behavior
            x = torch.randn(3, requires_grad=True)
            print(x, x + x, addself(x))
            if func == torch.ops.aten.embedding.default:
                args = fill_defaults(args, 5, [-1, False, False])
                weight, indices, padding_idx, scale_grad_by_freq, _sparse = map(
                    unwrap, args
                )
                assert not kwargs
                # Force sparse gradients.  We could have also done this by
                # defining a custom autograd function.
                return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True))

            return tree_map(
                wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
            )


def addself(x):
    return x + x


In [73]:
torch.ops.aten.embedding.default._schema


aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor

In [75]:


input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
weights = torch.rand(10, 3, requires_grad=True)
embedding_matrix = InnerAutogradTensor(weights)
r = torch.ops.aten.embedding.default(embedding_matrix, input)
r.sum().elem.backward()

tensor([-0.8549, -0.8576,  1.1171], requires_grad=True) tensor([-1.7098, -1.7153,  2.2342], grad_fn=<AddBackward0>) tensor([-1.7098, -1.7153,  2.2342], grad_fn=<AddBackward0>)
tensor([0.7289, 1.0692, 0.3923], requires_grad=True) tensor([1.4578, 2.1384, 0.7846], grad_fn=<AddBackward0>) tensor([1.4578, 2.1384, 0.7846], grad_fn=<AddBackward0>)


In [68]:
weights.grad


tensor(indices=tensor([[1, 2, 4, 5, 4, 3, 2, 9]]),
       values=tensor([[1., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.]]),
       size=(10, 3), nnz=8, layout=torch.sparse_coo)

# Let's always attach to a export graph node (which has a fake tensor)

In [274]:
import functools
import inspect
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode
from torch.overrides import enable_reentrant_dispatch, TorchFunctionMode, _get_current_function_mode, _get_current_function_mode_stack
from collections import defaultdict
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map
from torch.fx.operator_schemas import (
    _torchscript_schema_to_signature,
)
from torch._guards import detect_fake_mode, active_fake_mode
import torch.utils._pytree as pytree
import typing
import numbers
from typing import *
from collections import OrderedDict
from torch._ops import OpOverload, OpOverloadPacket, _has_script_object_arg

In [544]:
class NormedTensorBase(torch.Tensor):
    _backing_tensor: Optional[torch.Tensor]

    @staticmethod
    def __new__(cls, norm_size: Union[float, torch.Tensor], elem_dims: Optional[Tuple[int, ...]] = None, *,
                backing_tensor: Optional[torch.Tensor] = None, requires_grad: Optional[bool] = None):
        if issubclass(cls.__base__, NormedTensorBase) and cls.__base__ != NormedTensorBase:
            raise TypeError(f"NormedTensorBase can only be subclassed with one level of inheritance")
        if backing_tensor is None:
            backing_tensor = torch.empty((0,))  # this is a placeholder so that _make_wrapper_subclass doesn't fail, will have finalize=False
        else:
            assert type(backing_tensor) in (torch.Tensor, FakeTensor)
        return cls._make_wrapper_subclass(cls, backing_tensor.size(), dtype=backing_tensor.dtype, device=backing_tensor.device,
                                          requires_grad=False)  # NB: false here so that we can use reentrant dispatch on unwrapped normed tensors to get autograd on norms

    def __init__(self, norm_size: Union[float, torch.Tensor], elem_dims: Optional[Tuple[int, ...]] = None, *,
                 backing_tensor: Optional[torch.Tensor] = None, requires_grad: Optional[bool] = None):
        if isinstance(norm_size, torch.Tensor):
            assert requires_grad is None
            self._norm_size = norm_size
        else:
            self._norm_size = torch.full((), norm_size, dtype=torch.float32, requires_grad=requires_grad)
        if backing_tensor is not None:
            # finalized
            if elem_dims is None:
                # default
                elem_dims = tuple(range(backing_tensor.ndim))
            elem_dims = tuple(sorted(d % backing_tensor.ndim for d in elem_dims))
        self._elem_dims = elem_dims
        self._backing_tensor = backing_tensor

    def finalize(self, backing_tensor: torch.Tensor) -> Self:
        assert not self._finalized
        return self.__class__(self._norm_size, elem_dims=self._elem_dims, backing_tensor=backing_tensor)

    def elem_dims_are(self, dims: Iterable[int]) -> bool:
        # FIXME: figure out a good broadcasting API
        assert self._finalized
        return self._elem_dims == tuple(sorted(d % self.ndim for d in dims))

    def same_elem_dims(self, other: 'NormedTensorBase') -> bool:
        # broadcasting
        assert self._finalized and other._finalized
        _ = torch.broadcast_shapes(self.shape, other.shape)
        # convert to negative indexing
        return self.neg_elem_dims == other.neg_elem_dims

    @property
    def _finalized(self):
        return self._backing_tensor is not None

    @property
    def norm_size(self) -> torch.Tensor:
        assert self._finalized
        return self._norm_size

    @property
    def elem_dims(self) -> Tuple[int, ...]:
        assert self._finalized
        return self._elem_dims

    @property
    def neg_elem_dims(self) -> Tuple[int, ...]:
        return tuple(d - self.ndim for d in self._elem_dims)

    @property
    def unwrapped(self) -> torch.Tensor:
        assert self._finalized
        return self._backing_tensor

    def __repr__(self):
        if self._finalized:
            return f"""{self.__class__.__name__}(
    norm_size={self.norm_size!r},
    elem_dims={self.elem_dims!r},
    unwrapped={self.unwrapped!r},
)"""
        else:
            return f"""{self.__class__.__name__}(norm_size={self.norm_size!r}, NOT_FINALIZED)"""


    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        return NotImplemented
        print(f"base cls Dispatch Log: {func}, {types}")
        # with enable_reentrant_dispatch():
        if func in REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP:
            with enable_reentrant_dispatch(), torch.set_grad_enabled(True):
                x = torch.randn(3, requires_grad=True)
                print(x, x + x)
                return REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP[func].normed_dispatcher(*args, **(kwargs or {}))
        return func(*args, **(kwargs or {}))
        return NotImplemented
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        if ENABLE_NORM_DISPATCH and func in HANDLED_FUNCTIONS:
            with enable_reentrant_dispatch():
                return HANDLED_FUNCTIONS[func](*args, **kwargs)
        # for handler, sig in HANDLED_FUNCTIONS.get(func, []):
        #     print(f"Trying {handler}", sig, args, kwargs)
        #     try:
        #         bound = sig.bind(*args, **kwargs)
        #     except TypeError as e:
        #         continue
        #     with enable_reentrant_dispatch():
        #         out = handler(*bound.args, **bound.kwargs)
        #     print(out.norm_size.__class__)
        #     print(out)
        #     print(out.norm_size)
        #     return out
        return NotImplemented


class RMS_NormTensor(NormedTensorBase):
    pass

class RMS_RMS_NormTensor(NormedTensorBase):
    pass

class L1_NormTensor(NormedTensorBase):
    pass

class Linf_NormTensor(NormedTensorBase):
    pass


In [588]:
torch.library.Library('modula', "FRAGMENT")._destroy()

class NormedTensorDispatcher:
    # dispatches things based on the classes of NormTensorBase arguments

    def __init__(self, ref_sig: inspect.Signature, *, ignored_params: Iterable[str] = ()):
        self.ref_sig = ref_sig
        self.ignored_params = tuple(ignored_params)
        self.handled_functions = OrderedDict()
        functools.update_wrapper(self, ref_sig)

        dispatch_key_arg_names = []
        for param in self.ref_sig.parameters.values():
            if inspect.isclass(param.annotation) and issubclass(param.annotation, torch.Tensor) and param.name not in self.ignored_params:
                dispatch_key_arg_names.append(param.name)
        self.dispatch_key_arg_names = tuple(sorted(dispatch_key_arg_names))

    @staticmethod
    def _assert_specialized(ref_sig: inspect.Signature, specialized_sig: inspect.Signature, *,
                            allow_non_normed_tensor_inputs: bool = False):
        try:
            def only_normed_tensor(ty):
                if origin := typing.get_origin(ty):
                    return only_normed_tensor(origin) and all(only_normed_tensor(t) for t in typing.get_args(ty))
                if inspect.isclass(ty) and issubclass(ty, torch.Tensor):
                    return inspect.isclass(ty) and issubclass(ty, NormedTensorBase) # and ty != NormedTensorBase
                return True

            def is_compatible_type(ref_type, specialized_type):
                # print(ref_type, specialized_type)
                if ref_origin := typing.get_origin(ref_type):
                    if not is_compatible_type(typing.get_origin(specialized_type), ref_origin):
                        return False
                    ref_args = typing.get_args(ref_type)
                    specialized_args = typing.get_args(specialized_type)
                    if len(ref_args) != len(specialized_args):
                        return False
                    return all(is_compatible_type(ref_t, specialized_t) for ref_t, specialized_t in zip(ref_args, specialized_args))
                if ref_type == specialized_type:
                    return True
                if specialized_type is typing.Any:
                    return True
                if ref_type is numbers.Number:
                    return specialized_type == float
                if specialized_type in (torch.dtype, torch.layout) and ref_type is int:
                    return True
                if inspect.isclass(ref_type) and inspect.isclass(specialized_type):
                    return issubclass(specialized_type, ref_type)
                return False

            assert set(ref_sig.parameters.keys()) == set(specialized_sig.parameters.keys()), f"Function has a different signature"
            for param_name in ref_sig.parameters.keys():
                ref_param = ref_sig.parameters[param_name]
                specialized_param = specialized_sig.parameters[param_name]
                if not allow_non_normed_tensor_inputs:
                    assert only_normed_tensor(specialized_param.annotation), f"Specialized {specialized_sig} has a non-normed tensor parameter {param_name}"
                assert is_compatible_type(ref_param.annotation, specialized_param.annotation), f"Parameter {param_name} has a different type"

        except AssertionError as e:
            raise TypeError(f"Specialized {specialized_sig} has a different signature from {ref_sig}") from e

    def register(self, specialized_func: Optional[Callable] = None, *, allow_non_normed_tensor_inputs: bool = False):
        def decorator(specialized_func):
            specialized_sig = inspect.signature(specialized_func)
            self._assert_specialized(self.ref_sig, specialized_sig, allow_non_normed_tensor_inputs=allow_non_normed_tensor_inputs)
            dispatch_key = tuple(specialized_sig.parameters[name].annotation for name in self.dispatch_key_arg_names)
            if not allow_non_normed_tensor_inputs:
                assert all(inspect.isclass(t) and issubclass(t, NormedTensorBase) for t in dispatch_key)
            assert dispatch_key not in self.handled_functions
            # print(dispatch_key, specialized_func)
            self.handled_functions[dispatch_key] = specialized_func
            return specialized_func
        if specialized_func is None:
            return decorator
        return decorator(specialized_func)

    def __call__(self, *args, **kwargs):
        bound = self.ref_sig.bind(*args, **kwargs)
        dispatch_key = tuple(bound.arguments[name].__class__ for name in self.dispatch_key_arg_names)
        for k, fn in self.handled_functions.items():
            if all(issubclass(q, k) for q, k in zip(dispatch_key, k)):
                return fn(*args, **kwargs)
        raise NotImplementedError(f"No dispatch rule found for {dispatch_key}")


REG_FAKE_NORM_OP_REGISTRY: Dict[Callable, 'RegFakeNormOp'] = {}
REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP: Dict[Callable, 'RegFakeNormOp'] = {}

class RegFakeNormOp:
    reg_sig: inspect.Signature
    wrapper_custom_op: torch.library.CustomOpDef
    wrapper_custom_op_entrypoint: Callable
    normed_dispatcher: NormedTensorDispatcher

    @property
    def register_norm(self):
        return self.normed_dispatcher.register

    @property
    def register_fake(self):
        return self.wrapper_custom_op.register_fake

    def __init__(self, func: Callable, *, schema: Optional[str] = None, func_prefix: str = 'wrapper'):
        if isinstance(func, OpOverload):
            # for torch lib ops, we need the schema. inspect.signature gives (*args, **kwargs)
            schema = str(func._schema)
            reg_sig = _torchscript_schema_to_signature(func._schema)  # this overwrites the signature if provided
        else:
            if schema is not None:
                reg_sig = _torchscript_schema_to_signature(torch._C.parse_schema(schema))
            else:
                # this may error, so last resort
                reg_sig = inspect.signature(func)

        for param in reg_sig.parameters.values():
            assert param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD), f"Parameter {param.name} is var positional or var keyword"

        # register a new op
        func_name = f"op__{func_prefix}__{(func.__module__ + '.' + func.__qualname__).replace('::', '_').replace('.', '_')}__{id(func)}"
        op_id = f"modula::{func_name}"
        func = functools.partial(func)
        func.__signature__ = reg_sig
        if schema is not None:
            # name it nameless
            nameless_schema = '(' + schema.split('(', 1)[1]
        else:
            nameless_schema = None
        wrapper_custom_op: torch.library.CustomOpDef = torch.library.custom_op(op_id, func, mutates_args=(), schema=nameless_schema)
        wrapper_custom_op.register_fake(func)  # can be modified by self.register_fake

        # def torch_dispatch_wrapper(mode, func, types, args=(), kwargs=None):
        #     kwargs = kwargs or {}
        #     print(f"COP Dispatch Log: {func}, {types}")
        #     with enable_reentrant_dispatch():
        #         x = torch.randn(3, requires_grad=True)
        #         print(x, x + x)
        #         return self.normed_dispatcher(*args, **kwargs)

        # wrapper_custom_op.register_torch_dispatch(RMS_NormTensor, torch_dispatch_wrapper)
        # wrapper_custom_op.register_torch_dispatch(L1_NormTensor, torch_dispatch_wrapper)
        # wrapper_custom_op.register_torch_dispatch(RMS_RMS_NormTensor, torch_dispatch_wrapper)

        self.reg_sig = reg_sig
        self.wrapper_custom_op = wrapper_custom_op
        self.wrapper_custom_op_entrypoint = getattr(torch.ops.modula, func_name).default
        self.normed_dispatcher = NormedTensorDispatcher(reg_sig)
        functools.update_wrapper(self, func)

    def __call__(self, *args, **kwargs):
        return self.wrapper_custom_op(*args, **kwargs)

    # def call_fake(self, *args, fake_mode: Optional[FakeTensorMode] = None, **kwargs):
    #     fake_mode = fake_mode or active_fake_mode()
    #     if fake_mode is None:
    #         fake_mode = FakeTensorMode()
    #     def convert_from_real_tensor(x):
    #         if isinstance(x, torch.Tensor):
    #             return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, x)
    #         return x
    #     # Fakeify some real tensors
    #     with fake_mode:
    #         args = tree_map(convert_from_real_tensor, args)
    #         kwargs = tree_map(convert_from_real_tensor, kwargs)
    #         return self(*args, **kwargs)


class ExportFakeFunctionMode(TorchFunctionMode):
    # Used when exporting, to attach custom ops to the export graph.
    # The resulting graph should only contain `wrapper_custom_op`, .
    # Even ATen core IR ops should be wrapped in `wrapper_custom_op`.
    def __torch_function__(self, func, types, args=(), kwargs=None):
        # print(f"Dispatch Log: {func}, {types}")
        kwargs = kwargs or {}
        if func in REG_FAKE_NORM_OP_REGISTRY:
            return REG_FAKE_NORM_OP_REGISTRY[func](*args, **kwargs)
        # if any(issubclass(t, NormTensorBase) for t in types):
        #     return NotImplemented
        return func(*args, **kwargs)


MODULAR_EXPORTING = False

@contextlib.contextmanager
def modula_export():
    global MODULAR_EXPORTING
    assert not MODULAR_EXPORTING, "Cannot nest modula_export"
    MODULAR_EXPORTING = True
    with ExportFakeFunctionMode():
        yield
    MODULAR_EXPORTING = False




def finalize_normed_out(unfinalized_normed_out, fake_out):
    flat_fake_out, fake_out_tree_spec = pytree.tree_flatten(fake_out)
    flat_unfinalized_normed_out, unfinalized_normed_out_tree_spec = pytree.tree_flatten(unfinalized_normed_out)
    assert pytree.treespec_dumps(fake_out_tree_spec) == pytree.treespec_dumps(unfinalized_normed_out_tree_spec), f"Tree spec mismatch"
    return pytree.tree_unflatten(
        [
            normed.finalize(out) for normed, out in zip(flat_unfinalized_normed_out, flat_fake_out)
        ],
        fake_out_tree_spec,
    )


class NormPropagateDispatchMode(TorchDispatchMode):
    # Used when propagating norms on an exported graph, which contains only `wrapper_custom_op`.
    # We handle here instead of `wrapper_custom_op.register_torch_dispatch(exact_type, ...)` because we want to
    # capture all NormedTensorBase subclasses, and don't want to register a dispatch rule for each one.

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

    def _call_fake_with_normed_args(self, op: RegFakeNormOp, *args, **kwargs):
        def convert_from_normed_tensor(x):
            if isinstance(x, NormedTensorBase):
                return self.fake_mode.fake_tensor_converter.from_real_tensor(self.fake_mode, x.unwrapped)  # also works on fake tensor
            return x

        with self.fake_mode:
            args = tree_map(convert_from_normed_tensor, args)
            kwargs = tree_map(convert_from_normed_tensor, kwargs)
            return op(*args, **kwargs)

    def __torch_dispatch__(self, func, types, args, kwargs):
        # print(f"Dispatch Log: {func}, {types}", self.enabled,active_fake_mode())
        kwargs = kwargs or {}
        # if not any(issubclass(t, NormedTensorBase) for t in types):
        if func in REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP:  # NB: actual factories like torch.empty won't be in here since this only contains wrapped versions
            # normed mode
            with enable_reentrant_dispatch(), self, torch.set_grad_enabled(True):
                op = REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP[func]
                unfinalized_normed = op.normed_dispatcher(*args, **kwargs)
            fake = self._call_fake_with_normed_args(op, *args, **kwargs)
            return finalize_normed_out(unfinalized_normed, fake)
        # fake or real mode
        assert not any(issubclass(t, NormedTensorBase) for t in types)
        return func(*args, **kwargs)
        return NotImplemented
        # if any(issubclass(t, NormTensorBase) for t in types):
        #     return NotImplemented
        # return func(*args, **kwargs)


@contextlib.contextmanager
def norm_propagate_dispatch():
    with NormPropagateDispatchMode() as mode:
        yield mode


def reg_fake_norm_op(op: Optional[Callable] = None, *, schema: Optional[str] = None, func_prefix: str = 'wrapper') -> RegFakeNormOp:
    def decorator(op):
        if op not in REG_FAKE_NORM_OP_REGISTRY:
            reg_fake_norm_op = RegFakeNormOp(op, schema=schema, func_prefix=func_prefix)
            REG_FAKE_NORM_OP_REGISTRY[op] = reg_fake_norm_op
            REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP[reg_fake_norm_op.wrapper_custom_op_entrypoint] = reg_fake_norm_op
        return REG_FAKE_NORM_OP_REGISTRY[op]
    if op is None:
        return decorator
    return decorator(op)



class ConstantScaler(nn.Module):
    @reg_fake_norm_op(func_prefix='constant_scaler_mul')
    def _mul_with_scaler(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
        assert scale.ndim == 0
        return input * scale

    @_mul_with_scaler.register_norm(allow_non_normed_tensor_inputs=True)
    def _(input: NormedTensorBase, scale: torch.Tensor) -> NormedTensorBase:
        assert scale.ndim == 0
        return input.__class__(input.norm_size * scale, elem_dims=input.elem_dims)

    scale: torch.Tensor

    def __init__(self, scale: float):
        super().__init__()
        self.register_buffer('scale', torch.tensor(scale, dtype=torch.float32))

    def forward(self, x):
        return ConstantScaler._mul_with_scaler(x, self.scale)
        if MODULAR_EXPORTING:
            # do something that modula can detect
            return ConstantScaler._OP(x, self.scale)
        return x * self.scale


In [589]:
@reg_fake_norm_op(torch.nn.functional.linear, schema="linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor").register_norm
def linear(input: RMS_NormTensor, weight: RMS_RMS_NormTensor, bias: Optional[RMS_NormTensor] = None) -> RMS_NormTensor:
    assert input.elem_dims_are(dims=(-1,))
    assert weight.elem_dims_are(dims=(-1, -2))
    final_norm_size = input.norm_size * weight.norm_size
    if bias is not None:
        assert bias.elem_dims_are(dims=(-1,))
        final_norm_size += bias.norm_size
    return RMS_NormTensor(final_norm_size, elem_dims=(-1,))

In [590]:
@reg_fake_norm_op(torch.ops.aten.randn.default).register_norm
def randn(size: List[int], *, dtype: Optional[torch.dtype] = None, layout: Optional[torch.layout] = torch.strided, device: Optional[torch.device] = None, pin_memory: Optional[bool] = False) -> RMS_NormTensor:
    return RMS_NormTensor(1, elem_dims=None)

In [591]:
@reg_fake_norm_op(torch.ops.aten.add.Tensor).register_norm
def add(input: RMS_NormTensor, other: RMS_NormTensor, *, alpha: float = 1) -> RMS_NormTensor:
    assert input.same_elem_dims(other)  # FIXME
    return RMS_NormTensor(input.norm_size + other.norm_size * alpha, elem_dims=input.neg_elem_dims)


In [592]:
@reg_fake_norm_op(torch.nn.functional.layer_norm,
                  schema="layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05) -> Tensor").register_norm
def layer_norm(input: RMS_NormTensor, normalized_shape: List[int], weight: Optional[Linf_NormTensor] = None, bias: Optional[RMS_NormTensor] = None, eps: float = 1e-05) -> RMS_NormTensor:
    # assert input.elem_dims_are(dims=normalized_shape)
    # FIXME: this is wrong
    output_norm_size = input.norm_size
    if weight is not None:
        output_norm_size += weight.norm_size
    if bias is not None:
        output_norm_size += bias.norm_size
    return RMS_NormTensor(output_norm_size, elem_dims=input.elem_dims)


In [593]:
torch.ops.aten.relu.default._schema


aten::relu(Tensor self) -> Tensor

In [594]:
@reg_fake_norm_op(torch.ops.aten.relu.default).register_norm
def relu(input: RMS_NormTensor) -> RMS_NormTensor:
    return RMS_NormTensor(input.norm_size / np.sqrt(2), elem_dims=input.elem_dims)


In [595]:
@reg_fake_norm_op(torch.nn.functional.scaled_dot_product_attention,
                  schema="sdpa(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout=0.0, bool is_causal=False) -> Tensor").register_norm
def scaled_dot_product_attention(query: RMS_NormTensor, key: RMS_NormTensor, value: RMS_NormTensor, attn_mask: Optional[RMS_NormTensor] = None,
                                 dropout: float = 0.0, is_causal: bool = False) -> RMS_NormTensor:
    return value


In [596]:
from transformer import GPT

In [597]:
net = GPT(4, 256, 3, 64, 256)

example_input = torch.randint(0, 256, (2, 64)),
with modula_export():
    ep = torch.export.export(
        net,
        example_input,
        dynamic_shapes=[{0: batch}]
    )

    ep = ep.run_decompositions()

gm = ep.module()

In [598]:
print(gm)

GraphModule(
  (lm_head): Module()
  (transformer): Module(
    (wp_embedding): Module()
    (h): Module(
      (0): Module(
        (ln_1): Module()
        (attn): Module(
          (in_proj): Module()
          (out_proj): Module()
        )
        (ln_2): Module()
        (mlp): Module(
          (fc1): Module()
          (fc2): Module()
        )
      )
      (1): Module(
        (ln_1): Module()
        (attn): Module(
          (in_proj): Module()
          (out_proj): Module()
        )
        (ln_2): Module()
        (mlp): Module(
          (fc1): Module()
          (fc2): Module()
        )
      )
      (2): Module(
        (ln_1): Module()
        (attn): Module(
          (in_proj): Module()
          (out_proj): Module()
        )
        (ln_2): Module()
        (mlp): Module(
          (fc1): Module()
          (fc2): Module()
        )
      )
    )
    (ln_f): Module()
  )
)



def forward(self, idx):
    idx, = fx_pytree.tree_flatten_spec(([idx], {}), self._in_sp

In [599]:
from torch.export import Dim, ExportedProgram

batch = Dim('batch')

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(15, 16)
        self.net = nn.Sequential(
            nn.Linear(15, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
        )
        self.scaler = ConstantScaler(2)

    def forward(self, x):
        v = self.scaler(x) + torch.randn(15)
        return self.linear(v) + self.net(v)

net = MyNet()
example_input = torch.randn(10, 15, requires_grad=True),

with modula_export():
    ep: ExportedProgram = torch.export.export(
        net,
        example_input,
        dynamic_shapes=[{0: batch}]
    )

    ep = ep.run_decompositions()

gm = ep.module()
# https://pytorch.org/docs/stable/export.ir_spec.html#node
# gm = aot_module(gm, (torch.randn(2, 15, requires_grad=True),))


In [600]:
print(ep)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_linear_weight: "f32[16, 15]", p_linear_bias: "f32[16]", p_net_0_weight: "f32[16, 15]", p_net_0_bias: "f32[16]", p_net_2_weight: "f32[16, 16]", p_net_2_bias: "f32[16]", p_net_4_weight: "f32[16, 16]", p_net_4_bias: "f32[16]", b_scaler_scale: "f32[]", x: "f32[s0, 15]"):
             # File: /Users/S_sn/miniconda3/lib/python3.11/site-packages/torch/_library/custom_ops.py:669 in __call__, code: return self._opoverload(*args, **kwargs)
            op__constant_scaler_mul____main___constant_scaler__mul_with_scaler__13450329472: "f32[s0, 15]" = torch.ops.modula.op__constant_scaler_mul____main___ConstantScaler__mul_with_scaler__13450329472.default(x, b_scaler_scale);  x = b_scaler_scale = None
            
             # File: /var/folders/ky/gxqpxwvx29ggdsqzf67zslfh0000gn/T/ipykernel_69251/3892711258.py:19 in forward, code: v = self.scaler(x) + torch.randn(15)
            op__wrapper__torch__ops_aten_aten_rand

In [601]:
list(ep.graph_module.graph.nodes)

[p_linear_weight,
 p_linear_bias,
 p_net_0_weight,
 p_net_0_bias,
 p_net_2_weight,
 p_net_2_bias,
 p_net_4_weight,
 p_net_4_bias,
 b_scaler_scale,
 x,
 op__constant_scaler_mul____main___constant_scaler__mul_with_scaler__13450329472,
 op__wrapper__torch__ops_aten_aten_randn__4804563472,
 op__wrapper__torch__ops_aten_aten_add_tensor__4740564240,
 op__wrapper__torch__c__nn_linear__4535862672,
 op__wrapper__torch__c__nn_linear__4535862673,
 op__wrapper__torch__ops_aten_aten_relu__4806405200,
 op__wrapper__torch__c__nn_linear__4535862674,
 op__wrapper__torch__ops_aten_aten_relu__4806405201,
 op__wrapper__torch__c__nn_linear__4535862675,
 op__wrapper__torch__ops_aten_aten_add_tensor__4740564241,
 output]

In [602]:
ep.module_call_graph

[ModuleCallEntry(fqn='', signature=ModuleCallSignature(inputs=[], outputs=[], in_spec=TreeSpec(tuple, None, [TreeSpec(tuple, None, [*]),
   TreeSpec(dict, [], [])]), out_spec=*)),
 ModuleCallEntry(fqn='linear', signature=None),
 ModuleCallEntry(fqn='net', signature=None),
 ModuleCallEntry(fqn='net.0', signature=None),
 ModuleCallEntry(fqn='net.1', signature=None),
 ModuleCallEntry(fqn='net.2', signature=None),
 ModuleCallEntry(fqn='net.3', signature=None),
 ModuleCallEntry(fqn='net.4', signature=None),
 ModuleCallEntry(fqn='scaler', signature=None)]

In [603]:
nodes = list(ep.graph.nodes)

In [606]:
nodes[0].format_node()

'%p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]'

In [607]:
ep.graph_signature.input_specs

[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_net_0_weight'), target='net.0.weight', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_net_0_bias'), target='net.0.bias', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_net_2_weight'), target='net.2.weight', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_net_2_bias'), target='net.2.bias', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_net_4_weight'), target='net.4.weight', persistent=None),
 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_net_4_bias'), target='net.4.bias', persistent=None),
 Inp

In [614]:
ep.graph_signature.input_specs[0].kind.__class__

<enum 'InputKind'>

In [615]:
from torch.export.graph_signature import InputKind
inputs = []
for node, spec in zip(nodes, ep.graph_signature.input_specs):
    if spec.kind == InputKind.PARAMETER:
        target = spec.target
        if target.endswith('.weight'):
            input = RMS_RMS_NormTensor(4, elem_dims=(-1, -2), backing_tensor=node.meta['val'], requires_grad=True)
        elif target.endswith('.bias'):
            input = RMS_NormTensor(2.5, elem_dims=(-1,), backing_tensor=node.meta['val'], requires_grad=True)
        else:
            raise ValueError(f"Unknown parameter: {target}")
    elif spec.kind == InputKind.BUFFER:
        assert spec.target == 'scaler.scale'
        input = net.scaler.scale.clone().requires_grad_(True)
    elif spec.kind == InputKind.USER_INPUT:
        input = RMS_NormTensor(3, elem_dims=(-1,), backing_tensor=node.meta['val'], requires_grad=True)
    inputs.append(input)


In [562]:
for node in nodes:
    print(node, node.meta.keys())
    print('--'*10)


p_linear_weight dict_keys(['val', 'tensor_meta', 'source_fn_stack', 'from_node', 'seq_nr', 'example_value'])
--------------------
p_linear_bias dict_keys(['val', 'tensor_meta', 'source_fn_stack', 'from_node', 'seq_nr', 'example_value'])
--------------------
b_scaler_scale dict_keys(['val', 'tensor_meta', 'from_node', 'seq_nr', 'example_value', 'source_fn_stack'])
--------------------
x dict_keys(['val', 'tensor_meta'])
--------------------
op__constant_scaler_mul____main___constant_scaler__mul_with_scaler__13464146048 dict_keys(['stack_trace', 'nn_module_stack', 'torch_fn', 'source_fn_stack', 'original_aten', 'from_node', 'seq_nr', 'val', 'tensor_meta'])
--------------------
op__wrapper__torch__ops_aten_aten_randn__4804563472 dict_keys(['stack_trace', 'nn_module_stack', 'torch_fn', 'source_fn_stack', 'original_aten', 'from_node', 'seq_nr', 'val', 'tensor_meta'])
--------------------
op__wrapper__torch__ops_aten_aten_add_tensor__4740564240 dict_keys(['stack_trace', 'nn_module_stack', 't

In [563]:
nodes[2].meta

{'val': FakeTensor(..., size=()),
 'tensor_meta': TensorMetadata(shape=torch.Size([]), dtype=torch.float32, requires_grad=False, stride=(), memory_format=torch.contiguous_format, is_quantized=False, qparams={}),
 'from_node': [('op__constant_scaler_mul____main___constant_scaler__mul_with_scaler__13464146048_default',
   <OpOverload(op='modula.op__constant_scaler_mul____main___ConstantScaler__mul_with_scaler__13464146048', overload='default')>)],
 'seq_nr': 9977,
 'example_value': FakeTensor(..., size=()),
 'source_fn_stack': [('op__constant_scaler_mul____main___constant_scaler__mul_with_scaler__13464146048_default',
   <OpOverload(op='modula.op__constant_scaler_mul____main___ConstantScaler__mul_with_scaler__13464146048', overload='default')>)]}

In [564]:
nodes[4].stack_trace

'  File "/var/folders/ky/gxqpxwvx29ggdsqzf67zslfh0000gn/T/ipykernel_69251/3024122329.py", line 12, in forward\n    v = self.scaler(x) + torch.randn(15)\n  File "/var/folders/ky/gxqpxwvx29ggdsqzf67zslfh0000gn/T/ipykernel_69251/3538699669.py", line 287, in forward\n    return ConstantScaler._mul_with_scaler(x, self.scale)\n  File "/var/folders/ky/gxqpxwvx29ggdsqzf67zslfh0000gn/T/ipykernel_69251/3538699669.py", line 150, in __call__\n    return self.wrapper_custom_op(*args, **kwargs)\n  File "/Users/S_sn/miniconda3/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 669, in __call__\n    return self._opoverload(*args, **kwargs)\n'

In [618]:
with norm_propagate_dispatch() as mode:
    out, = ep.graph_module(*inputs)
out


RMS_NormTensor(
    norm_size=tensor(284.0711, grad_fn=<AddBackward0>),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(s0, 16)),
)

In [619]:
torch.autograd.grad(out.norm_size, inputs[-2])

(tensor(108.),)

In [497]:
with norm_propagate_dispatch():
    out = ep.graph_module(
        RMS_RMS_NormTensor(4, elem_dims=(-1, -2), backing_tensor=nodes[0].meta['val'], requires_grad=True),
        RMS_NormTensor(2.5, elem_dims=(-1,), backing_tensor=nodes[1].meta['val'], requires_grad=True),
        net.scaler.scale.requires_grad_(True),
        RMS_NormTensor(3, elem_dims=(-1,), backing_tensor=nodes[3].meta['val'], requires_grad=True),
    )

Dispatch Log: aten.full.default, () True None
Dispatch Log: aten.full.default, () True None
Dispatch Log: aten.full.default, () True None
Dispatch Log: modula.op__constant_scaler_mul____main___ConstantScaler__mul_with_scaler__13464011456.default, (<class '__main__.RMS_NormTensor'>,) True None
Dispatch Log: aten.mul.Tensor, () True None
Dispatch Log: aten.empty.memory_format, () True None
Dispatch Log: modula.op__wrapper__torch__ops_aten_aten_randn__4804563472.default, () True None
Dispatch Log: aten.empty.memory_format, () True None
Dispatch Log: aten.full.default, () True None
Dispatch Log: modula.op__wrapper__torch__ops_aten_aten_add_Tensor__4740564240.default, (<class '__main__.RMS_NormTensor'>,) True None
Dispatch Log: aten.mul.Tensor, () True None
Dispatch Log: aten.add.Tensor, () True None
Dispatch Log: aten.empty.memory_format, () True None
Dispatch Log: modula.op__wrapper__torch__C__nn_linear__4535862672.default, (<class '__main__.RMS_NormTensor'>, <class '__main__.RMS_RMS_Norm

In [455]:
out

(RMS_NormTensor(
     norm_size=tensor(30.5000, grad_fn=<AddBackward0>),
     elem_dims=(1,),
     unwrapped=FakeTensor(..., size=(s0, 16)),
 ),)

In [354]:
nodes[0].meta['norm_val'] = RMS_RMS_NormTensor(4, elem_dims=(-1, -2), backing_tensor=nodes[0].meta['val'], requires_grad=True)
nodes[1].meta['norm_val'] = RMS_NormTensor(2.5, elem_dims=(-1,), backing_tensor=nodes[1].meta['val'], requires_grad=True)
nodes[2].meta['norm_val'] = net.scaler.scale.requires_grad_(True)
# nodes[2].meta['norm_val'] = RMS_NormTensor(3, elem_dims=(-1,), backing_tensor=nodes[2].meta['val'], requires_grad=True)
nodes[3].meta['norm_val'] = RMS_NormTensor(3, elem_dims=(-1,), backing_tensor=nodes[3].meta['val'], requires_grad=True)

In [355]:
import torch.utils._pytree as pytree

# with NormPropagateDispatchMode():
with norm_propagate_dispatch(global_dispatch=True):
    for node in nodes[3:]:
        if isinstance(node.target, str): continue
        flat_out, out_tree_spec = pytree.tree_flatten(node.meta['val'])

        flat_unfinalized_normed_out, unfinalized_normed_out_tree_spec = pytree.tree_flatten(
            node.target(
                *tree_map(lambda x: x.meta['norm_val'] if isinstance(x, torch.fx.Node) else x, node.args),
                **tree_map(lambda x: x.meta['norm_val'] if isinstance(x, torch.fx.Node) else x, node.kwargs)
            )
        )
        # print(flat_unfinalized_normed_out[0]._norm_size)
        assert pytree.treespec_dumps(out_tree_spec) == pytree.treespec_dumps(unfinalized_normed_out_tree_spec), f"Tree spec mismatch"
        node.meta['norm_val'] = pytree.tree_unflatten(
            [
                normed.finalize(out) for normed, out in zip(flat_unfinalized_normed_out, flat_out)
            ],
            out_tree_spec,
        )
        print(node.meta['norm_val'])
        print('--'*10)



Dispatch Log: modula.op__constant_scaler_mul____main___ConstantScaler__mul_with_scaler__13506153504.default, (<class '__main__.RMS_NormTensor'>,)
Dispatch Log: aten.mul.Tensor, ()
Dispatch Log: aten.empty.memory_format, ()
Dispatch Log: modula.op__constant_scaler_mul____main___ConstantScaler__mul_with_scaler__13506153504.default, (<class '__main__.RMS_NormTensor'>,)


AssertionError: 

In [52]:
node.target

<OpOverload(op='modula.op__wrapper__aten_randn__5682461136', overload='default')>

In [26]:
net.scaler.scale

tensor(2., requires_grad=True)

In [112]:
torch.autograd.grad(nodes[-2].meta['norm_val'].norm_size, [net.scaler.scale])

(tensor(12.),)

In [28]:
REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP

{<OpOverload(op='modula.op__constant_scaler_mul__aten_mul_Tensor__5206687248', overload='default')>: <__main__.RegFakeNormOp at 0x13690ca50>,
 <OpOverload(op='modula.op__wrapper__linear__4899357776', overload='default')>: <__main__.RegFakeNormOp at 0x30e135310>,
 <OpOverload(op='modula.op__wrapper__aten_randn__5211397008', overload='default')>: <__main__.RegFakeNormOp at 0x30e137650>,
 <OpOverload(op='modula.op__wrapper__aten_add_Tensor__5202966352', overload='default')>: <__main__.RegFakeNormOp at 0x177e30d10>}

In [32]:
torch.ops.modula.op__constant_scaler_mul__aten_mul_Tensor__5206687248.default in REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP

False

In [30]:
list(REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP.keys())[0]

<OpOverload(op='modula.op__constant_scaler_mul__aten_mul_Tensor__5206687248', overload='default')>

In [566]:
nodes[-2].meta['norm_val']

RMS_NormTensor(
    norm_size=tensor(18.5000),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(s0, 16)),
)

In [537]:
REG_FAKE_NORM_OP_LOOKUP_VIA_CUSTOM_OP

{<OpOverloadPacket(op='modula.op_linear__4760847760')>: <__main__.RegFakeNormOp at 0x32bafd010>,
 <OpOverloadPacket(op='modula.op_aten_randn__5081374544')>: <__main__.RegFakeNormOp at 0x32baff090>,
 <OpOverloadPacket(op='modula.op_aten_add_Tensor__5074027472')>: <__main__.RegFakeNormOp at 0x32bd52350>}

In [379]:
nodes[0].meta['norm_val']

RMS_RMS_NormTensor(
    norm_size=tensor(4., requires_grad=True),
    elem_dims=(0, 1),
    unwrapped=FakeTensor(..., size=(16, 30), requires_grad=True),
)

In [380]:
nodes[1].meta['norm_val']

RMS_NormTensor(
    norm_size=tensor(2.5000, requires_grad=True),
    elem_dims=(0,),
    unwrapped=FakeTensor(..., size=(16,), requires_grad=True),
)

In [381]:
node = nodes[3]
node.meta['norm_val'] = node.target(
    *[arg.meta['norm_val'] if isinstance(arg, torch.fx.Node) else arg for arg in node.args],
    **{k: v.meta['norm_val'] for k, v in node.kwargs.items()}
).finalize(node.meta['val'])

RMS_NormTensor(
    norm_size=tensor(3., requires_grad=True),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(s0, 30)),
) RMS_RMS_NormTensor(
    norm_size=tensor(4., requires_grad=True),
    elem_dims=(0, 1),
    unwrapped=FakeTensor(..., size=(16, 30), requires_grad=True),
) RMS_NormTensor(
    norm_size=tensor(2.5000, requires_grad=True),
    elem_dims=(0,),
    unwrapped=FakeTensor(..., size=(16,), requires_grad=True),
)


In [382]:
nodes[3].meta['norm_val']

RMS_NormTensor(
    norm_size=tensor(14.5000, grad_fn=<AddBackward0>),
    elem_dims=(1,),
    unwrapped=FakeTensor(..., size=(s0, 16)),
)

In [383]:
torch.autograd.grad(nodes[3].meta['norm_val'].norm_size, [nodes[0].meta['norm_val'].norm_size, nodes[1].meta['norm_val'].norm_size, nodes[2].meta['norm_val'].norm_size], allow_unused=True)

(tensor(3.), tensor(1.), tensor(4.))

In [123]:
torch.ops.modula.op_linear__4760847760(torch.randn(10, 15, requires_grad=True), torch.randn(16, 15, requires_grad=True)).shape

torch.Size([10, 16])

In [83]:
torch._C.parse_schema("linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor")

<Signature (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor>

In [66]:
torch.nn.functional.linear

<function torch._C._nn.linear>

In [60]:
inspect.signature(torch.nn.functional.gelu)

ValueError: no signature found for builtin <built-in function gelu>

In [56]:
de@get_fake_norm_op().norm_register
def unsqueeze(input: RMS_NormTensor, dim: int) -> RMS_NormTensor:
    return RMS_NormTensor(input.norm_size, export_graph_node=get_current_export_graph_node())

In [57]:
get_fake_norm_op(test).aten_or_custom_op

<OpOverload(op='module.op_test__6275085504', overload='default')>

In [50]:
torch.library.infer_schema(torch.nn.functional.linear, mutates_args=())

ValueError: no signature found for builtin <built-in function linear>

In [30]:
get_fake_norm_op(torch.ops.aten.unsqueeze.default).aten_or_custom_op

<OpOverload(op='aten.unsqueeze', overload='default')>

In [21]:
@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
    return x.clone()

foo.register_fake(foo)
foo.register_fake(foo)

<CustomOpDef(mylib::foo)>

In [69]:
import functools
import inspect
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from torch.overrides import enable_reentrant_dispatch, TorchFunctionMode
from collections import defaultdict
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map
from typeguard import typechecked
from torch.fx.operator_schemas import (
    _torchscript_schema_to_signature,
)
import typing
import numbers



ENABLE_NORM_DISPATCH = True

@contextlib.contextmanager
def disable_norm_dispatch():
    global ENABLE_NORM_DISPATCH
    old_flag = ENABLE_NORM_DISPATCH
    ENABLE_NORM_DISPATCH = False
    yield
    ENABLE_NORM_DISPATCH = old_flag

HANDLED_FUNCTIONS: Dict[Callable, NormedTensorDispatcher] = {}
#defaultdict(lambda: Dispatcher(ignored_specialized_params=('_export_graph_node',)))


def get_output_fake_tensors(func, *args, **kwargs):
    # Create a fake mode
    fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
    def convert_from_real_tensor(x):
        if isinstance(x, torch.Tensor):
            return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, x)
        return x
    # Fakeify some real tensors
    with fake_mode, disable_norm_dispatch():
        args = tree_map(convert_from_real_tensor, args)
        kwargs = tree_map(convert_from_real_tensor, kwargs)
        return func(*args, **kwargs)


def implements(torch_op):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        # sig = inspect.signature(func)
        # func = typechecked(func)
        if torch_op not in HANDLED_FUNCTIONS:
            HANDLED_FUNCTIONS[torch_op] = NormedTensorDispatcher(torch_op, ignored_params=('_export_graph_node',))
        return HANDLED_FUNCTIONS[torch_op].register(func)
        # assert_specialized(torch_function, func)
        # dispatch_key = get_dispatch_key_func(torch_function, func)
        # functools.update_wrapper(func, torch_function)
        # assert torch_function not in HANDLED_FUNCTIONS
        # assert dispatch_key not in HANDLED_FUNCTIONS[torch_function]
        # HANDLED_FUNCTIONS[torch_function][dispatch_key] = func
        return func
    return decorator


class NormedOpRegularTensorMode(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        if ENABLE_NORM_DISPATCH:
            print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
            if any(issubclass(t, NormTensorBase) for t in types):
                return NotImplemented
        return func(*args, **(kwargs or {}))


class NormedTensorMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        if ENABLE_NORM_DISPATCH:
            print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
            if any(issubclass(t, NormTensorBase) for t in types):
                return NotImplemented
        return func(*args, **(kwargs or {}))




# class SubclassOnce(type):
#     _subclass_depth: int

#     def __new__(cls, name, bases, classdict):
#         restricted_bases = [b for b in bases if isinstance(b, SubclassOnce)]
#         if len(restricted_bases) > 0:
#             subclass_depth = max(b._subclass_depth for b in restricted_bases) + 1
#         else:
#             subclass_depth = 0
#         if subclass_depth > 1:
#             raise TypeError(f"Type {cls.__name__} has multiple base types {bases}")
#         classdict = dict(classdict)
#         assert '_subclass_depth' not in classdict
#         classdict['_subclass_depth'] = subclass_depth
#         return type.__new__(cls, name, bases, classdict)

#     @property
#     def subclassable(cls):
#         return cls._subclass_depth < 1


CURRENT_EXPORT_GRAPH_NODE: Optional[torch.fx.Node] = None

@contextlib.contextmanager
def processing_export_graph_node(node: torch.fx.Node):
    global CURRENT_EXPORT_GRAPH_NODE
    old_node = CURRENT_EXPORT_GRAPH_NODE
    CURRENT_EXPORT_GRAPH_NODE = node
    yield node
    CURRENT_EXPORT_GRAPH_NODE = old_node

def get_current_export_graph_node():
    assert CURRENT_EXPORT_GRAPH_NODE is not None, "No export graph node is being processed"
    return CURRENT_EXPORT_GRAPH_NODE

@implements(torch.ops.aten.unsqueeze.default)
def unsqueeze(input: RMS_NormTensor, dim: int) -> RMS_NormTensor:
    return RMS_NormTensor(input.norm_size, export_graph_node=get_current_export_graph_node())

@implements(torch.ops.aten.squeeze_.dim)
def squeeze(input: RMS_NormTensor, dim: int) -> RMS_NormTensor:
    return RMS_NormTensor(input.norm_size, export_graph_node=get_current_export_graph_node())

@implements(torch.ops.aten.permute.default)
def permute(input: RMS_RMS_NormTensor, dims: List[int]) -> RMS_RMS_NormTensor:
    assert input.ndim == 2
    return RMS_RMS_NormTensor(input.norm_size, export_graph_node=get_current_export_graph_node())

# @t.register
# def t(input: RMSNormTensor) -> RMSRMSNormTensor:
#     assert input.ndim == 2
#     print(input.size(), 'l')
#     return RMSRMSNormTensor(input.norm_size + 2, size=input.size()[::-1], dtype=input.dtype, device=input.device)

@implements(torch.ops.aten.addmm.default)
def addmm(input: RMS_NormTensor, mat1: RMS_NormTensor, mat2: RMS_RMS_NormTensor, *, beta: float = 1, alpha: float = 1) -> RMS_NormTensor:
    # output = input * beta + mat1 @ mat2 * alpha
    final_norm_size = input.norm_size * beta + mat1.norm_size * mat2.norm_size * alpha
    return RMS_NormTensor(final_norm_size, export_graph_node=get_current_export_graph_node())

@addmm.register
def _(input: RMS_NormTensor, mat1: RMS_RMS_NormTensor, mat2: RMS_NormTensor, *, beta: float = 1, alpha: float = 1) -> RMS_NormTensor:
    # output = input * beta + mat1 @ mat2 * alpha
    final_norm_size = input.norm_size * beta + mat1.norm_size * mat2.norm_size * alpha
    return RMS_NormTensor(final_norm_size, export_graph_node=get_current_export_graph_node())

@implements(torch.ops.aten.mm.default)
def mm(input: RMS_NormTensor, mat2: RMS_RMS_NormTensor) -> RMS_NormTensor:
    final_norm_size = input.norm_size * mat2.norm_size
    return RMS_NormTensor(final_norm_size, export_graph_node=get_current_export_graph_node())



In [None]:
# real
# fake
# normed

In [71]:
from torch._functorch.aot_autograd import aot_module_simplified
from torch.export import Dim

batch = Dim('batch', min=10)

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(15, 16)

    def forward(self, x):
        return self.linear(x.split(15, dim=1)[0])

ep = torch.export.export(
    nn.Linear(30, 16),
    # MyNet(),
    (torch.randn(10, 30, requires_grad=True),),
    dynamic_shapes=[{0: batch}]
)

# ep = ep.run_decompositions()

gm = ep.module()
# https://pytorch.org/docs/stable/export.ir_spec.html#node
# gm = aot_module(gm, (torch.randn(2, 15, requires_grad=True),))


In [72]:
print(gm)

GraphModule()



def forward(self, input):
    input, = fx_pytree.tree_flatten_spec(([input], {}), self._in_spec)
    weight = self.weight
    bias = self.bias
    input_1 = input
    linear = torch.ops.aten.linear.default(input_1, weight, bias);  input_1 = weight = bias = None
    return pytree.tree_unflatten((linear,), self._out_spec)
    
# To see more debug info, please use `graph_module.print_readable()`


In [73]:
nodes = list(ep.graph.nodes)
[
    node.format_node()
    for node in nodes
]


['%p_fn_weight : [num_users=1] = placeholder[target=p_fn_weight]',
 '%p_fn_bias : [num_users=1] = placeholder[target=p_fn_bias]',
 '%input : [num_users=1] = placeholder[target=input]',
 '%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_fn_weight, %p_fn_bias), kwargs = {})',
 'return (linear,)']

In [74]:
_torchscript_schema_to_signature(
    torch.ops.aten.addmm.default._schema
)

<Signature (input: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta: numbers.Number = 1, alpha: numbers.Number = 1) -> torch.Tensor>

In [79]:
nodes[0].meta['norm_val'] = RMS_RMS_NormTensor(4, backing_tensor=nodes[0].meta['val'])
nodes[1].meta['norm_val'] = RMS_NormTensor(2.5, backing_tensor=nodes[1].meta['val'])
nodes[2].meta['norm_val'] = RMS_NormTensor(3, backing_tensor=nodes[2].meta['val'])

In [88]:
nodes[2].meta['val'].numel()

300

In [81]:
with NormedOpRegularTensorMode(), NormedTensorMode(), processing_export_graph_node(nodes[3]) as node:
    node.meta['norm_val'] = node.target(
        *[arg.meta['norm_val'] if isinstance(arg, torch.fx.Node) else arg for arg in node.args],
        **{k: v.meta['norm_val'] for k, v in node.kwargs.items()}
    )


Dispatch Log: aten.linear.default(*(RMS_NormTensor(
    norm_size=tensor(3.),
    unwrapped=FakeTensor(..., size=(s0, 30), requires_grad=True),
), RMS_RMS_NormTensor(
    norm_size=tensor(4.),
    unwrapped=FakeTensor(..., size=(16, 30), requires_grad=True),
), RMS_NormTensor(
    norm_size=tensor(2.5000),
    unwrapped=FakeTensor(..., size=(16,), requires_grad=True),
)), **{})
Dispatch Log: aten.t.default(*(RMS_RMS_NormTensor(
    norm_size=tensor(4.),
    unwrapped=FakeTensor(..., size=(16, 30), requires_grad=True),
),), **{})


TypeError: Multiple dispatch failed for 'torch._ops.aten.t.default'; all __torch_dispatch__ handlers returned NotImplemented:

  - tensor subclass <class '__main__.RMS_RMS_NormTensor'>

For more information, try re-running with TORCH_LOGS=not_implemented

In [509]:
node.meta['norm_val']

RMS_RMS_NormTensor(
    norm_size=tensor(4.),
    fake=FakeTensor(..., size=(15, 16)),
    node=(
        %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_linear_weight, [1, 0]), kwargs = {})
    ),
)

In [338]:
node.args

(p_fn_bias, input, permute)

In [331]:
print(list(ep.graph.nodes))

node: torch.fx.node.Node
for node in ep.graph.nodes:
    print(node.format_node())

[p_fn_weight, p_fn_bias, input, permute, addmm, output]
%p_fn_weight : [num_users=1] = placeholder[target=p_fn_weight]
%p_fn_bias : [num_users=1] = placeholder[target=p_fn_bias]
%input : [num_users=1] = placeholder[target=input]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_fn_weight, [1, 0]), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_fn_bias, %input, %permute), kwargs = {})
return (addmm,)


In [322]:
node.args[0][0].meta['val'] = 34

In [332]:
list(ep.graph.nodes)[-2].meta['val']

FakeTensor(..., size=(s0, 16))

In [302]:
list(ep.graph.nodes)[-3].meta['val']

FakeTensor(..., size=(15, 16))

Dispatch Log: aten.full.default(*([], 3), **{'dtype': torch.float32, 'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.full.default(*([], 4), **{'dtype': torch.float32, 'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.t.default(*(RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True)),), **{})
Dispatch Log: aten.t.default(*(RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True)),), **{})
Dispatch Log: aten.unsqueeze.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), 0), **{})
Dispatch Log: aten.unsqueeze.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), 0), **{})
Dispatch Log: aten.mm.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True))), **{})
Dispatch Log: aten.mm.default(*(RMS_NormTensor(norm_size=tensor(3., requires_grad=True)), RMS_RMS_NormTensor(norm_size=tensor(4., requires_grad=True))), **{})
Dispatch Log: aten.mul.T

In [127]:
y = (x:=RMSNormTensor(3, requires_grad=True)) + RMSRMSNormTensor(4)
print(y, y.norm_size)

Dispatch Log: aten.add.Tensor(*(RMSNormTensor(norm_size=3, requires_grad=False), RMSRMSNormTensor(norm_size=4, requires_grad=False)), **{})
tensor(7., grad_fn=<AddBackward0>)
z: RMSNormTensor(norm_size=7, requires_grad=False) tensor(7., grad_fn=<AddBackward0>) [tensor(3., requires_grad=True), tensor(4.)] tensor(7., grad_fn=<AddBackward0>)
RMSNormTensor(norm_size=7, requires_grad=False) tensor(7., grad_fn=<AddBackward0>)


In [96]:
torch.autograd.grad(y.norm_size, [x.norm_size])

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [72]:
y = (x:=RMSNormTensor(3, requires_grad=True)) + RMSRMSNormTensor(4)
print(y, y.norm_size)
torch.autograd.backward([y])
y, x.grad



Dispatch Log: aten.add.Tensor(*(RMSNormTensor(norm_size=3.0), RMSRMSNormTensor(norm_size=4.0)), **{})
RMSNormTensor(norm_size=7.0) tensor(7.)
Dispatch Log: aten.ones_like.default(*(RMSNormTensor(norm_size=7.0),), **{'pin_memory': False, 'memory_format': torch.preserve_format})


ZeroDivisionError: division by zero

In [60]:
y%debu

In [33]:

torch.ops.aten.linear(torch.randn(10, 10), torch.randn(10, 10), torch.randn(10))

tensor([[-2.9366, -2.1385,  0.2499,  0.1769,  3.3125,  0.5611,  0.2110, -0.9544,
         -1.4054, -1.0516],
        [-1.3193,  2.8552, -2.4230,  2.6803,  4.9918,  3.8054, -1.6162, -3.6952,
         -3.6532, -1.4474],
        [-1.9989,  3.0092, -8.1097,  0.5183,  2.0837, -0.9222,  0.3195, -5.3470,
         -3.0957,  2.3705],
        [ 5.2813,  2.8777,  0.0169, -0.6770, -3.8537,  2.0995, -3.8720, -1.3664,
         -2.2780,  2.0936],
        [-0.3890, -0.7412,  2.1590,  0.5803, -0.2978,  0.4573, -0.6891,  0.1847,
          0.4717, -1.0158],
        [-4.2937,  4.5814,  0.6280,  0.9427, -4.7592, -3.3714, -5.1919, -6.9244,
         -0.7811,  4.4841],
        [-3.6631, -4.5173,  1.1486, -3.5798, -0.8339, -1.9358,  3.3141,  1.8897,
          1.1224,  1.5577],
        [ 2.7387, -2.0768,  0.9836, -0.7142,  3.7925, -3.0117, -0.3672, -1.0644,
          4.0226, -2.9036],
        [-3.9206,  0.3760,  3.8509, -0.3813, -1.1643,  0.2692, -2.0817, -1.4092,
         -0.0249,  0.6661],
        [ 7.7065,  