In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

## LinODEnet v2


We add special treatment of covariates $u$

Inhomogeneous Linear Equation

$$ \dot{x}(t) = Ax(t) + f(t) $$


E.g. Linear State Space System

$$ \dot{x}(t) = Ax(t) + Bu(t) $$

Solution:

$$\begin{aligned} 
     x(t) &=  e^{A(t-t_0)}x_{t_0} + \int_{t_0}^t e^{A(t-s)} f(s) ds 
\\   x(t+∆t) &=  e^{A∆t}x_t + \int_{t}^{t+∆t} e^{A(t+∆t-s)} f(s) ds 
\\   x(t+∆t) &=  e^{A∆t}x_t + \int_{0}^{∆t} e^{A(∆t-∆τ)} f(t+{∆τ}) d{∆τ}  \qquad s=t+∆τ
\end{aligned}$$


Special cases:

1. $f(t) = b$ constant in $t$, then $⇝x(t+∆t)=e^{A∆t}x_t + \frac{e^{A∆t}-𝕀}{A}b$
2. $f(t) = a⋅t + b$ linear in $t$, then ?

Generalized Exponential Integral:


$$ E_n(x) = ∫ \frac{e^{-xt}}{t^n} dt = x^{n-1} Γ(1-n,x)$$

Misra Function: $φ_m(x) = E_{-m}(x)$

Block matrix trick:

$$ \exp \bigg(\begin{bmatrix}A& B \\ 0 & 0 \end{bmatrix}⋅t\bigg) 
= \begin{bmatrix}e^{At} & ∫_0^t e^{Aτ}Bdτ \\ 0 & 𝕀 \end{bmatrix}
= \begin{bmatrix}e^{At} & φ_1(At)B \\ 0 & 𝕀 \end{bmatrix}
$$ 



Formula: (Higham)


$$ x(t+∆t) = e^{A∆t}x_t + ∑_{k=1}^∞ φ_k(A∆t) u_k ∆t^k \qquad  u_k = 𝐃^{k-1}f(t)$$

Recursion:  $$ \varphi_{\ell}(z)=z \varphi_{\ell+1}(z)+\frac{1}{\ell !}, \quad \varphi_{0}(z)=e^{z}$$

So: $$φ_1(At) = \frac{e^{At}-𝕀}{At} \qquad φ_2(At) = \frac{e^{At} -At -𝕀}{(At)^2} \qquad φ_3(At) = \frac{e^{At}-½(At)^2 -At -𝕀}{(At)^3} $$

Truncating at $K=2$, i.e. $f(t)=ω⋅t+b$


$$ x(t+∆t) = e^{A∆t}x_t + \frac{e^{A∆t}-𝕀}{A {∆t}}{∆t} b +  \frac{e^{A∆t}-A-𝕀}{(A {∆t})^2}{∆t}^2ω$$

## Remark

The block matrix formula proves that, in some sense, we do not need to separately encode the predictions and the covariates.
A larger latent space is sufficient to model the same thing as inhomogeneous linear ODE, or even polynomial components.


### Theorem:

If $p(t)$ is polynomial, then:

$$ \dot{x} = Ax(t) + p(t) ⟺ \dot{z} = \tilde{A} z(t)$$ 



In [None]:
import linodenet
from linodenet.models.system import LinODECell
from torchinfo import summary

In [None]:
import logging
from typing import Any, Final, Optional, TypeVar

import torch
from linodenet.initializations.functional import FunctionalInitialization
from linodenet.models.embeddings import ConcatEmbedding, ConcatProjection
from linodenet.models.encoders import iResNet
from linodenet.models.filters import Filter, RecurrentCellFilter
from linodenet.models.system import LinODECell
from linodenet.projections import Projection
from linodenet.util import autojit, deep_dict_update, initialize_from_config
from torch import Tensor, jit, nn

__logger__ = logging.getLogger(__name__)


