In [302]:
import torch
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math

from collections.abc import Callable

In [143]:
def random_SSM(N : int) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    A = torch.autograd.Variable(torch.rand(size=(N,N)), requires_grad = True)
    B = torch.autograd.Variable(torch.rand(size=(N,1)), requires_grad = True)
    C = torch.autograd.Variable(torch.rand(size=(1,N)), requires_grad = True)
    D = torch.autograd.Variable(torch.rand(size=(1,1)), requires_grad = True)
    return A, B, C, D

In [144]:
A, B, C, D = random_SSM(10)


In [160]:
D = torch.zeros((1,1))

In [161]:
delta = torch.tensor(0.01)

In [162]:
def discretize(
    A : torch.Tensor, B : torch.Tensor, C : torch.Tensor, D : torch.Tensor, delta : torch.Tensor
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
    """Discretizes SSM using bilinear model

    parameters:
        A: (NxN) transition matrix in latent
        B: (Nx1) projection matrix to latent
        C: (1xN) projection matrix from latent to output
        D: (1x1) skip connection from input to output
        delta: time step, ensure sufficient smallness
    """
    Cbar = C
    Dbar = D
    N = A.shape[0]
    Bl = torch.linalg.inv(torch.eye(N) - delta / 2 * A)
    Abar = Bl@(torch.eye(N) + delta/2 * A)
    Bbar = Bl@(delta*B)
    return Abar, Bbar, Cbar, Dbar

In [163]:
Abar, Bbar, Cbar, Dbar = discretize(A, B, C, D, delta)

In [164]:
T = 100
num_steps = int(T/delta)

u = torch.cos(torch.arange(num_steps))

In [204]:
def scan_SSM(
    Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, Db : torch.Tensor,  u : torch.Tensor, x0 : torch.Tensor
) -> torch.Tensor:
    """
    computes steps of the SSM going forward.

    parameters:
        Ab : (NxN) transition matrix in discrete space of latent to latent
        Bb : (Nx1) projcetion matrix from input to latent space
        Cb : (1xN) projection matrix from latent to output
        Db : (1x1) skip connection input to output
        u  : (L,)  trajectory we are trying to track
        x0 : (Nx1) initial condition of latent
    """
    x0 = torch.zeros((10,1))
    x = torch.zeros((Ab.shape[0], len(u[:100])))
    y = torch.zeros_like(u[:100])
    for i in range(u[:100].shape[0]):
        x[:,i] = (Ab@x0 + Bb*u[i]).squeeze()
        y[i] = (Cb@x[:,i]).squeeze()
        x0 = x[:,i].unsqueeze(-1)
    return x, y

In [205]:
def K_conv(Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, L : int) -> torch.Tensor:
    """
    computes convolution window given L time steps using equation K_t = Cb @ (Ab^t) @ Bb. 
    Needs to be flipped for correct causal convolution, but can be used as is in fft mode

    parameters:
        Ab : transition matrix
        Bb : projection matrix from input to latent
        Cb : projection matrix from latent to input
        Db : skip connection
        L  : length over which we want convolutional window
    """
    return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)])

In [206]:
def causal_conv(u : torch.Tensor, K : torch.Tensor, notfft : bool = False) -> torch.Tensor:
    """
    computes 1-d causal convolution either using standard method or fft transform.

    parameters:
        u : trajectory to convolve
        K : convolutional filter
        notfft: boolean, for whether or not we use fft mode or not.
    """
    assert len(u.shape)==1
    assert K.shape==u.shape
    
    L = u.shape[0]
    powers_of_2 = 2**int(math.ceil(math.log2(2*L)))

    if notfft:
        padded_u = torch.nn.functional.pad(u, (L-1,L-1))
        convolve = torch.zeros_like(u)
        for i in range(L):
            convolve[i] = torch.sum(padded_u[i:i+L]*K.flip(dims=[0]))
        return convolve
    else:

        K_pad = torch.nn.functional.pad(K, (0, L))
        u_pad = torch.nn.functional.pad(u, (0, L))
        
        K_f, u_f = torch.fft.rfft(K_pad, n = powers_of_2), torch.fft.rfft(u_pad, n = powers_of_2)
        return torch.fft.irfft(K_f * u_f, n = powers_of_2)[:L]

