In [1]:
import torch

In [9]:
def test(x : torch.Tensor):
    return x

In [11]:
test(10)

10

In [None]:
class SSM(torch.nn.Module):
    """
    Creating a basic state-space model, using the bilinearization discretization of the parameters

    parameters:
        N: dimension, int
    """
    def __init__(self, N : int, dt : float):
        super().__init__()
        N = self.N
        self.dt = dt
        
        A = - 0.5 + 1j * torch.pi * torch.arange(N//2)
        self.A = torch.autograd.Variable(A, requires_grad=True)

        B = torch.ones(N//2) + 0j
        self.B = torch.autograd.Variable(B, requires_grad=True)

        C = torch.randn(N//2) + 1j*torch.randn(N//2)
        self.C = torch.autograd.Variable(C, requires_grad=True)

    def kernel(self, L : int):
        self.dA, self.dB = (1 + self.dt * self.A/2) / (1 - self.dt * self.A/2), self.dt * self.B / (1 - self.dt * self.A/2)
        K_bar = torch.real(
            torch.matmul(
                self.C * self.dB,
                self.dA[:, None] ** torch.arange(L)
            )
        )
        return K_bar

    def forward(self, u : torch.Tensor):
        L = u.shape[-1]
        K = self.kernel(L)

        K_f, u_f = torch.fft.fft(K, n = 2*L), torch.fft.fft(u, n = 2*L)
        return torch.fft.ifft(K_f * u_f, n = 2*L)[..., :L]
