In [1]:
import torch
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math

from collections.abc import Callable

import warnings

In [56]:
class S4Layer(torch.nn.Module):
    """
    Efficient layer for S4Ms. (Structured State Space Sequence Models).
    Implements initialization of A as a NPLR matrix, enabling fast 
    matrix vector multiplication. 

    Several parameters, such as the projection matrix, are learned.

    In this case, the C matrix is actually learned as C(1-A^L). 
    This is fairly easy to undo, and is done in the calc of Cbar.

    Parameters:
        N_input : dimension of input,
        latent_dim : int, dimensions of latent space,
        
    """
    def __init__(
        self,
        N_input : int,
        latent_dim : int,
        dt_min  : torch.Tensor = torch.tensor(0.001),
        dt_max  : torch.Tensor = torch.tensor(0.1),
        step_grad : bool = True
    ):
        super().__init__()
        assert N_input==1


        self.latent_dim = latent_dim
        self.Lambda, self.P, self.B, _ = self.make_DPLR_HiPPO(self.latent_dim)
        
        self.Lambda = torch.autograd.Variable(self.Lambda, requires_grad = True)
        self.P = torch.autograd.Variable(self.P, requires_grad = True)
        self.B = torch.autograd.Variable(self.B, requires_grad = True)
        
        self.log_dt = self.log_step_initializer(dt_min, dt_max, step_grad)
        
        Ctilde = torch.nn.init.normal_(torch.empty(self.latent_dim, 2), mean=0, std=0.5**0.5)
        self.Ctilde = torch.autograd.Variable(Ctilde[:,0] + Ctilde[:,1]*1j, requires_grad=True)

        self.D = torch.autograd.Variable(torch.tensor(1.), requires_grad = True)

    def make_HiPPO(self, N : int) -> torch.Tensor:
        """
        creates HiPPO matrix for legendre polynomials up to order N
        parameters:
            N: int
        """
        P = torch.sqrt(1+2*torch.arange(N))
        A = P.unsqueeze(1) * P.unsqueeze(0)
        A = torch.tril(A) - torch.diag(torch.arange(N))
        return -A
        
    def make_NPLR_HiPPO(self, N : int) -> torch.Tensor:
        """
        creating hippo matrix and associated low rank additive component, P
        and the B matrix associated, as hippo forces it
    
        parameters:
            N : int, degree of legendre polynomial coefficient
        """
        nhippo = self.make_HiPPO(N)
    
        P = torch.sqrt(torch.arange(N)+0.5).to(torch.complex64)
        B = torch.sqrt(2*torch.arange(N)+1.0).to(torch.complex64)
    
        return nhippo.to(torch.complex64), P, B

    def make_DPLR_HiPPO(self, N : int) -> torch.Tensor:
        """
        convert matrices to DPLR representation
        parameters:
            N : int, degree of legendre polynomials
        """
        A, P, B = self.make_NPLR_HiPPO(N)
    
        S = A + torch.outer(P, P)
    
        S_diag = torch.diagonal(S)
        Lambda_real = torch.mean(S_diag) * torch.ones_like(S_diag)
    
        Lambda_imag, V = torch.linalg.eigh(S * -1j)
        P = V.T.conj() @ P
        B = V.T.conj() @ B
        return Lambda_real + 1j * Lambda_imag, P, B, V

    
    def log_step_initializer(self, dt_min = torch.tensor(0.001), dt_max = torch.tensor(0.1), requires_grad = True):
        """
        initial guess for dt, from random number generator. to be learned.
    
        parameters:
            dt_min
            dt_max
        """
        return torch.autograd.Variable(torch.rand(1) * (torch.log(dt_max) - torch.log(dt_min)) + torch.log(dt_min), requires_grad = requires_grad)

    
    def K_gen_DPLR(
        self,
        Lambda : torch.Tensor, 
        P : torch.Tensor, 
        Q : torch.Tensor, 
        B: torch.Tensor, 
        C : torch.Tensor, 
        delta : torch.Tensor, 
        L : int
    )-> torch.Tensor:
        """
        computes convolution kernel from generating function using DPLR representation and
        the cauchy kernel
    
        Parameters:
            Lambda : diagonal part of DPLR
            P : N matrix, rank 1 representation to A
            Q : N matrix, rank 1 representation to A
            C : N matrix, projection from latent to input
            B : N matrix, projection from input to latent
        """
        Omega_L = torch.exp(-2j*torch.pi * (torch.arange(L))/L)
    
        aterm = (torch.conj(C), torch.conj(Q))
        bterm = (B, P)
    
        g = (2.0/delta) * ((1.0-Omega_L)/(1.0+Omega_L))
        c = 2.0 / (1.0+Omega_L)
    
        k00 = self.cauchy((aterm[0] * bterm[0]).unsqueeze(0), g.unsqueeze(1), Lambda)
        k01 = self.cauchy((aterm[0] * bterm[1]).unsqueeze(0), g.unsqueeze(1), Lambda)
        k10 = self.cauchy((aterm[1] * bterm[0]).unsqueeze(0), g.unsqueeze(1), Lambda)
        k11 = self.cauchy((aterm[1] * bterm[1]).unsqueeze(0), g.unsqueeze(1), Lambda)
    
        atRoots = c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
        out = torch.fft.ifft(atRoots, L)
        return out

    
    def cauchy(self, k : torch.Tensor, omega : torch.Tensor, lambd : torch.Tensor):
        """
        computes cauchy kernel 
        sum(c_i * b_i/(z - lambda_i)

        Parameters:
            k : term by term dot product of vectors
            omega : function of the roots of unity
            lambd: diagonal parts of the DPLR matrix
        """
        return torch.sum(k/(omega-lambd), axis=1)

    def discrete_DPLR(
        self,
        Lambda : torch.Tensor,
        P : torch.Tensor,
        Q : torch.Tensor,
        B : torch.Tensor,
        C : torch.Tensor,
        delta : torch.Tensor,
        L : int
    )->(torch.Tensor, torch.Tensor, torch.Tensor):
        """
        computes the discretized version of the state space model,
        assuming the DPLR form
    
        Parameters:
            Lambda : Nx1, represents the diagonal values of the A matrix
            P : Nx1, represents part of the low rank aspect of the A matrix
            Q : Nx1, represents the other part of the low rank aspect of the A matrix
            B : N, projection from input to latent
            C : N, projection from latent to input
            delta : step size
            L : length of window
        """
        Bt = B.unsqueeze(1)
        Ct = C.unsqueeze(0)
    
        A = (torch.diag(Lambda) - torch.outer(P, torch.conj(Q)))
        A0 = 2.0/delta * torch.eye(A.shape[0]) + A
        
        Qdagger = Q.conj().unsqueeze(0)
        P_1 = P.unsqueeze(1)
        
        D = torch.diag(1.0/(2.0/delta - Lambda))
        A1 = (D -  (1.0/(1.0 + Qdagger @ D @ P_1)) * D@P_1@Qdagger@D)
        Ab = A1@A0
        Bb = 2 * A1@B.unsqueeze(1)
        Cb = Ct @ torch.conj(torch.linalg.inv(torch.eye(A.shape[0]) - torch.matrix_power(Ab, L)))
        return Ab, Bb, Cb.conj()

    def scan_SSM(
        self,
        Ab : torch.Tensor,
        Bb : torch.Tensor,
        Cb : torch.Tensor,
        u  : torch.Tensor,
        x0 : torch.Tensor,
    ):
        x = torch.zeros((u.shape[0], Ab.shape[0]))
        y = torch.zeros_like(u)

        for index in range(u.shape[0]):
            x[index, :] = 

    def forward(
        self,
        u : torch.Tensor,
        mode : str,
        x0 : torch.Tensor = torch.Tensor(0),
    ):
        dt = torch.exp(self.log_dt)
        L = u.shape[0]
        if mode == "recurrent":
            Ab, Bb, Cb = self.discrete_DPLR(self.Lambda, self.P, self.P, self.B, self.C, dt, L)
            return 
        pass

