# Test if caching works properly

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.

In [None]:
import torch
import torch.nn.utils.parametrize as P
from torch import nn

In [None]:
model = nn.RNN(5, 3)

In [None]:
x = torch.randn(10, 5)
y = torch.randn(10, 3)

In [None]:
model.zero_grad(set_to_none=True)
yhat, hn = model(x)
loss = ((y - yhat) ** 2).mean()
loss.backward()
model.weight_hh_l0.grad

### With parameterization

In [None]:
class MyRNN(nn.Module):
    parametrize: bool = False

    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.W = nn.Parameter(torch.randn(hidden_size, input_size))
        self.V = nn.Parameter(torch.randn(hidden_size, hidden_size))

        self.bias = nn.Parameter(torch.randn(hidden_size))
        self.act = nn.Tanh()
        self.register_buffer("h0", torch.zeros(hidden_size))

    def forward(self, X):
        h_list = []
        h = self.h0

        if parametrize:
            V = (self.V - self.V.T) / 2
        else:
            V = self.V

        for x in torch.moveaxis(X, -2, 0):
            w = torch.einsum("...j, ij -> ...i", x, self.W)
            v = torch.einsum("...j, ij -> ...i", h, V)
            h = self.act(w + v + self.bias)
            h_list.append(h)

        return torch.stack(h_list, dim=-2)

In [None]:
model = MyRNN(5, 3)

In [None]:
model.zero_grad(set_to_none=True)
yhat = model(x)
loss = ((y - yhat) ** 2).mean()
loss.backward()
model.V.grad

In [None]:
class skew_symmetric(nn.Module):
    def forward(self, x):
        return (x - x.T) / 2

In [None]:
P.register_parametrization(model, "weight_hh_l0", skew_symmetric())
model

In [None]:
model.zero_grad(set_to_none=True)
yhat, hn = model(x)
loss = ((y - yhat) ** 2).mean()
loss.backward()
model.parametrizations.weight_hh_l0.original.grad

In [None]:
...
with P.cached():
    output = model(inputs)