# @autojit
class LinODEnet(nn.Module):
    r"""Linear ODE Network is a FESD model.

    +---------------------------------------------------+--------------------------------------+
    | Component                                         | Formula                              |
    +===================================================+======================================+
    | Filter  `F` (default: :class:`~torch.nn.GRUCell`) | `\hat x_i' = F(\hat x_i, x_i)`       |
    +---------------------------------------------------+--------------------------------------+
    | Encoder `ϕ` (default: :class:`~iResNet`)          | `\hat z_i' = ϕ(\hat x_i')`           |
    +---------------------------------------------------+--------------------------------------+
    | System  `S` (default: :class:`~LinODECell`)       | `\hat z_{i+1} = S(\hat z_i', Δ t_i)` |
    +---------------------------------------------------+--------------------------------------+
    | Decoder `π` (default: :class:`~iResNet`)          | `\hat x_{i+1}  =  π(\hat z_{i+1})`   |
    +---------------------------------------------------+--------------------------------------+

    Attributes
    ----------
    input_size:  int
        The dimensionality of the input space.
    hidden_size: int
        The dimensionality of the latent space.
    output_size: int
        The dimensionality of the output space.
    ZERO: Tensor
        BUFFER: A constant tensor of value float(0.0)
    xhat_pre: Tensor
        BUFFER: Stores pre-jump values.
    xhat_post: Tensor
        BUFFER: Stores post-jump values.
    zhat_pre: Tensor
        BUFFER: Stores pre-jump latent values.
    zhat_post: Tensor
        BUFFER: Stores post-jump latent values.
    kernel: Tensor
        PARAM: The system matrix of the linear ODE component.
    encoder: nn.Module
        MODULE: Responsible for embedding `x̂→ẑ`.
    embedding: nn.Module
        MODULE: Responsible for embedding `x̂→ẑ`.
    system: nn.Module
        MODULE: Responsible for propagating `ẑ_t→ẑ_{t+∆t}`.
    decoder: nn.Module
        MODULE: Responsible for projecting `ẑ→x̂`.
    projection: nn.Module
        MODULE: Responsible for projecting `ẑ→x̂`.
    filter: nn.Module
        MODULE: Responsible for updating `(x̂, x_obs) →x̂'`.
    """

    name: Final[str] = __name__
    """str: The name of the model."""

    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__doc__": __doc__,
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": int,
        "hidden_size": int,
        "output_size": int,
        "System": LinODECell.HP,
        "Embedding": ConcatEmbedding.HP,
        "Projection": ConcatProjection.HP,
        "Filter": RecurrentCellFilter.HP | {"autoregressive": True},
        "Encoder": iResNet.HP,
        "Decoder": iResNet.HP,
    }
    r"""Dictionary of Hyperparameters."""

    # Constants
    input_size: Final[int]
    r"""CONST: The dimensionality of the inputs."""
    hidden_size: Final[int]
    r"""CONST: The dimensionality of the linear ODE."""
    output_size: Final[int]
    r"""CONST: The dimensionality of the outputs."""

    # Buffers
    zero: Tensor
    r"""BUFFER: A tensor of value float(0.0)"""
    xhat_pre: Tensor
    r"""BUFFER: Stores pre-jump values."""
    xhat_post: Tensor
    r"""BUFFER: Stores post-jump values."""
    zhat_pre: Tensor
    r"""BUFFER: Stores pre-jump latent values."""
    zhat_post: Tensor
    r"""BUFFER: Stores post-jump latent values."""
    timedeltas: Tensor
    """BUFFER: Stores the timedelta values."""

    # Parameters:
    kernel: Tensor
    r"""PARAM: The system matrix of the linear ODE component."""
    z0: Tensor
    r"""PARAM: The initial latent state."""

    # Sub-Modules
    # encoder: Any
    # r"""MODULE: Responsible for embedding `x̂→ẑ`."""
    # embedding: nn.Module
    # r"""MODULE: Responsible for embedding `x̂→ẑ`."""
    # system: nn.Module
    # r"""MODULE: Responsible for propagating `ẑ_t→ẑ_{t+∆t}`."""
    # decoder: nn.Module
    # r"""MODULE: Responsible for projecting `ẑ→x̂`."""
    # projection: nn.Module
    # r"""MODULE: Responsible for projecting `ẑ→x̂`."""
    # filter: nn.Module
    # r"""MODULE: Responsible for updating `(x̂, x_obs) →x̂'`."""

    def __init__(self, input_size: int, hidden_size: int, **HP: Any):
        super().__init__()
        self.CFG = HP = deep_dict_update(self.HP, HP)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = input_size

        HP["Encoder"]["input_size"] = hidden_size
        HP["Decoder"]["input_size"] = hidden_size
        HP["System"]["input_size"] = hidden_size
        HP["Filter"]["hidden_size"] = input_size
        HP["Filter"]["input_size"] = input_size
        HP["Embedding"]["input_size"] = input_size
        HP["Embedding"]["hidden_size"] = hidden_size
        HP["Projection"]["input_size"] = input_size
        HP["Projection"]["hidden_size"] = hidden_size

        # if HP["embedding_type"] == "linear":
        #     _embedding: nn.Module = nn.Linear(input_size, hidden_size)
        #     _projection: nn.Module = nn.Linear(hidden_size, input_size)
        # elif HP["embedding_type"] == "concat":
        #     _embedding = ConcatEmbedding(input_size, hidden_size)
        #     _projection = ConcatProjection(input_size, hidden_size)
        # else:
        #     raise NotImplementedError(
        #         f"{HP['embedding_type']=}" + "not in {'linear', 'concat'}"
        #     )

        # TODO: replace with add_module once supported!
        # self.add_module("embedding", _embedding)
        # self.add_module("encoder", HP["Encoder"](**HP["Encoder_cfg"]))
        # self.add_module("system", HP["System"](**HP["System_cfg"]))
        # self.add_module("decoder", HP["Decoder"](**HP["Decoder_cfg"]))
        # self.add_module("projection", _projection)
        # self.add_module("filter", HP["Filter"](**HP["Filter_cfg"]))
        __logger__.debug("%s Initializing Embedding %s", self.name, HP["Embedding"])
        self.embedding: nn.Module = initialize_from_config(HP["Embedding"])
        __logger__.debug("%s Initializing Embedding %s", self.name, HP["Embedding"])
        self.projection: nn.Module = initialize_from_config(HP["Projection"])
        __logger__.debug("%s Initializing Encoder %s", self.name, HP["Encoder"])
        self.encoder: nn.Module = initialize_from_config(HP["Encoder"])
        __logger__.debug("%s Initializing System %s", self.name, HP["Encoder"])
        self.system: nn.Module = initialize_from_config(HP["System"])
        __logger__.debug("%s Initializing Decoder %s", self.name, HP["Encoder"])
        self.decoder: nn.Module = initialize_from_config(HP["Decoder"])
        __logger__.debug("%s Initializing Filter %s", self.name, HP["Encoder"])
        self.filter: Filter = initialize_from_config(HP["Filter"])

        assert isinstance(self.system.kernel, Tensor)
        self.kernel = self.system.kernel
        self.z0 = nn.Parameter(torch.randn(self.hidden_size))

        # Buffers
        self.register_buffer("zero", torch.tensor(0.0), persistent=False)
        self.register_buffer("timedeltas", torch.tensor(()), persistent=False)
        self.register_buffer("xhat_pre", torch.tensor(()), persistent=False)
        self.register_buffer("xhat_post", torch.tensor(()), persistent=False)
        self.register_buffer("zhat_pre", torch.tensor(()), persistent=False)
        self.register_buffer("zhat_post", torch.tensor(()), persistent=False)

    @jit.export
    def forward(self, T: Tensor, X: Tensor) -> Tensor:
        r"""Signature: `[...,N]×[...,N,d] ⟶ [...,N,d]`.

        **Model Sketch**::

            ⟶ [ODE] ⟶ (ẑᵢ)                (ẑᵢ') ⟶ [ODE] ⟶
                       ↓                   ↑
                      [Ψ]                 [Φ]
                       ↓                   ↑
                      (x̂ᵢ) → [ filter ] → (x̂ᵢ')
                                 ↑
                              (tᵢ, xᵢ)

        Parameters
        ----------
        T: Tensor, shape=(...,LEN) or PackedSequence
            The timestamps of the observations.
        X: Tensor, shape=(...,LEN,DIM) or PackedSequence
            The observed, noisy values at times `t∈T`. Use ``NaN`` to indicate missing values.

        Returns
        -------
        X̂_pre: Tensor, shape=(...,LEN,DIM)
            The estimated true state of the system at the times `t⁻∈T` (pre-update).
        X̂_post: Tensor, shape=(...,LEN,DIM)
            The estimated true state of the system at the times `t⁺∈T` (post-update).

        References
        ----------
        - https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
        """
        BATCH_SIZE = X.shape[:-2]
        # prepend a single zero for the first iteration.
        pad_dim = list(BATCH_SIZE) + [1]
        pad = torch.zeros(pad_dim, device=T.device, dtype=T.dtype)
        DT = torch.diff(T, prepend=pad, dim=-1)  # (..., LEN) → (..., LEN)
        DT = DT.moveaxis(-1, 0)  # (..., LEN) → (LEN, ...)
        X = torch.moveaxis(X, -2, 0)  # (...,LEN,DIM) → (LEN,...,DIM)

        Zhat_pre: list[Tensor] = []
        Xhat_pre: list[Tensor] = []
        Xhat_post: list[Tensor] = []
        Zhat_post: list[Tensor] = []

        ẑ_post = self.z0

        for dt, x_obs in zip(DT, X):
            # Propagate the latent state forward in time.
            ẑ_pre = self.system(dt, ẑ_post)  # (...,), (...,LAT) -> (...,LAT)

            # Decode the latent state at the observation time.
            x̂_pre = self.projection(self.decoder(ẑ_pre))  # (...,LAT) -> (...,DIM)

            # Update the state estimate by filtering the observation.
            x̂_post = self.filter(x_obs, x̂_pre)  # (...,DIM), (..., DIM) → (...,DIM)

            # Encode the latent state at the observation time.
            ẑ_post = self.encoder(self.embedding(x̂_post))  # (...,DIM) → (...,LAT)

            # Save all tensors for later.
            Zhat_pre.append(ẑ_pre)
            Xhat_pre.append(x̂_pre)
            Xhat_post.append(x̂_post)
            Zhat_post.append(ẑ_post)

        self.xhat_pre = torch.stack(Xhat_pre, dim=-2)
        self.xhat_post = torch.stack(Xhat_post, dim=-2)
        self.zhat_pre = torch.stack(Zhat_pre, dim=-2)
        self.zhat_post = torch.stack(Zhat_post, dim=-2)
        self.timedeltas = DT.moveaxis(0, -1)

        return self.xhat_post

        # TODO: Control variables
        # xhat = self.control(xhat, u)
        # u: possible controls:
        #  1. set to value
        #  2. add to value
        # do these via indicator variable
        # u = (time, value, mode-indicator, col-indicator)
        # => apply control to specific column.

        # TODO: Smarter initialization
        # IDEA: The problem is the initial state of RNNCell is not defined and typically put equal
        # to zero. Staying with the idea that the Cell acts as a filter, that is updates the state
        # estimation given an observation, we could "trust" the original observation in the sense
        # that we solve the fixed point equation h0 = g(x0, h0) and put the solution as the initial
        # state.
        # issue: if x0 is really sparse this is useless.
        # better idea: we probably should go back and forth.
        # other idea: use a set-based model and put h = g(T,X), including the whole TS.
        # This set model can use triplet notation.
        # bias weighting towards close time points

