In [1]:
import torch
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math

from collections.abc import Callable

import warnings

In [2]:
def random_SSM(N : int) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    A = torch.autograd.Variable(torch.rand(size=(N,N)), requires_grad = True)
    B = torch.autograd.Variable(torch.rand(size=(N,1)), requires_grad = True)
    C = torch.autograd.Variable(torch.rand(size=(1,N)), requires_grad = True)
    D = torch.autograd.Variable(torch.rand(size=(1,1)), requires_grad = True)
    return A, B, C, D

In [3]:
A, B, C, D = random_SSM(10)


In [4]:
D = torch.zeros((1,1))

In [5]:
delta = torch.tensor(0.01)

In [6]:
def discretize(
    A : torch.Tensor, B : torch.Tensor, C : torch.Tensor, D : torch.Tensor, delta : torch.Tensor
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    """Discretizes SSM using bilinear model

    parameters:
        A: (NxN) transition matrix in latent
        B: (Nx1) projection matrix to latent
        C: (1xN) projection matrix from latent to output
        D: (1x1) skip connection from input to output
        delta: time step, ensure sufficient smallness
    """
    Cbar = C
    Dbar = D
    N = A.shape[0]
    Bl = torch.linalg.inv(torch.eye(N) - delta / 2 * A)
    Abar = Bl@(torch.eye(N) + delta/2 * A)
    Bbar = Bl@(delta*B)
    return Abar, Bbar, Cbar, Dbar

In [7]:
Abar, Bbar, Cbar, Dbar = discretize(A, B, C, D, delta)

In [8]:
T = 100
num_steps = int(T/delta)

u = torch.cos(torch.arange(num_steps))

In [9]:
def scan_SSM(
    Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, Db : torch.Tensor,  u : torch.Tensor, x0 : torch.Tensor
) -> torch.Tensor:
    """
    computes steps of the SSM going forward.

    parameters:
        Ab : (NxN) transition matrix in discrete space of latent to latent
        Bb : (Nx1) projcetion matrix from input to latent space
        Cb : (1xN) projection matrix from latent to output
        Db : (1x1) skip connection input to output
        u  : (L,)  trajectory we are trying to track
        x0 : (Nx1) initial condition of latent
    """
    x0 = torch.zeros((10,1))
    x = torch.zeros((Ab.shape[0], len(u[:100])))
    y = torch.zeros_like(u[:100])
    for i in range(u[:100].shape[0]):
        x[:,i] = (Ab@x0 + Bb*u[i]).squeeze()
        y[i] = (Cb@x[:,i]).squeeze()
        x0 = x[:,i].unsqueeze(-1)
    return x, y

In [10]:
def K_conv(Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, L : int) -> torch.Tensor:
    """
    computes convolution window given L time steps using equation K_t = Cb @ (Ab^t) @ Bb. 
    Needs to be flipped for correct causal convolution, but can be used as is in fft mode

    parameters:
        Ab : transition matrix
        Bb : projection matrix from input to latent
        Cb : projection matrix from latent to input
        Db : skip connection
        L  : length over which we want convolutional window
    """
    return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)])

In [11]:
def causal_conv(u : torch.Tensor, K : torch.Tensor, notfft : bool = False) -> torch.Tensor:
    """
    computes 1-d causal convolution either using standard method or fft transform.

    parameters:
        u : trajectory to convolve
        K : convolutional filter
        notfft: boolean, for whether or not we use fft mode or not.
    """
    assert len(u.shape)==1
    assert K.shape==u.shape
    
    L = u.shape[0]
    powers_of_2 = 2**int(math.ceil(math.log2(2*L)))

    if notfft:
        padded_u = torch.nn.functional.pad(u, (L-1,L-1))
        convolve = torch.zeros_like(u)
        for i in range(L):
            convolve[i] = torch.sum(padded_u[i:i+L]*K.flip(dims=[0]))
        return convolve
    else:

        K_pad = torch.nn.functional.pad(K, (0, L))
        u_pad = torch.nn.functional.pad(u, (0, L))
        
        K_f, u_f = torch.fft.rfft(K_pad, n = powers_of_2), torch.fft.rfft(u_pad, n = powers_of_2)
        return torch.fft.irfft(K_f * u_f, n = powers_of_2)[:L]

In [12]:
K = K_conv(Abar, Bbar, Cbar, 100)

