In [1]:
import torch
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math

from collections.abc import Callable

import warnings

In [57]:
class S4Layer(torch.nn.Module):
    """
    Efficient layer for S4Ms. (Structured State Space Sequence Models).
    Implements initialization of A as a NPLR matrix, enabling fast 
    matrix vector multiplication. 

    Several parameters, such as the projection matrix, are learned.

    In this case, the C matrix is actually learned as C(1-A^L). 
    This is fairly easy to undo, and is done in the calc of Cbar.

    Parameters:
        N_input : dimension of input,
        latent_dim : int, dimensions of latent space,
        
    """
    def __init__(
        self,
        N_input : int,
        latent_dim : int,
        dt_min  : torch.Tensor = torch.tensor(0.001),
        dt_max  : torch.Tensor = torch.tensor(0.1),
        step_grad : bool = True
    ):
        super().__init__()
        assert N_input==1


        self.latent_dim = latent_dim
        self.Lambda, self.P, self.B, _ = self.make_DPLR_HiPPO(self.latent_dim)
        
        self.Lambda = torch.autograd.Variable(self.Lambda, requires_grad = True)
        self.P = torch.autograd.Variable(self.P, requires_grad = True)
        self.B = torch.autograd.Variable(self.B, requires_grad = True)
        
        self.dt = torch.exp(self.log_step_initializer(dt_min, dt_max, step_grad))
        
        Ctilde = torch.nn.init.normal_(torch.empty(self.latent_dim, 2), mean=0, std=0.5**0.5)
        self.Ctilde = torch.autograd.Variable(Ctilde[:,0] + Ctilde[:,1]*1j, requires_grad=True)

        self.D = torch.autograd.Variable(torch.tensor(1.), requires_grad = True)

    def make_HiPPO(self, N : int) -> torch.Tensor:
        """
        creates HiPPO matrix for legendre polynomials up to order N
        parameters:
            N: int
        """
        P = torch.sqrt(1+2*torch.arange(N))
        A = P.unsqueeze(1) * P.unsqueeze(0)
        A = torch.tril(A) - torch.diag(torch.arange(N))
        return -A
        
    def make_NPLR_HiPPO(self, N : int) -> torch.Tensor:
        """
        creating hippo matrix and associated low rank additive component, P
        and the B matrix associated, as hippo forces it
    
        parameters:
            N : int, degree of legendre polynomial coefficient
        """
        nhippo = self.make_HiPPO(N)
    
        P = torch.sqrt(torch.arange(N)+0.5).to(torch.complex64)
        B = torch.sqrt(2*torch.arange(N)+1.0).to(torch.complex64)
    
        return nhippo.to(torch.complex64), P, B

    def make_DPLR_HiPPO(self, N : int) -> torch.Tensor:
        """
        convert matrices to DPLR representation
        parameters:
            N : int, degree of legendre polynomials
        """
        A, P, B = self.make_NPLR_HiPPO(N)
    
        S = A + torch.outer(P, P)
    
        S_diag = torch.diagonal(S)
        Lambda_real = torch.mean(S_diag) * torch.ones_like(S_diag)
    
        Lambda_imag, V = torch.linalg.eigh(S * -1j)
        P = V.T.conj() @ P
        B = V.T.conj() @ B
        return Lambda_real + 1j * Lambda_imag, P, B, V

    
    def log_step_initializer(self, dt_min = torch.tensor(0.001), dt_max = torch.tensor(0.1), requires_grad = True):
        """
        initial guess for dt, from random number generator. to be learned.
    
        parameters:
            dt_min
            dt_max
        """
        return torch.autograd.Variable(torch.rand(1) * (torch.log(dt_max) - torch.log(dt_min)) + torch.log(dt_min), requires_grad = requires_grad)

    
    def K_gen_DPLR(
        self,
        Lambda : torch.Tensor, 
        P : torch.Tensor, 
        Q : torch.Tensor, 
        B: torch.Tensor, 
        C : torch.Tensor, 
        delta : torch.Tensor, 
        L : int
    )-> torch.Tensor:
        """
        computes convolution kernel from generating function using DPLR representation and
        the cauchy kernel
    
        Parameters:
            Lambda : diagonal part of DPLR
            P : N matrix, rank 1 representation to A
            Q : N matrix, rank 1 representation to A
            C : N matrix, projection from latent to input
            B : N matrix, projection from input to latent
        """
        Omega_L = torch.exp(-2j*torch.pi * (torch.arange(L))/L)
    
        aterm = (torch.conj(C), torch.conj(Q))
        bterm = (B, P)
    
        g = (2.0/delta) * ((1.0-Omega_L)/(1.0+Omega_L))
        c = 2.0 / (1.0+Omega_L)
    
        k00 = self.cauchy(aterm[0] * bterm[0].unsqueeze(0), g.unsqueeze(1), Lambda)
        k01 = self.cauchy(aterm[0] * bterm[1].unsqueeze(0), g.unsqueeze(1), Lambda)
        k10 = self.cauchy(aterm[1] * bterm[0].unsqueeze(0), g.unsqueeze(1), Lambda)
        k11 = self.cauchy(aterm[1] * bterm[1].unsqueeze(0), g.unsqueeze(1), Lambda)
    
        atRoots = c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
        out = torch.fft.irfft(atRoots, L)
        return out

    
    def cauchy(self, k : torch.Tensor, omega : torch.Tensor, lambd : torch.Tensor):
        """
        computes cauchy kernel 
        sum(c_i * b_i/(z - lambda_i)

        Parameters:
            k : term by term dot product of vectors
            omega : function of the roots of unity
            lambd: diagonal parts of the DPLR matrix
        """
        return torch.sum(k/(omega-lambd), axis=1)

    def discrete_DPLR(
        self,
        Lambda : torch.Tensor,
        P : torch.Tensor,
        Q : torch.Tensor,
        B : torch.Tensor,
        C : torch.Tensor,
        delta : torch.Tensor,
        L : int
    )->(torch.Tensor, torch.Tensor, torch.Tensor):
        """
        computes the discretized version of the state space model,
        assuming the DPLR form
    
        Parameters:
            Lambda : Nx1, represents the diagonal values of the A matrix
            P : Nx1, represents part of the low rank aspect of the A matrix
            Q : Nx1, represents the other part of the low rank aspect of the A matrix
            B : N, projection from input to latent
            C : N, projection from latent to input
            delta : step size
            L : length of window
        """
        Bt = B.unsqueeze(1)
        Ct = C.unsqueeze(0)
    
        A = (torch.diag(Lambda) - torch.outer(P, torch.conj(Q)))
        A0 = 2.0/delta * torch.eye(A.shape[0]) + A
    
        Qdagger = torch.conj(torch.transpose(Q))
        
        D = torch.diag(1.0/(2.0/delta - Lambda))
        A1 = (D -  (1.0/(1.0 + Qdagger @ D @ P)) * D@P@Qdagger@D)
        Ab = A1@A0
        Bb = 2 * A1
        Cb = Ct @ torch.conj(torch.linalg.inv(torch.eye(A.shape[0]) - torch.matrix_power(Ab, L)))
        return Ab, Bb, Cb.conj()

    def forward(self):
        pass

