# Sample Template

Here is some sample text, and a first block of code

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.

In [None]:
from typing import Optional

import torch
from torch import Tensor, nn
from torch.autograd import grad

In [None]:
import logging
from abc import ABC, abstractmethod
from dataclasses import KW_ONLY, dataclass, field
from typing import ClassVar, Final, NamedTuple, Optional, Protocol

import torch
from torch import Tensor, dot, norm, rsqrt, sqrt
from torch.linalg import solve

logger = logging.getLogger(__name__)


class Solver(Protocol):
    """Protocol for solvers."""

    maxiter: int
    """Maximum number of iterations."""
    atol: float
    """Absolute tolerance."""
    rtol: float
    """Relative tolerance."""

    def __call__(self, x0: Tensor) -> Tensor:
        """Solve the linear system."""
        ...

    @staticmethod
    def step(state: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
        """Pure function representing the state transition."""
        ...


class State(Protocol):
    """Protocol for solver states."""

    x: Tensor
    """Current iterate."""


@dataclass
class BaseSolver(ABC):
    """Base class for solvers."""

    requires_transpose: ClassVar[bool]
    """Whether the solver requires the transpose of the operator."""
    requires_symmetric: ClassVar[bool]
    """Whether the solver requires the operator to be symmetric."""
    requires_positive_definite: ClassVar[bool]
    """Whether the solver requires the operator to be positive definite."""
    requires_finite_steps: ClassVar[bool]
    """Whether the solver terminates in a finite number of steps."""

    L: nn.Module
    """Linear operator."""
    y: Tensor
    """Right-hand side of the linear system."""

    _: KW_ONLY

    maxiter: Final[int] = 1000
    """Maximum number of iterations."""
    atol: Final[float] = 1e-8
    """Absolute tolerance."""
    rtol: Final[float] = 1e-5
    """Relative tolerance."""

    @abstractmethod
    def initial_state(self, x0: Optional[Tensor] = None) -> State:
        """Initialize the solver state."""
        ...

    @abstractmethod
    def step(self, state: State) -> State:
        """Perform a single step of the solver."""
        ...

    def condition(self, new_state: State, old_state: State) -> bool:
        """Check if the solver has converged."""
        x_new = new_state.x
        x_old = old_state.x
        return (x_new - x_old).norm() < self.atol * x_old.norm() + self.rtol

    def solve(self, x0: Optional[Tensor] = None) -> Tensor:
        """Solve the linear system."""
        state = self.initial_state(x0)

        for it in range(self.maxiter):
            new_state = self.step(state)
            converged = self.condition(new_state, state)
            state = new_state

            if converged:
                logger.info("Converged after %s iterations.", it)
                break
        else:
            logger.warning("No convergence after %s iterations.", self.maxiter)

        return state.x


class CGS_STATE(NamedTuple):
    """State of the conjugate gradient squared solver."""

    x: Tensor
    """Vector: Current iterate."""
    r: Tensor
    """Vector: Residual vector."""
    p: Tensor
    """Vector: Search direction."""
    u: Tensor
    """Vector: Auxiliary vector."""
    rho: Tensor
    """Scalar: Inner Product between r and rstar."""


class CGS_Solver(BaseSolver):
    """Conjugate Gradient Squared solver."""

    requires_transpose: ClassVar[bool] = NotImplemented
    requires_symmetric: ClassVar[bool] = NotImplemented
    requires_positive_definite: ClassVar[bool] = NotImplemented

    rstar: Tensor

    @torch.no_grad()
    def step(self, state: CGS_STATE) -> CGS_STATE:
        # unpack state
        x = state.x
        r = state.r
        p = state.p
        u = state.u
        rho_old = state.rho

        # perform iteration
        v = self.L(p)
        alpha = rho_old / dot(v, self.rstar)
        q = u - alpha * v
        x += alpha * (u + q)
        r -= alpha * self.L(u + q)
        rho = dot(r, self.rstar)
        beta = rho / rho_old
        u = r + beta * q
        p = u + beta * (q + beta * p)

        return CGS_STATE(x=x, r=r, p=p, u=u, rho=rho)

    def initial_state(self, x0: Optional[Tensor] = None) -> CGS_STATE:
        r0 = self.y - self.L(x0)
        p0 = r0.clone()
        u0 = r0.clone()
        rho0 = dot(r0, self.rstar)
        return CGS_STATE(x=x0, r=r0, p=p0, u=u0, rho=rho0)

In [None]:
L = nn.Linear(5, 5)
y = torch.randn(5)
x0 = torch.zeros(L.in_features)
solver = CGS_Solver(L, y)
solver.solve(x0)

In [None]:
with torch.no_grad():
    CGS_Solver(L, g).solve()

In [None]:
def solve(f: nn.Module, y: Tensor, x0: Optional[Tensor] = None) -> Tensor:
    """Given a linear function f, solve f(x)=y with initial guess x0."""
    x0 = torch.zeros(f.input_size) if x0 is None else x0
    return x0

In [None]:
class DEQ_Layer(torch.autograd.Function):
    @staticmethod
    def forward(f: nn.Module, x: Tensor) -> Tensor:
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        z = torch.zeros(f.hidden_size)
        with torch.no_grad():
            for k in range(100):
                z = f(x, z)
        return z.requires_grad_()

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        f, x = inputs
        z = outputs
        z0 = z.clone().detach().requires_grad_()
        f0 = f(x, z0).requires_grad_()
        ctx.save_for_backward(z0, f0)

    @staticmethod
    def backward(ctx, grad_output):
        z0, f0 = ctx.saved_tensors
        g = grad_output
        L = lambda y: grad(f0, z0, y, retain_graph=True)[0] + g
        return CGS_Solver(L, g).solve()


# Wrap MyCube in a function so that it is clearer what the output is
def deq_layer(f: nn.Module, x: Tensor):
    result = DEQ_Layer.apply(f, x)
    return result

In [None]:
f = nn.RNNCell(input_size=5, hidden_size=5)
x = torch.randn(5)

In [None]:
y = deq_layer(f, x).norm()

In [None]:
.backward()

In [None]:
x.requires_grad_()

In [None]:
z = torch.zeros(f.hidden_size)
z = f(x, z)

In [None]:
z = f(x, z)
z