In [None]:
summary(LinODEnet(2, 3))

## LinODEnet v2

In [None]:
import logging
from typing import Any, Final, overload

import torch
from linodenet.initializations.functional import FunctionalInitialization
from linodenet.models.embeddings import ConcatEmbedding, ConcatProjection
from linodenet.models.encoders import iResNet
from linodenet.models.filters import Filter, RecurrentCellFilter
from linodenet.models.system import LinODECell
from linodenet.projections import Projection
from linodenet.util import autojit, deep_dict_update, initialize_from_config
from torch import Tensor, jit, nn

__logger__ = logging.getLogger(__name__)


# @autojit
class LinODEnet(nn.Module):
    HP = {
        "__name__": __qualname__,  # type: ignore[name-defined]
        "__doc__": __doc__,
        "__module__": __module__,  # type: ignore[name-defined]
        "input_size": int,
        "hidden_size": int,
        "output_size": int,
        "System": LinODECell.HP,
        "Embedding": ConcatEmbedding.HP,
        "Projection": ConcatProjection.HP,
        "Filter": RecurrentCellFilter.HP | {"autoregressive": True},
        "Encoder": iResNet.HP,
        "Decoder": iResNet.HP,
    }
    r"""Dictionary of Hyperparameters."""

    # Constants
    input_size: Final[int]
    r"""CONST: The dimensionality of the inputs."""
    hidden_size: Final[int]
    r"""CONST: The dimensionality of the linear ODE."""
    output_size: Final[int]
    r"""CONST: The dimensionality of the outputs."""

    # Buffers
    zero: Tensor
    r"""BUFFER: A tensor of value float(0.0)"""
    xhat_pre: Tensor
    r"""BUFFER: Stores pre-jump values."""
    xhat_post: Tensor
    r"""BUFFER: Stores post-jump values."""
    zhat_pre: Tensor
    r"""BUFFER: Stores pre-jump latent values."""
    zhat_post: Tensor
    r"""BUFFER: Stores post-jump latent values."""
    timedeltas: Tensor
    """BUFFER: Stores the timedelta values."""

    # Parameters:
    kernel: Tensor
    r"""PARAM: The system matrix of the linear ODE component."""
    z0: Tensor
    r"""PARAM: The initial latent state."""

    @classmethod
    def from_config(
        cls,
        *,
        input_size: int,
        hidden_size: int,
        output_size: int,
        System: LinODECell.HP,
        Embedding: ConcatEmbedding.HP,
        Projection: ConcatProjection.HP,
        Filter: RecurrentCellFilter.HP | {"autoregressive": True},
        Encoder: iResNet.HP,
        Decoder: iResNet.HP,
    ):
        ...

    def __new__(cls, *args, **kwargs):
        ...

    def __init__(
        self,
        System: nn.Module,
        Filter: nn.Module,
        Encoder: nn.Module,
        Embedding: nn.Module,
        Decoder: Optional[nn.Module] = None,
        Projection: Optional[nn.Module] = None,
    ):
        """If Projection is None, will assume to use Embedding.inverse."""
        super().__init__()

        # Register Modules
        self.register_module("embedding", Embedding)
        self.register_module("projection", Projection)
        self.register_module("encoder", Encoder)
        self.register_module("system", System)
        self.register_module("decoder", Decoder)
        self.register_module("filter", Filter)

        # Register Buffers
        self.register_buffer("zero", torch.tensor(0.0), persistent=False)
        self.register_buffer("timedeltas", torch.tensor(()), persistent=False)
        self.register_buffer("xhat_pre", torch.tensor(()), persistent=False)
        self.register_buffer("xhat_post", torch.tensor(()), persistent=False)
        self.register_buffer("zhat_pre", torch.tensor(()), persistent=False)
        self.register_buffer("zhat_post", torch.tensor(()), persistent=False)

        # Register Parameters
        self.register_parameter("z0", nn.Parameter(torch.randn(self.hidden_size)))

        # self.register_parameter("kernel", self.system.k)
        # assert isinstance(self.system.kernel, Tensor)
        # self.kernel = self.system.kernel
        # self.z0 = nn.Parameter(torch.randn(self.hidden_size))

    def __init__(self, input_size: int, hidden_size: int, **HP: Any):
        super().__init__()
        self.CFG = HP = deep_dict_update(self.HP, HP)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = input_size

        HP["Encoder"]["input_size"] = hidden_size
        HP["Decoder"]["input_size"] = hidden_size
        HP["System"]["input_size"] = hidden_size
        HP["Filter"]["hidden_size"] = input_size
        HP["Filter"]["input_size"] = input_size
        HP["Embedding"]["input_size"] = input_size
        HP["Embedding"]["hidden_size"] = hidden_size
        HP["Projection"]["input_size"] = input_size
        HP["Projection"]["hidden_size"] = hidden_size

    @jit.export
    def forward(self, T: Tensor, X: Tensor) -> Tensor:
        BATCH_SIZE = X.shape[:-2]
        # prepend a single zero for the first iteration.
        pad_dim = list(BATCH_SIZE) + [1]
        pad = torch.zeros(pad_dim, device=T.device, dtype=T.dtype)
        DT = torch.diff(T, prepend=pad, dim=-1)  # (..., LEN) → (..., LEN)
        DT = DT.moveaxis(-1, 0)  # (..., LEN) → (LEN, ...)
        X = torch.moveaxis(X, -2, 0)  # (...,LEN,DIM) → (LEN,...,DIM)

        Zhat_pre: list[Tensor] = []
        Xhat_pre: list[Tensor] = []
        Xhat_post: list[Tensor] = []
        Zhat_post: list[Tensor] = []

        ẑ_post = self.z0

        for dt, x_obs in zip(DT, X):
            # Propagate the latent state forward in time.
            ẑ_pre = self.system(dt, ẑ_post)  # (...,), (...,LAT) -> (...,LAT)

            # Decode the latent state at the observation time.
            x̂_pre = self.projection(self.decoder(ẑ_pre))  # (...,LAT) -> (...,DIM)

            # Update the state estimate by filtering the observation.
            x̂_post = self.filter(x_obs, x̂_pre)  # (...,DIM), (..., DIM) → (...,DIM)

            # Encode the latent state at the observation time.
            ẑ_post = self.encoder(self.embedding(x̂_post))  # (...,DIM) → (...,LAT)

            # Save all tensors for later.
            Zhat_pre.append(ẑ_pre)
            Xhat_pre.append(x̂_pre)
            Xhat_post.append(x̂_post)
            Zhat_post.append(ẑ_post)

        self.xhat_pre = torch.stack(Xhat_pre, dim=-2)
        self.xhat_post = torch.stack(Xhat_post, dim=-2)
        self.zhat_pre = torch.stack(Zhat_pre, dim=-2)
        self.zhat_post = torch.stack(Zhat_post, dim=-2)
        self.timedeltas = DT.moveaxis(0, -1)

        return self.xhat_post