In [13]:
conv_fft = causal_conv(u[:100], K)
conv_notfft = causal_conv(u[:100],K , notfft=True)

In [14]:
x, y = scan_SSM(Abar, Bbar, Cbar, Dbar, u[:100], torch.zeros((10,1)))

In [15]:
print((abs(conv_fft - conv_notfft)<1e-5).all())
print((abs(conv_fft - y)<1e-5).all())

tensor(True)
tensor(True)


In [16]:
def log_step_initializer(dt_min = 0.001, dt_max = 0.1):
    """
    initial guess for dt, from random number generator. to be learned.

    parameters:
        dt_min
        dt_max
    """
    return torch.autograd.Variable(torch.rand(1) * (torch.log(dt_max) - torch.log(dt_min)) + torch.log(dt_min), requires_grad = True)

In [19]:
class SSMLayer(torch.nn.Module):
    """
    Simple layer that does SSMing. Assumes single input, single output. 
    Could be made multi-dimensional either by stacking and decorrelating,
    or by playing with the code to allow for multi input, multioutput. Should be relatively easy, 
    but need to carefully think a little about convolution of multi dim inputs.
    """
    def __init__(
        self,
        latent_dim,
        dt_min = 0.001,
        dt_max = 0.1,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.A, self.B, self.C, self.D = self.random_SSM(latent_dim)
        self.dt = self.log_step_initializer(dt_min, dt_max)
        self.Abar, self.Bbar, self.Cbar, self.Dbar = self.discretize(self.A, self.B, self.C, self.D, self.dt)


    def random_SSM(
        self, 
        N : int
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """
        initializing SSM parameters given latent dim
        
        parameters:
            N : size of latent dimension
        """
        A = torch.autograd.Variable(torch.rand(size=(N,N)), requires_grad = True)
        B = torch.autograd.Variable(torch.rand(size=(N,1)), requires_grad = True)
        C = torch.autograd.Variable(torch.rand(size=(1,N)), requires_grad = True)
        D = torch.autograd.Variable(torch.rand(size=(1,1)), requires_grad = True)
        return A, B, C, D

    def log_step_initializer(self, dt_min = 0.001, dt_max = 0.1):
        """
        initial guess for dt, from random number generator. to be learned.
    
        parameters:
            dt_min
            dt_max
        """
        return torch.autograd.Variable(torch.rand(1) * (torch.log(torch.tensor(dt_max)) - torch.log(torch.tensor(dt_min))) + torch.log(torch.tensor(dt_min)), requires_grad = True)

    def discretize(
        self, A : torch.Tensor, B : torch.Tensor, C : torch.Tensor, D : torch.Tensor, delta : torch.Tensor
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """Discretizes SSM using bilinear model
    
        parameters:
            A: (NxN) transition matrix in latent
            B: (Nx1) projection matrix to latent
            C: (1xN) projection matrix from latent to output
            D: (1x1) skip connection from input to output
            delta: time step, ensure sufficient smallness
        """
        Cbar = C
        Dbar = D
        N = A.shape[0]
        Bl = torch.linalg.inv(torch.eye(N) - delta / 2 * A)
        Abar = Bl@(torch.eye(N) + delta/2 * A)
        Bbar = Bl@(delta*B)
        return Abar, Bbar, Cbar, Dbar

    def scan_SSM(
        self, Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, Db : torch.Tensor,  u : torch.Tensor, x0 : torch.Tensor
    ) -> torch.Tensor:
        """
        computes steps of the SSM going forward.
    
        parameters:
            Ab : (NxN) transition matrix in discrete space of latent to latent
            Bb : (Nx1) projcetion matrix from input to latent space
            Cb : (1xN) projection matrix from latent to output
            Db : (1x1) skip connection input to output
            u  : (L,)  trajectory we are trying to track
            x0 : (Nx1) initial condition of latent
        """
        x0 = torch.zeros((10,1))
        x = torch.zeros((Ab.shape[0], len(u)))
        y = torch.zeros_like(u)
        for i in range(u.shape[0]):
            x[:,i] = (Ab@x0 + Bb*u[i]).squeeze()
            y[i] = (Cb@x[:,i]).squeeze()
            x0 = x[:,i].unsqueeze(-1)
        return x, y
        
    def K_conv(self, Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, L : int) -> torch.Tensor:
        """
        computes convolution window given L time steps using equation K_t = Cb @ (Ab^t) @ Bb. 
        Needs to be flipped for correct causal convolution, but can be used as is in fft mode
    
        parameters:
            Ab : transition matrix
            Bb : projection matrix from input to latent
            Cb : projection matrix from latent to input
            Db : skip connection
            L  : length over which we want convolutional window
        """
        return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)])

    def causal_conv(u : torch.Tensor, K : torch.Tensor, notfft : bool = False) -> torch.Tensor:
        """
        computes 1-d causal convolution either using standard method or fft transform.
    
        parameters:
            u : trajectory to convolve
            K : convolutional filter
            notfft: boolean, for whether or not we use fft mode or not.
        """
        assert K.shape==u.shape
        
        L = u.shape[0]
        powers_of_2 = 2**int(math.ceil(math.log2(2*L)))
    
        if notfft:
            padded_u = torch.nn.functional.pad(u, (L-1,L-1))
            convolve = torch.zeros_like(u)
            for i in range(L):
                convolve[i] = torch.sum(padded_u[i:i+L]*K.flip(dims=[0]))
            return convolve
        else:
    
            K_pad = torch.nn.functional.pad(K, (0, L))
            u_pad = torch.nn.functional.pad(u, (0, L))
            
            K_f, u_f = torch.fft.rfft(K_pad, n = powers_of_2), torch.fft.rfft(u_pad, n = powers_of_2)
            return torch.fft.irfft(K_f * u_f, n = powers_of_2)[:L]

    def forward(
        self,
        u : torch.Tensor,
        x0 : torch.Tensor = torch.zeros((1,1)),
        mode : bool | str = False
    ) -> torch.Tensor:
        """
        forward pass of model

        Parameters:
            u  : input time series
            x0 : initial condition, only used in recurrent mode
            mode: recurrent mode ("recurrent"), or convolution mode (True : direct convolution, False : fourier transform)
        """
        if mode == "recurrent":
            return self.scan_SSM(self.Abar, self.Bbar, self.Cbar, u, x0)[1]
        else:
            K = self.K_conv(self.Abar, self.Bbar, self.Cbar, u.shape[0])
            return self.causal_conv(u, K, mode)

