In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from deluca.agents import BPC, LQR, BCOMC
from deluca.envs import LDS
import jax.numpy as jnp
import jax
import numpy as np

from jax.config import config
config.update("jax_debug_nans", True)
import matplotlib.pyplot as plt

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [3]:
def get_err(T, lds, controller, noise, random_key=jax.random.PRNGKey(42)):
    lds.reset()
    avg_err = 0
    err = 0
    mavg_err = []
    prev_noise = jnp.zeros(shape=lds.state.shape)
    key = random_key
    last_avg_err = 0
    for i in range(T):

        key, subkey = jax.random.split(key)
        try:
            action = controller(lds.obs, err)
        except:
            action = controller(lds.obs)
        lds.step(action)
        if noise == "Gaussian":
            lds.state += 0.03 * jax.random.normal(subkey, shape=lds.state.shape)
        elif noise == "Sinusoidal":
            lds.state += 0.03 * jnp.sin(i/(20 * np.pi))
        elif noise == "GaussianWalk":
            prev_noise = prev_noise + jax.random.normal(subkey, shape=lds.state.shape)
            lds.state += 0.03 * prev_noise / np.sqrt(T)
        elif noise == "None":
            pass
        else:
            raise ValueError("Noise type unrecognized!")
        err = (jnp.linalg.norm(lds.state)**2+jnp.linalg.norm(action)**2)
        last_avg_err += err
        if (i+1) % 100 == 0:
            print(str(i+1) + " avg err:", last_avg_err/100)
            last_avg_err = 0
        avg_err += err/T
        mavg_err += [err / (i+1)]
    return avg_err, np.array(mavg_err)

In [4]:
T = 100000

system = "DI" # one of "DI", "LargeSparse"
noise = "Gaussian" # one of "Gaussian", "GaussianWalk", "SinusoidalMult", "None"

if system == "DI":
    A = jnp.array([[.9, .9], [-0.01, .9]])
    B = jnp.array([[0], [1]])
elif system == "LargeSparse":
    A = jnp.array([[.3, 0, 0, .4, .1], [0, .5, .5, .5, 0], [.05, .05, .05, .05, 0], [.3, 0, 0, 0, 0], [4, 0, 0, 0, .1]])
    B = jnp.array([[2, 0, 0], [0, .3, .1], [0, .1, .3], [0, 0, 0], [0, 0, 0]])
else:
    raise ValueError("System type unrecognized!")
#A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])


In [None]:
bpc = BPC(A, B, lr_scale=1e-3, delta=1/T)
bpc_avg_err, bpc_mavg_err = get_err(T, LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B), bpc, noise)
print("BPC incurs ", bpc_avg_err, " loss")

In [None]:
lqr = LQR(A, B)
lqr_avg_err, lqr_mavg_err = get_err(T, LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B), lqr, noise)
print("LQR incurs ", lqr_avg_err, " loss")

In [6]:
bcomc = BCOMC(
    A = A,
    B = B,
    C = jnp.identity(A.shape[0]),
    T = T,
    H = 5,   # To change, but good for debug
    cost_bound = 1e5,
    R = 5,
    beta = 5.0,
    sigma = 0.2,
    eta_mul = 1e5,
    grad_mul = 1e-5, # Typically wants to decrease as eta_mul increases to offset its effect
)
bcomc_avg_err, bcomc_mavg_err = get_err(T, LDS(state_size=B.shape[0], action_size=B.shape[1], A=A, B=B), bcomc, noise)
print("BCOMC incurs ", bcomc_avg_err, " loss")

eta: 0.041425253237137964