In [207]:
K = K_conv(Abar, Bbar, Cbar, 100)

In [208]:
conv_fft = causal_conv(u[:100], K)
conv_notfft = causal_conv(u[:100],K , notfft=True)

In [212]:
x, y = scan_SSM(Abar, Bbar, Cbar, Dbar, u[:100], torch.zeros((10,1)))

In [224]:
print((abs(conv_fft - conv_notfft)<1e-5).all())
print((abs(conv_fft - y)<1e-5).all())

tensor(True)
tensor(True)


In [241]:
torch.rand(1)

tensor([0.8434])

In [242]:
def log_step_initializer(dt_min = 0.001, dt_max = 0.1):
    """
    initial guess for dt, from random number generator. to be learned.

    parameters:
        dt_min
        dt_max
    """
    return torch.autograd.Variable(torch.rand(1) * (torch.log(dt_max) - torch.log(dt_min)) + torch.log(dt_min), requires_grad = True)

In [249]:
class SSMLayer(torch.nn.Module):
    """
    Simple layer that does SSMing. Assumes single input, single output. 
    Could be made multi-dimensional either by stacking and decorrelating,
    or by playing with the code to allow for multi input, multioutput. Should be relatively easy, 
    but need to carefully think a little about convolution of multi dim inputs.
    """
    def __init__(
        self,
        latent_dim,
        L_max,
        dt_min = 0.001,
        dt_max = 0.1,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.A, self.B, self.C, self.D = self.random_SSM(latent_dim)
        self.Abar, self.Bbar, self.Cbar, self.Dbar = self.discretize(self.A, self.B, self.C, self.D, self.dt)
        self.dt = self.log_step_initializer(dt_min, dt_max)

    def random_SSM(
        self, 
        N : int
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """
        initializing SSM parameters given latent dim
        
        parameters:
            N : size of latent dimension
        """
        A = torch.autograd.Variable(torch.rand(size=(N,N)), requires_grad = True)
        B = torch.autograd.Variable(torch.rand(size=(N,1)), requires_grad = True)
        C = torch.autograd.Variable(torch.rand(size=(1,N)), requires_grad = True)
        D = torch.autograd.Variable(torch.rand(size=(1,1)), requires_grad = True)
        return A, B, C, D

    def log_step_initializer(self, dt_min = 0.001, dt_max = 0.1):
        """
        initial guess for dt, from random number generator. to be learned.
    
        parameters:
            dt_min
            dt_max
        """
        return torch.autograd.Variable(torch.rand(1) * (torch.log(dt_max) - torch.log(dt_min)) + torch.log(dt_min), requires_grad = True)

    def discretize(
        self, A : torch.Tensor, B : torch.Tensor, C : torch.Tensor, D : torch.Tensor, delta : torch.Tensor
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
        """Discretizes SSM using bilinear model
    
        parameters:
            A: (NxN) transition matrix in latent
            B: (Nx1) projection matrix to latent
            C: (1xN) projection matrix from latent to output
            D: (1x1) skip connection from input to output
            delta: time step, ensure sufficient smallness
        """
        Cbar = C
        Dbar = D
        N = A.shape[0]
        Bl = torch.linalg.inv(torch.eye(N) - delta / 2 * A)
        Abar = Bl@(torch.eye(N) + delta/2 * A)
        Bbar = Bl@(delta*B)
        return Abar, Bbar, Cbar, Dbar

    def scan_SSM(
        self, Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, Db : torch.Tensor,  u : torch.Tensor, x0 : torch.Tensor
    ) -> torch.Tensor:
        """
        computes steps of the SSM going forward.
    
        parameters:
            Ab : (NxN) transition matrix in discrete space of latent to latent
            Bb : (Nx1) projcetion matrix from input to latent space
            Cb : (1xN) projection matrix from latent to output
            Db : (1x1) skip connection input to output
            u  : (L,)  trajectory we are trying to track
            x0 : (Nx1) initial condition of latent
        """
        x0 = torch.zeros((10,1))
        x = torch.zeros((Ab.shape[0], len(u[:100])))
        y = torch.zeros_like(u[:100])
        for i in range(u[:100].shape[0]):
            x[:,i] = (Ab@x0 + Bb*u[i]).squeeze()
            y[i] = (Cb@x[:,i]).squeeze()
            x0 = x[:,i].unsqueeze(-1)
        return x, y
        
    def K_conv(self, Ab : torch.Tensor, Bb : torch.Tensor, Cb : torch.Tensor, L : int) -> torch.Tensor:
        """
        computes convolution window given L time steps using equation K_t = Cb @ (Ab^t) @ Bb. 
        Needs to be flipped for correct causal convolution, but can be used as is in fft mode
    
        parameters:
            Ab : transition matrix
            Bb : projection matrix from input to latent
            Cb : projection matrix from latent to input
            Db : skip connection
            L  : length over which we want convolutional window
        """
        return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)])

    def causal_conv(u : torch.Tensor, K : torch.Tensor, notfft : bool = False) -> torch.Tensor:
        """
        computes 1-d causal convolution either using standard method or fft transform.
    
        parameters:
            u : trajectory to convolve
            K : convolutional filter
            notfft: boolean, for whether or not we use fft mode or not.
        """
        assert K.shape==u.shape
        
        L = u.shape[0]
        powers_of_2 = 2**int(math.ceil(math.log2(2*L)))
    
        if notfft:
            padded_u = torch.nn.functional.pad(u, (L-1,L-1))
            convolve = torch.zeros_like(u)
            for i in range(L):
                convolve[i] = torch.sum(padded_u[i:i+L]*K.flip(dims=[0]))
            return convolve
        else:
    
            K_pad = torch.nn.functional.pad(K, (0, L))
            u_pad = torch.nn.functional.pad(u, (0, L))
            
            K_f, u_f = torch.fft.rfft(K_pad, n = powers_of_2), torch.fft.rfft(u_pad, n = powers_of_2)
            return torch.fft.irfft(K_f * u_f, n = powers_of_2)[:L]

    def forward(
        self,
        u : torch.Tensor,
        x0 : torch.Tensor = torch.zeros((1,1)),
        mode : bool | str = False
    ) -> torch.Tensor:
        """
        forward pass of model

        Parameters:
            u  : input time series
            x0 : initial condition, only used in recurrent mode
            mode: recurrent mode ("recurrent"), or convolution mode (True : direct convolution, False : fourier transform)
        """
        if mode == "recurrent":
            return self.scan_SSM(self.Abar, self.Bbar, self.Cbar, u, x0)[1]
        else:
            K = self.K_conv(self.Abar, self.Bbar, self.Cbar, u.shape[0])
            return self.causal_conv(u, K, mode)

