In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from deluca.agents import BPC, LQR, BCOMC
from deluca.envs import LDS
import jax.numpy as jnp
import jax
import numpy as np

from jax.config import config
config.update("jax_debug_nans", True)
import matplotlib.pyplot as plt

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [3]:
def get_err(T, lds, controller, noise, random_key=jax.random.PRNGKey(42)):
    lds.reset()
    avg_err = 0
    err = 0
    mavg_err = []
    prev_noise = jnp.zeros(shape=lds.state.shape)
    key = random_key
    last_avg_err = 0
    for i in range(T):

        key, subkey = jax.random.split(key)
        try:
            action = controller(lds.obs, err)
        except:
            action = controller(lds.obs)
        lds.step(action)
        if noise == "Gaussian":
            lds.state += 0.03 * jax.random.normal(subkey, shape=lds.state.shape)
        elif noise == "Sinusoidal":
            lds.state += 0.03 * jnp.sin(i/(20 * np.pi))
        elif noise == "GaussianWalk":
            prev_noise = prev_noise + jax.random.normal(subkey, shape=lds.state.shape)
            lds.state += 0.03 * prev_noise / np.sqrt(T)
        elif noise == "None":
            pass
        else:
            raise ValueError("Noise type unrecognized!")
        err = (jnp.linalg.norm(lds.state)**2+jnp.linalg.norm(action)**2)
        last_avg_err += err
        if (i+1) % 100 == 0:
            print(str(i+1) + " avg err:", last_avg_err/100)
            last_avg_err = 0
        avg_err += err/T
        mavg_err += [err / (i+1)]
    return avg_err, np.array(mavg_err)

In [17]:
T = 10000

system = "DI" # one of "DI", "LargeSparse"
noise = "Gaussian" # one of "Gaussian", "GaussianWalk", "SinusoidalMult", "None"

if system == "DI":
    A = jnp.array([[.9, .9], [-0.01, .9]])
    B = jnp.array([[0], [1]])
elif system == "LargeSparse":
    A = jnp.array([[.3, 0, 0, .4, .1], [0, .5, .5, .5, 0], [.05, .05, .05, .05, 0], [.3, 0, 0, 0, 0], [4, 0, 0, 0, .1]])
    B = jnp.array([[2, 0, 0], [0, .3, .1], [0, .1, .3], [0, 0, 0], [0, 0, 0]])
else:
    raise ValueError("System type unrecognized!")
#A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])


In [5]:
bpc = BPC(A, B, lr_scale=1e-3, delta=1/T)
bpc_avg_err, bpc_mavg_err = get_err(T, LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B), bpc, noise)
print("BPC incurs ", bpc_avg_err, " loss")

100 avg err: 0.030828321366530704
200 avg err: 0.013754223640200083
300 avg err: 0.010965379175152197
400 avg err: 0.013279678605238529
500 avg err: 0.013094804825300187
600 avg err: 0.010551015907132264
700 avg err: 0.01211077690174486
800 avg err: 0.019386534927830212
900 avg err: 0.011542501056410935
1000 avg err: 0.01506529791432944
BPC incurs  0.015057853431986954  loss


In [None]:
lqr = LQR(A, B)
lqr_avg_err, lqr_mavg_err = get_err(T, LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B), lqr, noise)
print("LQR incurs ", lqr_avg_err, " loss")

In [19]:
bcomc = BCOMC(
    A = A,
    B = B,
    C = jnp.identity(A.shape[0]),
    T = T,
    H = 5,
    cost_bound = 1e8,
    R = 1,
    beta = 5.0,
    sigma = 0.2,
    eta_mul = 2e8,
    grad_mul = 1e-1,  # Typically wants to increase with eta_mul to offset its effect
)
bcomc_avg_err, bcomc_mavg_err = get_err(T, LDS(state_size=B.shape[0], action_size=B.shape[1], A=A, B=B), bcomc, noise)
print("BCOMC incurs ", bcomc_avg_err, " loss")

eta: 0.00019999999999999998
g: [[[ 26.77901475 -22.70760958]]

 [[-44.95172476  35.98327219]]

 [[ 36.1068592  -31.919599  ]]

 [[ 25.93522728  97.77551218]]

 [[ 53.71376152 -36.30892413]]]
M: [[[ 0.02614009 -0.02659531]]

 [[ 0.00714975  0.07246782]]

 [[ 0.16204514  0.1547062 ]]

 [[ 0.15214833 -0.0621554 ]]

 [[-0.28529282 -0.09367642]]]
~M: [[[-0.04823428 -0.2171221 ]]

 [[ 0.03690505  0.34292375]]

 [[ 0.47894869  0.06929011]]

 [[-0.03166391 -0.26492609]]

 [[-0.50999738 -0.25378726]]]
y_nat: [[-0.1969698  -0.02743688]
 [-0.20458575  0.00233902]
 [-0.2048792  -0.00582721]
 [-0.13398146 -0.02204664]
 [-0.18888266 -0.11731445]]
100 avg err: 11.297239664148591
g: [[[ 3.43876412 -1.19852205]]

 [[-4.28127355  3.12726819]]

 [[ 2.01918204  5.05965247]]

 [[ 9.09704845 -3.87404619]]

 [[-2.17839065 -3.09290888]]]
M: [[[ 0.08244678  0.03652106]]

 [[ 0.03127728 -0.02709484]]

 [[ 0.13281522  0.1121101 ]]

 [[ 0.19581704 -0.26116443]]

 [[-0.31126749 -0.09151832]]]
~M: [[[ 0.01385163 -0