In [20]:
def make_HiPPO(N : int) -> torch.Tensor:
    """
    creates HiPPO matrix for legendre polynomials up to order N
    parameters:
        N: int
    """
    P = torch.sqrt(1+2*torch.arange(N))
    A = P.unsqueeze(1) * P.unsqueeze(0)
    A = torch.tril(A) - torch.diag(torch.arange(N))
    return -A

In [21]:
def K_gen_inverse(
    Abar : torch.Tensor, Bbar : torch.Tensor, Cbar : torch.Tensor, L : int
) -> torch.Tensor:
    """
    creates generating function for convolutional window, to be evaluated at roots of unity
    parameters:
        Abar : discretized A matrix
        Bbar : discretized B matrix
        Cbar : discretized C matrix
        L    : length of convolutional window
    """
    Abar = Abar.to(torch.complex64)
    Bbar = Bbar.to(torch.complex64)
    Cbar = Cbar.to(torch.complex64)
    
    I = torch.eye(Abar.shape[0]).to(torch.complex64)
    Al = torch.matrix_power(Abar, L)
    Ctilde = Cbar @ (I - (Al))
    return lambda z: (torch.conj(Ctilde)@(torch.linalg.inv(I-(Abar * z)))@Bbar).squeeze()

In [22]:
def conv_from_gen(gen : Callable, L : int):
    """
    returns convolution from generating function by evaluating at roots of unity

    parameters:
        gen : generating function
        L   : int
    """
    omega_L = torch.exp(-2j * torch.pi * torch.arange(L)/L)
    atRoots = torch.tensor([gen(omega) for omega in omega_L])
    return torch.fft.irfft(atRoots, L).squeeze()
    

In [23]:
def cauchy(v : torch.Tensor, omega : torch.Tensor, lambd : torch.Tensor) -> torch.Tensor:
    """
    helper function for calculating cauchy kernel in generating function

    parameters:
        v : a dot product vector, relying on DPLR representation of SSM matrices
        omega : complex poles
        lambd : diagonal values of A matrix stand-in
    """
    cauchy_dot = lambda _omega: (v/(omega-lamb)).sum()
    return cauchy_dot
    