In [None]:
from typing import Final, Optional

import torch
from torch import Tensor, jit, nn


# @autojit
class ConcatProjection(nn.Module):
    r"""Maps `z = [x,w] ⟼ x`."""

    # Constants
    input_size: Final[int]
    r"""CONST: The dimensionality of the inputs."""
    hidden_size: Final[int]
    r"""CONST: The dimensionality of the outputs."""
    pad_size: Final[int]
    r"""CONST: The size of the padding."""

    # Parameters
    padding: Tensor
    r"""PARAM: The padding vector."""

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        padding: Optional[Tensor] = None,
        # inverted: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()

        if not input_size >= hidden_size:
            raise ValueError(
                f"ConcatProjection requires {input_size=} ≥ {hidden_size=}!"
            )

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.pad_size = input_size - hidden_size

        if padding is None:
            padding = nn.Parameter(torch.randn(self.pad_size))
        elif not isinstance(padding, nn.Parameter):
            padding = nn.Parameter(padding)

        self.register_parameter("padding", padding)

    #         if inverted is None:
    #             inverted = ConcatEmbedding(
    #                 input_size=self.hidden_size,
    #                 hidden_size=self.input_size,
    #                 padding=self.padding,
    #                 inverted=self,
    #             )
    #         self.inverted = inverted

    @property
    def inverted(self):
        return ConcatEmbedding(
            input_size=self.hidden_size,
            hidden_size=self.input_size,
            padding=self.padding,
        )

    @jit.export
    def __invert__(self):
        return self.inverted

    @jit.export
    def forward(self, Z: Tensor) -> Tensor:
        r"""Signature: `[..., d+e] ⟶ [..., d]`."""
        return Z[..., : self.hidden_size]

    @jit.export
    def inverse(self, X: Tensor) -> Tensor:
        r"""Signature: `[..., d] ⟶ [..., d+e]`."""
        shape = list(X.shape[:-1]) + [self.pad_size]
        return torch.cat([X, self.padding.expand(shape)], dim=-1)


