In [1]:
import torch
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math

from collections.abc import Callable

import warnings

In [16]:
import torch
import math

class SSMLayer(torch.nn.Module):
    """
    Simple layer that does SSMing. Assumes single input, single output. 
    Could be made multi-dimensional either by stacking and decorrelating,
    or by playing with the code to allow for multi input, multioutput. Should be relatively easy, 
    but need to carefully think a little about convolution of multi dim inputs.
    """
    def __init__(
        self,
        latent_dim,
        dt_min = torch.tensor(0.001),
        dt_max = torch.tensor(0.1),
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.A, self.B, self.C, self.D = self.random_SSM(latent_dim)
        self.dt = torch.exp(self.log_step_initializer(dt_min, dt_max))
        self.Abar, self.Bbar, self.Cbar, self.Dbar = self.discretize(self.A, self.B, self.C, self.D, self.dt)


    def random_SSM(
        self, 
        N : int
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """
        initializing SSM parameters given latent dim
        
        parameters:
            N : size of latent dimension
        """
        A = torch.autograd.Variable(torch.rand(size=(N,N)), requires_grad = True)
        B = torch.autograd.Variable(torch.rand(size=(N,1)), requires_grad = True)
        C = torch.autograd.Variable(torch.rand(size=(1,N)), requires_grad = True)
        D = torch.autograd.Variable(torch.rand(size=(1,1)), requires_grad = True)
        return A, B, C, D

    def log_step_initializer(self, dt_min = 0.001, dt_max = 0.1):
        """
        initial guess for dt, from random number generator. to be learned.
    
        parameters:
            dt_min
            dt_max
        """
        return torch.autograd.Variable(torch.rand(1) * (torch.log(dt_max) - torch.log(dt_min)) + torch.log(dt_min), requires_grad = True)

    def discretize(
        self, A : torch.Tensor, B : torch.Tensor, C : torch.Tensor, D : torch.Tensor, delta : torch.Tensor
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """Discretizes SSM using bilinear model
    
        parameters:
            A: (NxN) transition matrix in latent
            B: (Nx1) projection matrix to latent
            C: (1xN) projection matrix from latent to output
            D: (1x1) skip connection from input to output
            delta: time step, ensure sufficient smallness
        """
        Cbar = C
        Dbar = D
        N = A.shape[0]
        Bl = torch.linalg.inv(torch.eye(N) - delta / 2 * A)
        Abar = Bl@(torch.eye(N) + delta/2 * A)
        Bbar = Bl@(delta*B)
        return Abar, Bbar, Cbar, Dbar

    def scan_SSM(
        self, Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, Db : torch.Tensor,  u : torch.Tensor, x0 : torch.Tensor
    ) -> torch.Tensor:
        """
        computes steps of the SSM going forward.
    
        parameters:
            Ab : (NxN) transition matrix in discrete space of latent to latent
            Bb : (Nx1) projcetion matrix from input to latent space
            Cb : (1xN) projection matrix from latent to output
            Db : (1x1) skip connection input to output
            u  : (L,)  trajectory we are trying to track
            x0 : (Nx1) initial condition of latent
        """
        x0 = torch.zeros((10,1))
        x = torch.zeros((Ab.shape[0], len(u[:100])))
        y = torch.zeros_like(u[:100])
        for i in range(u[:100].shape[0]):
            x[:,i] = (Ab@x0 + Bb*u[i]).squeeze()
            y[i] = (Cb@x[:,i]).squeeze()
            x0 = x[:,i].unsqueeze(-1)
        return x, y + Db*u
        
    def K_conv(self, Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, L : int) -> torch.Tensor:
        """
        computes convolution window given L time steps using equation K_t = Cb @ (Ab^t) @ Bb. 
        Needs to be flipped for correct causal convolution, but can be used as is in fft mode
    
        parameters:
            Ab : transition matrix
            Bb : projection matrix from input to latent
            Cb : projection matrix from latent to input
            Db : skip connection
            L  : length over which we want convolutional window
        """
        return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)])

    def causal_conv(self, u : torch.Tensor, K : torch.Tensor, notfft : bool = False) -> torch.Tensor:
        """
        computes 1-d causal convolution either using standard method or fft transform.
    
        parameters:
            u : trajectory to convolve
            K : convolutional filter
            notfft: boolean, for whether or not we use fft mode or not.
        """
        assert K.shape==u.shape
        
        L = u.shape[0]
        powers_of_2 = 2**int(math.ceil(math.log2(2*L)))
    
        if notfft:
            padded_u = torch.nn.functional.pad(u, (L-1,L-1))
            convolve = torch.zeros_like(u)
            for i in range(L):
                convolve[i] = torch.sum(padded_u[i:i+L]*K.flip(dims=[0]))
            return convolve
        else:
    
            K_pad = torch.nn.functional.pad(K, (0, L))
            u_pad = torch.nn.functional.pad(u, (0, L))
            
            K_f, u_f = torch.fft.rfft(K_pad, n = powers_of_2), torch.fft.rfft(u_pad, n = powers_of_2)
            return torch.fft.irfft(K_f * u_f, n = powers_of_2)[:L]

    def forward(
        self,
        u : torch.Tensor,
        x0 : torch.Tensor = torch.zeros((1,1)),
        mode : bool | str = False
    ) -> torch.Tensor:
        """
        forward pass of model

        Parameters:
            u  : input time series
            x0 : initial condition, only used in recurrent mode
            mode: recurrent mode ("recurrent"), or convolution mode (True : direct convolution, False : fourier transform)
        """
        if mode == "recurrent":
            return self.scan_SSM(self.Abar, self.Bbar, self.Cbar, self.Dbar, u, x0)[1]
        else:
            K = self.K_conv(self.Abar, self.Bbar, self.Cbar, u.shape[0])
            return self.causal_conv(u, K, mode) + self.D*u
        

