# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import torch
from torch import jit, Tensor, nn

from typing import Any

In [None]:
@jit.script
class spectral_norm(torch.autograd.Function):
    @staticmethod
    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
        r"""Jacobian-vector product."""
        u, v = ctx.saved_tensors
        return torch.outer(u, v) @ grad_inputs[0]

    def forward(A: Tensor, u: Tensor, v: Tensor, atol, rtol, maxiter) -> Tensor:
        ...

    @staticmethod
    def backward(ctx: Any, *grad_outputs: Tensor) -> Tensor:
        r"""Backward pass.

        Parameters
        ----------
        ctx
        grad_outputs
        """
        u, v = ctx.saved_tensors
        return grad_outputs[0] * torch.outer(u, v)

In [None]:
class SpectralNorm(nn.Module):
    ...

In [None]:
class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx: Any, x: Tensor) -> Tensor:
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tensor:
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        (x,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[x < 0] = 0
        return grad_input

In [None]:
myrelu = MyReLU.apply
myrelu.__module__ = "dummy__module__"

x = torch.randn(3, 4, 5)
scripted = jit.trace(myrelu, x)

In [None]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = scripted

    def forward(self, x: Tensor) -> Tensor:
        return self.layer(x)

In [None]:
model = MyNet()
model(torch.randn(4, 4))

In [None]:
scripted_model = jit.script(model)

In [None]:
scripted_model.save("here.torch")

In [None]:
scripted(torch.randn(1))

In [None]:
jit.save(scripted, "scripted.torch")