In [24]:
def K_gen_DPLR(
    Lambda : torch.Tensor, 
    P : torch.Tensor, 
    Q : torch.Tensor, 
    B: torch.Tensor, 
    C : torch.Tensor, 
    delta : torch.Tensor, 
    L : int
)-> torch.Tensor:
    """
    computes convolution kernel from generating function using DPLR representation and
    the cauchy kernel

    Parameters:
        Lambda : diagonal part of A
        P : Nx1 matrix, rank 1 representation to A
        Q : Nx1 matrix, rank 1 representation to A
        C : 1xN matrix, projection from latent to input
        B : Nx1 matrix, projection from input to latent
    """
    Omega_L = torch.exp(-2j*torch.pi * (torch.arange(L))/L)

    aterm = (torch.conj(C), torch.conj(Q))
    bterm = (B, P)

    g = (2.0/delta) * ((1.0-Omega_L)/(1.0+Omega_L))
    c = 2.0 / (1.0+Omega_L)

    k00 = cauchy(aterm[0] * bterm[0], g, Lambda)
    k01 = cauchy(aterm[0] * bterm[1], g, Lambda)
    k10 = cauchy(aterm[1] * bterm[0], g, Lambda)
    k11 = cauchy(aterm[1] * bterm[1], g, Lambda)

    atRoots = c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
    out = np.fft.irfft(atRoots, L)
    return out

In [41]:
def random_DPLR(N):
    Lambda = torch.rand((N,))
    P = torch.rand((N,))
    Q = torch.rand((N,))
    B = torch.rand((N, 1))
    C = torch.rand((1, N))
    return Lambda, P, Q, B, C

In [42]:
Lambda, P, Q, B, C = random_DPLR(8)

In [35]:
K_gen_DPLR(Lambda, P, P, B, C, delta, 100)

RuntimeError: The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 1

In [43]:
Omega_L = torch.exp(-2j*torch.pi * (torch.arange(100))/100)

aterm = (torch.conj(C).squeeze(), torch.conj(P).squeeze())
bterm = (B.squeeze(), P.squeeze())

g = (2.0/delta) * ((1.0-Omega_L)/(1.0+Omega_L))
c = 2.0 / (1.0+Omega_L)

In [62]:
test[1]

tensor([-0.5348+6.2853j, -0.0200+6.2853j, -0.1802+6.2853j, -0.3887+6.2853j,
        -0.3949+6.2853j, -0.7301+6.2853j, -0.3963+6.2853j, -0.6580+6.2853j])

In [65]:
(aterm[0] * bterm[0]).shape

torch.Size([8])

In [66]:
term1 = (aterm[0] * bterm[0]).unsqueeze(0)

In [63]:
g[1] - Lambda[2]

tensor(-0.1802+6.2853j)

In [55]:
test = g.unsqueeze(1) - Lambda

In [64]:
test.shape

torch.Size([100, 8])

In [69]:
term1[0]/test[0]

tensor([-1.0458-0.j, -2.3463-0.j, -1.6375-0.j, -1.2626-0.j, -0.1732-0.j, -0.0760-0.j,
        -0.2237-0.j, -0.2555-0.j])

In [67]:
term1/test