# @autojit
class ConcatEmbedding(nn.Module):
    r"""Maps `x ⟼ [x,w]`."""

    # Constants|
    input_size: Final[int]
    r"""CONST: The dimensionality of the inputs."""
    hidden_size: Final[int]
    r"""CONST: The dimensionality of the outputs."""
    pad_size: Final[int]
    r"""CONST: The size of the padding."""

    # Parameters
    padding: Tensor
    r"""PARAM: The padding vector."""

    def __init__(
        self,
        *,
        input_size: int,
        hidden_size: int,
        padding: Optional[Tensor] = None,
        # inverted: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()

        if not input_size <= hidden_size:
            raise ValueError(
                f"ConcatEmbedding requires {input_size=} ≤ {hidden_size=}!"
            )

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.pad_size = hidden_size - input_size

        padding = None
        if padding is None:
            padding = nn.Parameter(torch.randn(self.pad_size))
        elif not isinstance(padding, nn.Parameter):
            padding = nn.Parameter(padding)

        self.register_parameter("padding", padding)

        # if inverted is None:
        #     inverted = ConcatProjection(
        #         input_size=self.hidden_size,
        #         hidden_size=self.input_size,
        #         padding=self.padding,
        #         inverted=self,
        #     )
        # self.inverted = inverted

    @property
    def inverted(self):
        return ConcatProjection(
            input_size=self.hidden_size,
            hidden_size=self.input_size,
            padding=self.padding,
        )

    @jit.export
    def __invert__(self):
        return self.inverted

    @jit.export
    def forward(self, X: Tensor) -> Tensor:
        r"""Signature: `[..., d] ⟶ [..., d+e]`."""
        shape = list(X.shape[:-1]) + [self.pad_size]
        return torch.cat([X, self.padding.expand(shape)], dim=-1)

    @jit.export
    def inverse(self, Z: Tensor) -> Tensor:
        r"""Signature: `[..., d+e] ⟶ [..., d]`."""
        return Z[..., : self.input_size]

In [None]:
model = ConcatEmbedding(input_size=4, hidden_size=16)
# assert model is ~(~model)

model = ConcatProjection(input_size=16, hidden_size=4)
# assert model is ~(~model)

print(model)
jit.script(model)  # RecursionError: maximum recursion depth exceeded

In [None]:
jit.script(model);

In [None]:
model(torch.randn(8, 4))
model.inverse(torch.randn(3, 8))

In [None]:
inverted = model.invert()
original = inverted.invert()

In [None]:
model = ConcatProjection(16, 4)

model.inverse(torch.randn(8, 4))
model.forward(torch.randn(3, 16))
# model.inverse(torch.randn(8, 4))