In [57]:
test = S4Layer(1,10)

In [55]:
def cauchy(k : torch.Tensor, omega : torch.Tensor, lambd : torch.Tensor):
    """
    computes cauchy kernel 
    sum(c_i * b_i/(z - lambda_i)

    Parameters:
        k : term by term dot product of vectors
        omega : function of the roots of unity
        lambd: diagonal parts of the DPLR matrix
    """
    return torch.sum(k/(omega-lambd), axis=1)

def K_gen_DPLR(
    Lambda : torch.Tensor, 
    P : torch.Tensor, 
    Q : torch.Tensor, 
    B: torch.Tensor, 
    C : torch.Tensor, 
    delta : torch.Tensor, 
    L : int
)-> torch.Tensor:
    """
    computes convolution kernel from generating function using DPLR representation and
    the cauchy kernel

    Parameters:
        Lambda : diagonal part of DPLR
        P : N matrix, rank 1 representation to A
        Q : N matrix, rank 1 representation to A
        C : N matrix, projection from latent to input
        B : N matrix, projection from input to latent
    """
    Omega_L = torch.exp(-2j*torch.pi * (torch.arange(L))/L)

    aterm = (torch.conj(C), torch.conj(Q))
    bterm = (B, P)

    g = (2.0/delta) * ((1.0-Omega_L)/(1.0+Omega_L))
    c = 2.0 / (1.0+Omega_L)

    k00 = cauchy((aterm[0] * bterm[0]).unsqueeze(0), g.unsqueeze(1), Lambda)
    k01 = cauchy((aterm[0] * bterm[1]).unsqueeze(0), g.unsqueeze(1), Lambda)
    k10 = cauchy((aterm[1] * bterm[0]).unsqueeze(0), g.unsqueeze(1), Lambda)
    k11 = cauchy((aterm[1] * bterm[1]).unsqueeze(0), g.unsqueeze(1), Lambda)

    atRoots = c * (k00 - k01 * (1.0 / (1.0 + k11)) * k10)
    out = torch.fft.ifft(atRoots, L)[:L]
    return out.real

