#  [JIT] Infinite RecursionError with self-referential models (also affects __repr__)!! #76775 

## https://github.com/pytorch/pytorch/issues/76775

In [None]:
from typing import Final, Optional

import torch
from torch import Tensor, jit, nn

In [None]:
class ConcatProjection(nn.Module):
    r"""Maps `z = [x,w] ⟼ x`."""

    # Constants
    input_size: Final[int]
    r"""CONST: The dimensionality of the inputs."""
    hidden_size: Final[int]
    r"""CONST: The dimensionality of the outputs."""
    pad_size: Final[int]
    r"""CONST: The size of the padding."""

    # Parameters
    padding: Tensor
    r"""PARAM: The padding vector."""

    inverted: Tensor
    r"""BUFFER: Whether module is in forward or reverse"""

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        padding: Optional[Tensor] = None,
        inverted: bool = False,
    ) -> None:
        super().__init__()

        if not input_size >= hidden_size:
            raise ValueError(
                f"ConcatProjection requires {input_size=} ≥ {hidden_size=}!"
            )

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.pad_size = input_size - hidden_size

        if padding is None:
            padding = nn.Parameter(torch.randn(self.pad_size))
        elif not isinstance(padding, nn.Parameter):
            padding = nn.Parameter(padding)

        self.register_parameter("padding", padding)
        self.register_buffer("inverted", torch.tensor(inverted, dtype=bool))

    #     @jit.export
    #     def yaya(self):
    #         return self

    #     @jit.export
    #     def __invert__(self) -> None:
    #         self.inverted  = ~self.inverted

    @jit.export
    def _forward(self, Z: Tensor) -> Tensor:
        r"""Signature: `[..., d+e] ⟶ [..., d]`."""
        return Z[..., : self.hidden_size]

    @jit.export
    def _inverse(self, X: Tensor) -> Tensor:
        r"""Signature: `[..., d] ⟶ [..., d+e]`."""
        shape = list(X.shape[:-1]) + [self.pad_size]
        return torch.cat([X, self.padding.expand(shape)], dim=-1)

    @jit.export
    def forward(self, Z: Tensor) -> Tensor:
        r"""Signature: `[..., d+e] ⟶ [..., d]`."""
        if self.inverted:
            return self._inverse(Z)
        return self._forward(Z)

    @jit.export
    def inverse(self, X: Tensor) -> Tensor:
        r"""Signature: `[..., d] ⟶ [..., d+e]`."""
        if self.inverted:
            return self._forward(X)
        return self._inverse(X)

In [None]:
model = ConcatProjection(3, 2)
scripted_model = jit.script(model)

print(set(dir(model)) - set(dir(scripted_model)))

jit.save(scripted_model, "model.pt")
loaded_model = jit.load("model.pt")

print(set(dir(model)) - set(dir(loaded_model)))

In [None]:
from typing import Final

from torch import Tensor, jit, nn


class Foo(nn.Module):
    const: Final[bool]
    """Some important COSNTANT"""

    def __init__(self, const):
        super().__init__()
        self.const = const

    @jit.export
    def forward(self, x: Tensor) -> Tensor:
        return x

    @jit.export
    def exported_method(self, x: Tensor) -> Tensor:
        return x


model = Foo(const=False)
scripted_model = jit.script(model)
jit.save(scripted_model, "model.pt")
loaded_model = jit.load("model.pt")

print(set(dir(model)) - set(dir(scripted_model)))
print(set(dir(model)) - set(dir(loaded_model)))

for obj in (model, scripted_model, loaded_model):
    for attr in ("const", "forward", "exported_method", "training"):
        assert hasattr(obj, attr)

In [None]:
loaded_model

In [None]:
model.training

In [None]:
loaded_model.__invert__()
loaded_model.inverted

In [None]:
~model

In [None]:
model.yaya