In [300]:
def make_HiPPO(N : int) -> torch.Tensor:
    """
    creates HiPPO matrix for legendre polynomials up to order N
    parameters:
        N: int
    """
    P = torch.sqrt(1+2*torch.arange(N))
    A = P.unsqueeze(1) * P.unsqueeze(0)
    A = torch.tril(A) - torch.diag(torch.arange(N))
    return A

In [307]:
def K_gen_inverse(
    Abar : torch.Tensor, Bbar : torch.Tensor, Cbar : torch.Tensor, L : int
) -> torch.Tensor:
    """
    creates generating function for convolutional window, to be evaluated at roots of unity
    parameters:
        Abar : discretized A matrix
        Bbar : discretized B matrix
        Cbar : discretized C matrix
        L    : length of convolutional window
    """
    Abar = Abar.to(torch.complex64)
    Bbar = Bbar.to(torch.complex64)
    Cbar = Cbar.to(torch.complex64)
    
    I = torch.eye(Abar.shape[0]).to(torch.complex64)
    Al = torch.matrix_power(Abar, L)
    Ctilde = Cbar @ (I - (Al))
    return lambda z: (torch.conj(Ctilde)@(torch.linalg.inv(I-Abar * z))@Bbar).squeeze()

In [308]:
blah = K_gen_inverse(Abar, Bbar, Cbar, 10)

In [309]:
blah(10)

tensor(0.0019, grad_fn=<SqueezeBackward0>)

In [314]:
def conv_from_gen(gen : Callable, L : int):
    """
    returns convolution from generating function by evaluating at roots of unity

    parameters:
        gen : generating function
        L   : int
    """
    omega_L = torch.exp(-2j * torch.pi * torch.arange(L)/L)
    atRoots = torch.tensor([gen(omega) for omega in omega_L])
    return torch.fft.irfft(atRoots, L).squeeze()
    