In [58]:
test = S4Layer(1, 10)

In [64]:
bloop = test.K_gen_DPLR(test.Lambda, test.P, test.P, test.B, test.Ctilde, test.dt, 100)

In [60]:
test.Ctilde.shape

torch.Size([10])

In [61]:
def discrete_DPLR(Lambda, P, Q, B, C, step, L):
    # Convert parameters to matrices
    B = B.unsqueeze(1)
    Ct = C.unsqueeze(0)

    N = Lambda.shape[0]
    A = torch.diag(Lambda) - torch.outer(P, Q.conj())
    I = torch.eye(N)

    # Forward Euler
    A0 = (2.0 / step) * I + A

    # Backward Euler
    D = torch.diag(1.0 / ((2.0 / step) - Lambda))
    Qc = Q.conj().T.reshape(1, -1)
    P2 = P.reshape(-1, 1)
    A1 = D - (D @ P2 * (1.0 / (1 + (Qc @ D @ P2))) * Qc @ D)

    # A bar and B bar
    Ab = A1 @ A0
    Bb = 2 * A1 @ B

    # Recover Cbar from Ct
    Cb = Ct @ torch.linalg.inv(I - torch.matrix_power(Ab, L)).conj()
    return Ab, Bb, Cb.conj()

In [62]:
Ab, Bb, Cb = discrete_DPLR(test.Lambda, test.P, test.P, test.B, test.Ctilde, test.dt, 100)

In [69]:
bloop[2]

tensor(0.0012-0.0018j, grad_fn=<SelectBackward0>)

In [68]:
Cb@torch.matrix_power(Ab, 2)@Bb

tensor([[0.0012-0.0018j]], grad_fn=<MmBackward0>)

In [48]:
def K_conv(Ab, Bb, Cb, L):
    return torch.Tensor(
        [(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)]
    )
K2 = K_conv(Ab, Bb, Cb, L=100)


RuntimeError: value cannot be converted to type double without overflow