In [49]:
Lambda = torch.randn(10)
P = torch.randn(10)

A = torch.diag(Lambda) - torch.outer(P, P)

B = torch.randn(10)
C = torch.randn(10)

In [50]:
test_1 = K_gen_DPLR(Lambda, P, P, B, C, torch.tensor(.01), 100)

In [51]:
Ab, Bb, Cb = discrete_DPLR(Lambda, P, P, B, C, torch.tensor(.01), 100)

torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])
torch.Size([10, 10])


In [52]:
def K_conv(Ab, Bb, Cb, L):
        return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)]
    )

test_2 = K_conv(Ab, Bb, Cb, L=100)


In [53]:
test_2

tensor([-0.1590, -0.1603, -0.1614, -0.1625, -0.1636, -0.1646, -0.1655, -0.1664,
        -0.1672, -0.1680, -0.1687, -0.1694, -0.1701, -0.1707, -0.1712, -0.1718,
        -0.1722, -0.1727, -0.1731, -0.1735, -0.1739, -0.1742, -0.1745, -0.1748,
        -0.1750, -0.1753, -0.1755, -0.1756, -0.1758, -0.1759, -0.1761, -0.1761,
        -0.1762, -0.1763, -0.1763, -0.1764, -0.1764, -0.1764, -0.1763, -0.1763,
        -0.1763, -0.1762, -0.1761, -0.1760, -0.1759, -0.1758, -0.1757, -0.1756,
        -0.1754, -0.1753, -0.1751, -0.1750, -0.1748, -0.1746, -0.1744, -0.1742,
        -0.1740, -0.1738, -0.1736, -0.1733, -0.1731, -0.1729, -0.1726, -0.1724,
        -0.1721, -0.1718, -0.1716, -0.1713, -0.1710, -0.1707, -0.1704, -0.1701,
        -0.1698, -0.1695, -0.1692, -0.1689, -0.1686, -0.1683, -0.1679, -0.1676,
        -0.1673, -0.1669, -0.1666, -0.1662, -0.1659, -0.1655, -0.1652, -0.1648,
        -0.1645, -0.1641, -0.1637, -0.1634, -0.1630, -0.1626, -0.1622, -0.1618,
        -0.1614, -0.1610, -0.1606, -0.16

In [54]:
test_1

tensor([-0.1591, -0.1603, -0.1615, -0.1626, -0.1637, -0.1646, -0.1656, -0.1665,
        -0.1673, -0.1681, -0.1688, -0.1695, -0.1701, -0.1707, -0.1713, -0.1718,
        -0.1723, -0.1728, -0.1732, -0.1736, -0.1740, -0.1743, -0.1746, -0.1749,
        -0.1751, -0.1753, -0.1755, -0.1757, -0.1759, -0.1760, -0.1761, -0.1762,
        -0.1763, -0.1764, -0.1764, -0.1764, -0.1764, -0.1764, -0.1764, -0.1764,
        -0.1763, -0.1763, -0.1762, -0.1761, -0.1760, -0.1759, -0.1758, -0.1757,
        -0.1755, -0.1754, -0.1752, -0.1751, -0.1749, -0.1747, -0.1745, -0.1743,
        -0.1741, -0.1739, -0.1737, -0.1734, -0.1732, -0.1729, -0.1727, -0.1724,
        -0.1722, -0.1719, -0.1716, -0.1714, -0.1711, -0.1708, -0.1705, -0.1702,
        -0.1699, -0.1696, -0.1693, -0.1690, -0.1687, -0.1683, -0.1680, -0.1677,
        -0.1674, -0.1670, -0.1667, -0.1663, -0.1660, -0.1656, -0.1653, -0.1649,
        -0.1645, -0.1642, -0.1638, -0.1634, -0.1631, -0.1627, -0.1623, -0.1619,
        -0.1615, -0.1611, -0.1607, -0.16

In [35]:
torch.conj(test.Ctilde) * test.B.unsqueeze(0)

tensor([[ 5.7949+1.1499j,  0.3463+1.0878j,  1.5644+0.3168j,  1.9604-0.4458j,
          0.4461-0.8018j,  0.5388-0.9901j, -1.0271-0.5473j, -0.9828+0.1719j,
          1.5598-3.8654j,  0.9211+0.7982j]], grad_fn=<MulBackward0>)

In [36]:
(torch.conj(test.Ctilde) * test.B).unsqueeze(0)

tensor([[ 5.7949+1.1499j,  0.3463+1.0878j,  1.5644+0.3168j,  1.9604-0.4458j,
          0.4461-0.8018j,  0.5388-0.9901j, -1.0271-0.5473j, -0.9828+0.1719j,
          1.5598-3.8654j,  0.9211+0.7982j]], grad_fn=<UnsqueezeBackward0>)

In [37]:
K_gen_DPLR(test.Lambda, test.P, test.P, test.B, test.Ctilde, test.dt, 100)

tensor([ 0.0705, -0.0054,  0.0152,  0.0349, -0.0133,  0.0218,  0.0081, -0.0070,
         0.0184, -0.0053,  0.0042,  0.0091, -0.0041,  0.0104,  0.0024,  0.0047,
         0.0097,  0.0039,  0.0119,  0.0075,  0.0114,  0.0132,  0.0100,  0.0181,
         0.0116,  0.0173,  0.0189,  0.0131,  0.0240,  0.0157,  0.0196,  0.0249,
         0.0143,  0.0269,  0.0199,  0.0185,  0.0289,  0.0149,  0.0258,  0.0236,
         0.0153,  0.0297,  0.0155,  0.0214,  0.0258,  0.0117,  0.0269,  0.0163,
         0.0149,  0.0256,  0.0086,  0.0211,  0.0169,  0.0081,  0.0229,  0.0066,
         0.0135,  0.0166,  0.0021,  0.0179,  0.0056,  0.0053,  0.0150, -0.0023,
         0.0112,  0.0052, -0.0021,  0.0120, -0.0048,  0.0037,  0.0050, -0.0080,
         0.0077, -0.0056, -0.0037,  0.0043, -0.0118,  0.0024, -0.0051, -0.0099,
         0.0029, -0.0134, -0.0032, -0.0039, -0.0145,  0.0007, -0.0131, -0.0086,
        -0.0027, -0.0170, -0.0022, -0.0112, -0.0130, -0.0019, -0.0174, -0.0055,
        -0.0084, -0.0161, -0.0018, -0.01

In [38]:
Cb@(torch.matrix_power(Ab, 2))@Bb

tensor([[0.0379-0.0367j]], grad_fn=<MmBackward0>)

In [58]:
bloop = test.K_gen_DPLR(test.Lambda, test.P, test.P, test.B, test.Ctilde, test.dt, 100)

In [59]:
bloop

tensor([-0.0066-1.7582e-02j, -0.0056-1.8130e-02j, -0.0047-1.8441e-02j,
        -0.0039-1.8544e-02j, -0.0032-1.8465e-02j, -0.0027-1.8229e-02j,
        -0.0022-1.7859e-02j, -0.0018-1.7374e-02j, -0.0015-1.6793e-02j,
        -0.0013-1.6132e-02j, -0.0011-1.5407e-02j, -0.0010-1.4632e-02j,
        -0.0009-1.3819e-02j, -0.0009-1.2978e-02j, -0.0009-1.2121e-02j,
        -0.0010-1.1255e-02j, -0.0010-1.0388e-02j, -0.0011-9.5279e-03j,
        -0.0012-8.6801e-03j, -0.0013-7.8499e-03j, -0.0015-7.0419e-03j,
        -0.0016-6.2601e-03j, -0.0018-5.5077e-03j, -0.0020-4.7875e-03j,
        -0.0021-4.1017e-03j, -0.0023-3.4522e-03j, -0.0025-2.8403e-03j,
        -0.0027-2.2670e-03j, -0.0028-1.7330e-03j, -0.0030-1.2388e-03j,
        -0.0032-7.8424e-04j, -0.0034-3.6936e-04j, -0.0035+6.2415e-06j,
        -0.0037+3.4313e-04j, -0.0038+6.4201e-04j, -0.0040+9.0376e-04j,
        -0.0042+1.1294e-03j, -0.0043+1.3199e-03j, -0.0044+1.4767e-03j,
        -0.0046+1.6008e-03j, -0.0047+1.6938e-03j, -0.0048+1.7568e-03j,
      

In [60]:
Ab, Bb, Cb = test.discrete_DPLR(test.Lambda, test.P, test.P, test.B, test.Ctilde, test.dt, 100)

In [15]:
bloop[-1]

tensor(-0.0312, grad_fn=<SelectBackward0>)

In [65]:
test.B.shape

torch.Size([10])

In [43]:
def discrete_DPLR(Lambda, P, Q, B, C, step, L):
    # Convert parameters to matrices
    B = B.unsqueeze(1)
    Ct = C.unsqueeze(0)

    N = Lambda.shape[0]
    A = torch.diag(Lambda) - torch.outer(P, Q.conj())
    I = torch.eye(N)

    # Forward Euler
    A0 = (2.0 / step) * I + A
    print(A0.shape)
    # Backward Euler
    D = torch.diag(1.0 / ((2.0 / step) - Lambda))
    print(D.shape)
    Qc = torch.conj(Q.unsqueeze(0))
    print(Qc.shape)
    P2 = P.reshape(-1, 1)
    A1 = D - (D @ P2 * (1.0 / (1 + (Qc @ D @ P2))) * Qc @ D)
    print(A1.shape)
    # A bar and B bar
    Ab = A1 @ A0
    Bb = 2 * A1 @ B

    # Recover Cbar from Ct
    Cb = Ct @ torch.linalg.inv(I - torch.matrix_power(Ab, L)).conj()
    return Ab, Bb, Cb.conj()

In [16]:
Ab, Bb, Cb = discrete_DPLR(test.Lambda, test.P, test.P, test.B, test.Ctilde, test.dt, 10)

torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])
torch.Size([10, 10])


In [17]:
Ab

tensor([[ 6.7959e-01-0.3620j, -7.7162e-02+0.0219j,  5.0469e-02-0.0087j,
         -3.7691e-02-0.0004j,  2.1994e-02+0.0162j,  1.1074e-02-0.0250j,
          3.2512e-02-0.0191j,  4.8135e-02-0.0175j,  7.7887e-02-0.0192j,
          2.2241e-01-0.0071j],
        [-7.5232e-02+0.0279j,  9.5305e-01-0.1394j,  1.8449e-02-0.0007j,
         -1.3449e-02-0.0020j,  7.0893e-03+0.0068j,  5.1457e-03-0.0084j,
          1.2527e-02-0.0053j,  1.8034e-02-0.0040j,  2.8749e-02-0.0032j,
          7.9818e-02+0.0081j],
        [ 4.7530e-02-0.0191j,  1.8192e-02-0.0032j,  9.7820e-01-0.0639j,
          8.6155e-03+0.0010j, -4.6412e-03-0.0042j, -3.1396e-03+0.0054j,
         -7.9033e-03+0.0036j, -1.1440e-02+0.0028j, -1.8291e-02+0.0025j,
         -5.1073e-02-0.0038j],
        [-3.2603e-02+0.0189j, -1.2917e-02+0.0042j,  8.4858e-03-0.0018j,
          9.8532e-01-0.0268j,  3.8287e-03+0.0026j,  1.7110e-03-0.0043j,
          5.3783e-03-0.0034j,  8.0331e-03-0.0033j,  1.3058e-02-0.0038j,
          3.7600e-02-0.0027j],
        [ 1.

In [64]:
bloop[1]

tensor(-0.0056-0.0181j, grad_fn=<SelectBackward0>)

In [63]:
Cb@torch.matrix_power(Ab, 1)@Bb

tensor([[-0.0056-0.0181j]], grad_fn=<MmBackward0>)

In [90]:
def K_conv(Ab, Bb, Cb, L):
        return torch.stack([(Cb @ torch.matrix_power(Ab, l) @ Bb).squeeze() for l in range(L)]
    )

K2 = K_conv(Ab, Bb, Cb, L=100)


In [91]:
K2

tensor([-0.1206+5.0381e-02j, -0.0295+4.9775e-03j,  0.0140-1.2264e-02j,
         0.0300-1.4495e-02j,  0.0314-9.9820e-03j,  0.0259-3.6309e-03j,
         0.0184+1.8998e-03j,  0.0111+5.4070e-03j,  0.0052+6.5853e-03j,
         0.0010+5.6423e-03j, -0.0016+3.0413e-03j, -0.0032-6.6516e-04j,
        -0.0039-4.9371e-03j, -0.0044-9.3039e-03j, -0.0047-1.3391e-02j,
        -0.0053-1.6927e-02j, -0.0061-1.9739e-02j, -0.0073-2.1739e-02j,
        -0.0088-2.2910e-02j, -0.0106-2.3289e-02j, -0.0127-2.2955e-02j,
        -0.0148-2.2010e-02j, -0.0169-2.0571e-02j, -0.0188-1.8759e-02j,
        -0.0206-1.6694e-02j, -0.0220-1.4486e-02j, -0.0231-1.2234e-02j,
        -0.0238-1.0022e-02j, -0.0241-7.9199e-03j, -0.0239-5.9814e-03j,
        -0.0233-4.2460e-03j, -0.0223-2.7393e-03j, -0.0209-1.4747e-03j,
        -0.0191-4.5492e-04j, -0.0171+3.2614e-04j, -0.0148+8.8182e-04j,
        -0.0123+1.2310e-03j, -0.0097+1.3965e-03j, -0.0069+1.4040e-03j,
        -0.0042+1.2803e-03j, -0.0014+1.0528e-03j,  0.0014+7.4817e-04j,
      