In [None]:
import torch
from torch import nn


class SimpleNN(nn.Module):
    def __init__(
        self,
        num_layers: int = 1,
        num_neurons: int = 5,
    ) -> None:
        """Basic neural network architecture with linear layers

        Args:
            num_layers (int, optional): number of hidden layers
            num_neurons (int, optional): neurons for each hidden layer
        """
        super().__init__()

        layers = []

        # input layer
        layers.append(nn.Linear(1, num_neurons))

        # hidden layers with linear layer and activation
        for _ in range(num_layers):
            layers.extend([nn.Linear(num_neurons, num_neurons), nn.Tanh()])

        # output layer
        layers.append(nn.Linear(num_neurons, 1))

        # build the network
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x.reshape(-1, 1)).squeeze()


In [None]:
x = torch.randn(10)
model = SimpleNN()


In [None]:
model


In [None]:
from torch.func import functional_call


In [None]:
params = dict(model.named_parameters())


In [None]:
params


In [None]:
out = functional_call(model, params, (x,))


In [None]:
out


In [None]:
from torch.func import grad


In [None]:
grad_fn = grad(model)


In [None]:
grad_fn


In [None]:
params = tuple(model.parameters())


In [None]:
grad_values = grad_fn(x[0], params)