In [7]:
ssm = SSMLayer(latent_dim=10)

In [8]:
T = 100
num_steps = int(T/torch.exp(ssm.dt))


u = torch.cos(torch.arange(num_steps))

In [15]:
ssm(u[:100], torch.tensor(0), mode = "recurrent") - ssm(u[:100], torch.tensor(0), mode = True)

tensor([[ 0.0000e+00,  0.0000e+00, -3.7253e-09, -1.4901e-08,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -7.4506e-09,  0.0000e+00,
         -1.4901e-08, -1.1176e-08,  0.0000e+00, -2.9802e-08, -4.4703e-08,
         -7.4506e-09, -1.4901e-08,  3.7253e-09,  0.0000e+00,  0.0000e+00,
         -2.9802e-08, -5.9605e-08, -5.9605e-08,  0.0000e+00, -1.4901e-08,
          5.9605e-08, -5.9605e-08, -1.3411e-07, -8.9407e-08,  1.4901e-08,
         -5.9605e-08, -1.1921e-07, -1.1921e-07, -1.7881e-07,  0.0000e+00,
         -8.9407e-08, -1.4901e-07, -5.9605e-08, -1.7881e-07, -1.7881e-07,
         -1.1921e-07, -2.3842e-07, -1.1921e-07, -2.9802e-07,  0.0000e+00,
          1.1921e-07, -2.3842e-07, -1.1921e-07, -1.7881e-07, -3.5763e-07,
          0.0000e+00, -4.7684e-07, -2.3842e-07,  0.0000e+00,  0.0000e+00,
         -7.1526e-07, -7.1526e-07, -4.7684e-07, -7.1526e-07, -2.3842e-07,
          2.3842e-07, -7.1526e-07, -1.4305e-06, -1.9073e-06, -2.3842e-06,
         -9.5367e-07,  0.0000e+00,  0.

In [14]:
ssm(u[:100], torch.tensor(0), mode = False) - ssm(u[:100], torch.tensor(0), mode = True)

tensor([[-8.1658e-06, -4.2766e-06, -1.6805e-05, -1.6332e-05, -2.3991e-06,
         -6.5565e-07,  6.2287e-06,  3.3379e-06, -1.2256e-05, -6.4373e-06,
          6.6757e-06,  1.3277e-05,  2.6822e-06,  5.6326e-06,  2.2948e-06,
          8.8215e-06,  1.8612e-05,  1.7304e-05,  1.4305e-05,  7.0930e-06,
         -6.5267e-06, -3.2037e-06,  4.2319e-06,  1.0133e-05,  1.6272e-05,
          1.3500e-05, -9.1195e-06, -7.7039e-06,  2.2352e-06,  2.6524e-06,
          9.6858e-06,  6.0797e-06,  5.1856e-06,  2.3842e-07, -9.2089e-06,
          3.6955e-06,  1.1057e-05,  1.0252e-05,  7.9274e-06,  8.2254e-06,
          0.0000e+00,  1.1563e-05,  1.0788e-05,  1.9014e-05,  1.8716e-05,
          1.5855e-05,  1.1325e-05,  1.3590e-05,  1.2457e-05,  1.0848e-05,
          1.4544e-05,  8.5831e-06,  0.0000e+00, -2.6226e-06,  3.3379e-06,
          8.5831e-06,  1.7405e-05,  1.6212e-05, -3.0994e-06, -4.5300e-06,
         -5.0068e-06,  3.0994e-06,  6.1989e-06,  1.3351e-05,  3.5763e-06,
         -2.8610e-06, -5.2452e-06,  8.

In [13]:
ssm(u[:100], torch.tensor(0), mode = True)

tensor([[ 3.0553e-01,  2.0511e-01, -6.2344e-02, -2.4922e-01, -1.8186e-01,
          7.9826e-02,  2.9742e-01,  2.7322e-01,  3.2046e-02, -2.0160e-01,
         -2.0990e-01,  1.8045e-02,  2.7619e-01,  3.3103e-01,  1.3630e-01,
         -1.2446e-01, -2.0664e-01, -2.9377e-02,  2.5008e-01,  3.8102e-01,
          2.4980e-01, -1.5628e-02, -1.6331e-01, -4.8881e-02,  2.3177e-01,
          4.3070e-01,  3.7596e-01,  1.2974e-01, -6.8727e-02, -2.3030e-02,
          2.3993e-01,  4.9478e-01,  5.2498e-01,  3.2201e-01,  9.3377e-02,
          7.1912e-02,  3.0189e-01,  5.9849e-01,  7.1785e-01,  5.8151e-01,
          3.4873e-01,  2.7027e-01,  4.5813e-01,  7.8279e-01,  9.9260e-01,
          9.4544e-01,  7.3973e-01,  6.2429e-01,  7.6996e-01,  1.1130e+00,
          1.4140e+00,  1.4788e+00,  1.3372e+00,  1.2163e+00,  1.3323e+00,
          1.6926e+00,  2.0894e+00,  2.2918e+00,  2.2590e+00,  2.1785e+00,
          2.2950e+00,  2.6865e+00,  3.1937e+00,  3.5677e+00,  3.7005e+00,
          3.7258e+00,  3.8974e+00,  4.

In [11]:
ssm.Cbar@(ssm.Bbar * u[0]) + ssm.D * u[0]

tensor([[0.3055]], grad_fn=<AddBackward0>)

torch.Size([1, 10])