In [1]:
import jax
from functools import partial
import jax.numpy as np
import jax.scipy as sp

In [2]:
rng_key = jax.random.PRNGKey(1)

In [30]:
isinstance(rng_key, jax.Array)

True

In [13]:
def random_SSM(rng : jax.Array, N : int) -> (jax.Array, jax.Array, jax.Array, jax.Array):
    a_r, b_r, c_r, d_r = jax.random.split(rng, 4)
    A = jax.random.uniform(a_r, (N,N))
    B = jax.random.uniform(a_r, (N,N))
    C = jax.random.uniform(a_r, (N,N))
    D = jax.random.uniform(a_r, (N,N))
    return A, B, C, D

In [14]:
A, B, C, D = random_SSM(rng_key, 10)

In [25]:
np.array(0.01)

Array(0.01, dtype=float32, weak_type=True)

In [33]:
def discretize(
    A : jax.Array, B : jax.Array, C : jax.Array, D : jax.Array, delta : jax.Array
) -> (jax.Array, jax.Array, jax.Array, jax.Array):
    one = np.eye(A.shape[0])
    left_term = np.linalg.inv(one - delta/2 * A)
    Abar = left_term@(one + delta/2 * A)
    Bbar = left_term@(delta*B)
    Cbar = C
    Dbar = D
    return Abar, Bbar, Cbar, Dbar

In [None]:
delta = np.array(0.01)

In [36]:
Abar, Bbar, Cbar, Dbar = discretize(A, B, C, D, np.array(0.01))