In [315]:
blah = K_gen_inverse(Abar, Bbar, Cbar, 100)

In [316]:
conv_from_gen(blah, 100)

RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::complex<float>

In [311]:
conv_fft

tensor([ 3.2026e-02,  5.0786e-02,  3.9775e-02,  9.8948e-03, -1.0565e-02,
        -1.9345e-03,  2.8752e-02,  5.4227e-02,  5.2064e-02,  2.5292e-02,
        -3.7849e-04, -1.9813e-04,  2.6874e-02,  5.7214e-02,  6.4257e-02,
         4.2923e-02,  1.4292e-02,  6.2244e-03,  2.7753e-02,  6.0778e-02,
         7.6717e-02,  6.2783e-02,  3.3747e-02,  1.8364e-02,  3.2938e-02,
         6.6338e-02,  9.0239e-02,  8.5166e-02,  5.8408e-02,  3.7322e-02,
         4.4187e-02,  7.5728e-02,  1.0613e-01,  1.1080e-01,  8.8944e-02,
         6.4354e-02,  6.3508e-02,  9.1249e-02,  1.2634e-01,  1.4100e-01,
         1.2645e-01,  1.0101e-01,  9.3246e-02,  1.1574e-01,  1.5352e-01,
         1.7785e-01,  1.7265e-01,  1.4932e-01,  1.3624e-01,  1.5272e-01,
         1.9126e-01,  2.2445e-01,  2.3020e-01,  2.1207e-01,  1.9601e-01,
         2.0654e-01,  2.4420e-01,  2.8512e-01,  3.0295e-01,  2.9313e-01,
         2.7713e-01,  2.8270e-01,  3.1842e-01,  3.6582e-01,  3.9643e-01,
         3.9794e-01,  3.8560e-01,  3.8822e-01,  4.2

In [280]:
A = make_HiPPO(10)

In [283]:
A[6,3]

tensor(-9.5394)

In [284]:
A

tensor([[ -1.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -1.7321,  -2.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -2.2361,  -3.8730,  -3.0000,  -0.0000,  -0.0000,  -0.0000,  -0.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -2.6458,  -4.5826,  -5.9161,  -4.0000,  -0.0000,  -0.0000,  -0.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -3.0000,  -5.1962,  -6.7082,  -7.9373,  -5.0000,  -0.0000,  -0.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -3.3166,  -5.7446,  -7.4162,  -8.7750,  -9.9499,  -6.0000,  -0.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -3.6056,  -6.2450,  -8.0623,  -9.5394, -10.8167, -11.9583,  -7.0000,
          -0.0000,  -0.0000,  -0.0000],
        [ -3.8730,  -6.7082,  -8.6603, -10.2470, -11.6189, -12.8452, -13.9642,
          -8.0000,  -0.0000,  -0.0000],
        [ -4.1231,  -7.1414,  -9.2195, -10.9087,

In [275]:
P = torch.sqrt(1+2*torch.arange(10))

In [294]:
test = torch.tril(P.unsqueeze(1) * P.unsqueeze(0)) - torch.diag(torch.arange(10))

In [295]:
test[9,8]

tensor(17.9722)

In [292]:
(2*9+1)**(1/2) * (2*8+1)**(1/2)

17.972200755611432

In [293]:
test

tensor([[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 1.7321,  2.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 2.2361,  3.8730,  3.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 2.6458,  4.5826,  5.9161,  4.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 3.0000,  5.1962,  6.7082,  7.9373,  5.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 3.3166,  5.7446,  7.4162,  8.7750,  9.9499,  6.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 3.6056,  6.2450,  8.0623,  9.5394, 10.8167, 11.9583,  7.0000,  0.0000,
          0.0000,  0.0000],
        [ 3.8730,  6.7082,  8.6603, 10.2470, 11.6189, 12.8452, 13.9642,  8.0000,
          0.0000,  0.0000],
        [ 4.1231,  7.1414,  9.2195, 10.9087, 12.3693, 13.6748, 14.8661, 15.9687,
          9.0000,  0.0000],
        [ 4.3589,  