tensor([[-1.0458e+00-0.0000e+00j, -2.3463e+00-0.0000e+00j,
         -1.6375e+00-0.0000e+00j, -1.2626e+00-0.0000e+00j,
         -1.7315e-01-0.0000e+00j, -7.6041e-02-0.0000e+00j,
         -2.2374e-01-0.0000e+00j, -2.5547e-01-0.0000e+00j],
        [-7.5175e-03-8.8348e-02j, -2.3668e-05-7.4520e-03j,
         -1.3453e-03-4.6915e-02j, -4.8113e-03-7.7792e-02j,
         -6.8067e-04-1.0835e-02j, -1.0123e-03-8.7150e-03j,
         -8.8588e-04-1.4051e-02j, -2.7696e-03-2.6455e-02j],
        [-1.8858e-03-4.4370e-02j, -5.9060e-06-3.7223e-03j,
         -3.3587e-04-2.3449e-02j, -1.2039e-03-3.8969e-02j,
         -1.7033e-04-5.4281e-03j, -2.5512e-04-4.3971e-03j,
         -2.2169e-04-7.0393e-03j, -6.9670e-04-1.3323e-02j],
        [-8.3622e-04-2.9561e-02j, -2.6157e-06-2.4775e-03j,
         -1.4880e-04-1.5609e-02j, -5.3358e-04-2.5950e-02j,
         -7.5495e-05-3.6148e-03j, -1.1322e-04-2.9321e-03j,
         -9.8259e-05-4.6877e-03j, -3.0909e-04-8.8809e-03j],
        [-4.6837e-04-2.2127e-02j, -1.4648e-06-1.8538

In [56]:
test

tensor([[-5.3481e-01+0.0000e+00j, -1.9963e-02+0.0000e+00j,
         -1.8023e-01+0.0000e+00j, -3.8873e-01+0.0000e+00j,
         -3.9485e-01+0.0000e+00j, -7.3006e-01+0.0000e+00j,
         -3.9628e-01+0.0000e+00j, -6.5800e-01+0.0000e+00j],
        [-5.3481e-01+6.2853e+00j, -1.9963e-02+6.2853e+00j,
         -1.8023e-01+6.2853e+00j, -3.8873e-01+6.2853e+00j,
         -3.9485e-01+6.2853e+00j, -7.3006e-01+6.2853e+00j,
         -3.9628e-01+6.2853e+00j, -6.5800e-01+6.2853e+00j],
        [-5.3481e-01+1.2583e+01j, -1.9965e-02+1.2583e+01j,
         -1.8023e-01+1.2583e+01j, -3.8873e-01+1.2583e+01j,
         -3.9485e-01+1.2583e+01j, -7.3006e-01+1.2583e+01j,
         -3.9628e-01+1.2583e+01j, -6.5800e-01+1.2583e+01j],
        [-5.3481e-01+1.8906e+01j, -1.9960e-02+1.8906e+01j,
         -1.8023e-01+1.8906e+01j, -3.8873e-01+1.8906e+01j,
         -3.9485e-01+1.8906e+01j, -7.3006e-01+1.8906e+01j,
         -3.9628e-01+1.8906e+01j, -6.5800e-01+1.8906e+01j],
        [-5.3481e-01+2.5266e+01j, -1.9964e-02+2.5266

In [28]:
B

tensor([[0.9976],
        [0.3359],
        [0.3027],
        [0.2121],
        [0.8180],
        [0.3799],
        [0.3188],
        [0.9348],
        [0.4519],
        [0.7388]], requires_grad=True)

In [29]:
g.shape

NameError: name 'g' is not defined

In [30]:
c.shape

NameError: name 'c' is not defined

In [31]:
def discrete_DPLR(
    Lambda : torch.Tensor,
    P : torch.Tensor,
    Q : torch.Tensor,
    B : torch.Tensor,
    C : torch.Tensor,
    delta : torch.Tensor,
    L : int
)->(torch.Tensor, torch.Tensor, torch.Tensor):
    """
    computes the discretized version of the state space model,
    assuming the DPLR form

    Parameters:
        Lambda : Nx1, represents the diagonal values of the A matrix
        P : Nx1, represents part of the low rank aspect of the A matrix
        Q : Nx1, represents the other part of the low rank aspect of the A matrix
        B : N, projection from input to latent
        C : N, projection from latent to input
        delta : step size
        L : length of window
    """
    Bt = B.unsqueeze(1)
    Ct = C.unsqueeze(0)

    A = (torch.diag(Lambda) - torch.outer(P, torch.conj(Q)))
    A0 = 2.0/delta * torch.eye(A.shape[0]) + A

    Qdagger = torch.conj(torch.transpose(Q))
    
    D = torch.diag(1.0/(2.0/delta - Lambda))
    A1 = (D -  (1.0/(1.0 + Qdagger @ D @ P)) * D@P@Qdagger@D)
    Ab = A1@A0
    Bb = 2 * A1
    Cb = Ct @ torch.conj(torch.linalg.inv(torch.eye(A.shape[0]) - torch.matrix_power(Ab, L)))
    return Ab, Bb, Cb

In [32]:
def make_NPLR_HiPPO(N : int) -> torch.Tensor:
    """
    creating hippo matrix and associated low rank additive component, P
    and the B matrix associated, as hippo forces it

    parameters:
        N : int, degree of legendre polynomial coefficient
    """
    nhippo = make_HiPPO(N)

    P = torch.sqrt(torch.arange(N)+0.5).to(torch.complex64)
    B = torch.sqrt(2*torch.arange(N)+1.0).to(torch.complex64)

    return nhippo.to(torch.complex64), P, B


In [33]:
def make_DPLR_HiPPO(N : int) -> torch.Tensor:
    """
    convert matrices to DPLR representation
    parameters:
        N : int, degree of legendre polynomials
    """
    A, P, B = make_NPLR_HiPPO(N)

    S = A + torch.outer(P, P)

    S_diag = torch.diagonal(S)
    Lambda_real = torch.mean(S_diag) * torch.ones_like(S_diag)

    Lambda_imag, V = torch.linalg.eigh(S * -1j)
    P = V.T.conj() @ P
    B = V.T.conj() @ B
    return Lambda_real + 1j * Lambda_imag, P, B, V

In [70]:
N=8
A2, P, B = make_NPLR_HiPPO(N)
Lambda, Pc, Bc, V = make_DPLR_HiPPO(N)
Vc = V.conj().T
P = P
Pc = Pc
Lambda = torch.diag(Lambda)
A3 = V @ Lambda @ Vc - torch.outer(P,P.conj())  # Test NPLR
A4 = V @ (Lambda - torch.outer(Pc,Pc.conj())) @ Vc  # Test DPLR

assert torch.allclose(A2, A3, atol=1e-4, rtol=1e-4)
assert torch.allclose(A2, A4, atol=1e-4, rtol=1e-4)

In [73]:
B.shape

torch.Size([8])

In [68]:
torch.empty(10,10)
torch.nn.init.normal_(torch.empty(10), 0, 1)

tensor([ 0.1140, -1.3116,  0.7624,  0.3642,  0.8779,  0.8805,  1.0080, -1.9108,
        -0.1149, -1.4994])

In [110]:
class S4Layer(torch.nn.Module):
    """
    Efficient layer for S4Ms. (Structured State Space Sequence Models).
    Implements initialization of A as a NPLR matrix, enabling fast 
    matrix vector multiplication. 

    Several parameters, such as the projection matrix, are learned.

    In this case, the C matrix is actually learned as C(1-A^L). 
    This is fairly easy to undo, and is done in the calc of Cbar.

    Parameters:
        N_input : dimension of input,
        latent_dim : int, dimensions of latent space,
        
    """
    def __init__(
        self,
        N_input : int,
        latent_dim : int,
        dt_min  : float = 0.001,
        dt_max  : float = 0.1,
        step_grad : bool = True
    ):
        super().__init__()
        assert N_input==1


        self.latent_dim = latent_dim
        self.Lambda, self.P, self.B, _ = make_DPLR_HiPPO(self.latent_dim)
        
        self.Lambda = torch.autograd.Variable(self.Lambda, requires_grad = True)
        self.P = torch.autograd.Variable(self.P, requires_grad = True)
        self.B = torch.autograd.Variable(self.B, requires_grad = True)
        
        self.dt = torch.exp(self.log_step_initializer(dt_min, dt_max, step_grad))
        
        Ctilde = torch.nn.init.normal_(torch.empty(self.latent_dim, 2), mean=0, std=0.5**0.5)
        self.Ctilde = torch.autograd.Variable(Ctilde[:,0] + Ctilde[:,1]*1j, requires_grad=True)

        self.D = torch.autograd.Variable(torch.tensor(1), requires_grad = True)

    @staticmethod
    
    @staticmethod
    def make_DPLR_HiPPO(N : int) -> torch.Tensor:
        """
        convert matrices to DPLR representation
        parameters:
            N : int, degree of legendre polynomials
        """
        A, P, B = make_NPLR_HiPPO(N)
    
        S = A + torch.outer(P, P)
    
        S_diag = torch.diagonal(S)
        Lambda_real = torch.mean(S_diag) * torch.ones_like(S_diag)
    
        Lambda_imag, V = torch.linalg.eigh(S * -1j)
        P = V.T.conj() @ P
        B = V.T.conj() @ B
        return Lambda_real + 1j * Lambda_imag, P, B, V

    
    def log_step_initializer(self, dt_min = 0.001, dt_max = 0.1, requires_grad = True):
        """
        initial guess for dt, from random number generator. to be learned.
    
        parameters:
            dt_min
            dt_max
        """
        return torch.autograd.Variable(torch.rand(1) * (torch.log(torch.tensor(dt_max)) - torch.log(torch.tensor(dt_min))) + torch.log(torch.tensor(dt_min)), requires_grad = requires_grad)

    
    @staticmethod
    def cauchy_dot(v, omega, lambd):
        print(v.shape)
        print(omega.shape)
        print(lambd.shape)
        return (v / (omega - lambd)).sum()

    
    def kernel_DPLR(
        self,
        Lambda : torch.Tensor, 
        P : torch.Tensor, 
        Q : torch.Tensor, 
        B: torch.Tensor, 
        C : torch.Tensor, 
        delta : torch.Tensor, 
        L : int
    )-> torch.Tensor:
        """
        computes convolution kernel from generating function using DPLR representation and
        the cauchy kernel with CTilde assumed, not C.
    
        Parameters:
            Lambda : diagonal part of A
            P : Nx1 matrix, rank 1 representation to A
            Q : Nx1 matrix, rank 1 representation to A
            C : 1xN matrix, projection from latent to input
            B : Nx1 matrix, projection from input to latent
        """
        Omega_L = torch.exp(-2j*torch.pi * (torch.arange(L))/L)
    
        aterm = (torch.conj(C), torch.conj(Q))
        bterm = (B, P)
    
        g = (2.0/delta) * ((1.0-Omega_L)/(1.0+Omega_L))
        c = 2.0 / (1.0+Omega_L)

        Lambda = Lambda.unsqueeze(0)
    
        k00 = self.cauchy_dot(aterm[0] * bterm[0], g, Lambda)
        k01 = self.cauchy_dot(aterm[0] * bterm[1], g, Lambda)
        k10 = self.cauchy_dot(aterm[1] * bterm[0], g, Lambda)
        k11 = self.cauchy_dot(aterm[1] * bterm[1], g, Lambda)
    
        atRoots = c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
        out = np.fft.irfft(atRoots, L)
        return out

    def discrete_DPLR(
        self,
        Lambda : torch.Tensor,
        P : torch.Tensor,
        Q : torch.Tensor,
        B : torch.Tensor,
        C : torch.Tensor,
        delta : torch.Tensor,
        L : int
    )->(torch.Tensor, torch.Tensor, torch.Tensor):
        """
        computes the discretized version of the state space model,
        assuming the DPLR form with Ctilde, not C
    
        Parameters:
            Lambda : Nx1, represents the diagonal values of the A matrix
            P : Nx1, represents part of the low rank aspect of the A matrix
            Q : Nx1, represents the other part of the low rank aspect of the A matrix
            B : N, projection from input to latent
            C : N, projection from latent to input
            delta : step size
            L : length of window
        """
        Bt = B.unsqueeze(1)
        Ct = C.unsqueeze(0)
    
        A = (torch.diag(Lambda) - torch.outer(P, torch.conj(Q)))
        A0 = 2.0/delta * torch.eye(A.shape[0]) + A
    
        Qdagger = torch.conj(Q.T)
        
        D = torch.diag(1.0/(2.0/delta - Lambda))
        A1 = (D -  (1.0/(1.0 + Qdagger @ D @ P)) * D@P@Qdagger@D)
        Ab = A1@A0
        Bb = 2 * A1
        Cb = Ct @ torch.conj(torch.linalg.inv(torch.eye(A.shape[0]) - torch.matrix_power(Ab, L)))
        return Ab, Bb, Cb

    def forward(self, u : torch.Tensor, x0 : torch.Tensor, mode : bool | str):
        L = u.shape[0]
        if mode not in ["recurrent", True, False]:
            raise('mode not valid')
            
        if mode == "recurrent":
            Ab, Bb, Cb = self.discrete_DPLR(self.Lambda, self.P, self.P, self.B, self.C, self.dt, L)
            return self.scan_SSM(Ab, Bb, Cb, self.D, u, x0)[1]
        else:
            if mode:
                warnings.warn("convolving in non-fft mode. this is not recommended, as it is slow for large L")
            K = self.kernel_DPLR(self.Lambda, self.P, self.P, self.B, self.C, self.dt, L)
            return self.causal_conv(u, K, mode)

In [111]:
test_layer = S4Layer(N_input = 1, latent_dim = 10)

In [112]:
test_layer(u, torch.tensor(0), mode=False)

torch.Size([1, 10])
torch.Size([10000])
torch.Size([1, 10])


RuntimeError: The size of tensor a (10000) must match the size of tensor b (10) at non-singleton dimension 1

In [94]:
cauchy

<function __main__.cauchy(v: torch.Tensor, omega: torch.Tensor, lambd: torch.Tensor) -> torch.Tensor>