In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from deluca.agents import BPC, LQR, BCOMC
from deluca.envs import LDS
import jax.numpy as jnp
import jax
import numpy as np

from jax.config import config
config.update("jax_debug_nans", True)
import matplotlib.pyplot as plt

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [3]:
def get_err(T, lds, controller, noise, random_key=jax.random.PRNGKey(42)):
    lds.reset()
    avg_err = 0
    err = 0
    err_list = []
    prev_noise = jnp.zeros(shape=lds.state.shape)
    key = random_key
    last_avg_err = 0
    for i in range(T):

        key, subkey = jax.random.split(key)
        try:
            action = controller(lds.obs, err)
        except:
            action = controller(lds.obs)
        lds.step(action)
        if noise == "Gaussian":
            lds.state += 0.03 * jax.random.normal(subkey, shape=lds.state.shape)
        elif noise == "Sinusoidal":
            lds.state += 0.03 * jnp.sin(i/(20 * np.pi))
        elif noise == "GaussianWalk":
            prev_noise = prev_noise + jax.random.normal(subkey, shape=lds.state.shape)
            lds.state += 0.03 * prev_noise / np.sqrt(T)
        elif noise == "None":
            pass
        else:
            raise ValueError("Noise type unrecognized!")
        err = (jnp.linalg.norm(lds.state)**2+jnp.linalg.norm(action)**2)
        last_avg_err += err
        if (i+1) % 1000 == 0:
            print(str(i+1) + " avg err:", last_avg_err/100)
            last_avg_err = 0
        avg_err += err/T
        err_list += [err]
    return avg_err, np.array(mavg_err)

In [4]:
T = 30000

system = "DI" # one of "DI", "LargeSparse"
noise = "Gaussian" # one of "Gaussian", "GaussianWalk", "SinusoidalMult", "None"

if system == "DI":
    A = jnp.array([[.9, .9], [-0.01, .9]])
    B = jnp.array([[0], [1]])
    print(j)
elif system == "LargeSparse":
    A = jnp.array([[.3, 0, 0, .4, .1], [0, .5, .5, .5, 0], [.05, .05, .05, .05, 0], [.3, 0, 0, 0, 0], [4, 0, 0, 0, .1]])
    B = jnp.array([[2, 0, 0], [0, .3, .1], [0, .1, .3], [0, 0, 0], [0, 0, 0]])
else:
    raise ValueError("System type unrecognized!")
print("System norm:", )
#A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])


NameError: name 'j' is not defined

In [11]:
bpc = BPC(A, B, lr_scale=1e-3, delta=1/T)
bpc_avg_err, bpc_errs = get_err(T, LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B), bpc, noise)
print("BPC incurs ", bpc_avg_err, " loss")

100 avg err: 0.771041270066844
200 avg err: 0.19414951006885012
300 avg err: 0.12073958253389992
400 avg err: 0.13725981515007962
500 avg err: 0.1422750765421913
600 avg err: 0.06425218170610322
700 avg err: 0.2560788158713332
800 avg err: 0.1375199891661202
900 avg err: 0.10497498466008158
1000 avg err: 0.08665424675935621
1100 avg err: 0.31052416522201354
1200 avg err: 0.19337453960780898
1300 avg err: 0.21027251134945954
1400 avg err: 0.1546169000922108
1500 avg err: 0.17928499528917247
1600 avg err: 0.16648399606095546
1700 avg err: 0.03143529289477646
1800 avg err: 0.03191903249288464
1900 avg err: 0.21983672885945826
2000 avg err: 0.13803569947885408
2100 avg err: 0.10747226932825331
2200 avg err: 0.09379352343373826
2300 avg err: 0.2774129280962635
2400 avg err: 0.06153195548376792
2500 avg err: 0.10587976498090428
2600 avg err: 0.06116887320920914
2700 avg err: 0.19431309781936407
2800 avg err: 0.06694583134398441
2900 avg err: 0.10226941462341418
3000 avg err: 0.07981022697097

KeyboardInterrupt: 

In [None]:
lqr = LQR(A, B)
lqr_avg_err, lqr_errs = get_err(T, LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B), lqr, noise)
print("LQR incurs ", lqr_avg_err, " loss")

In [None]:
bcomc = BCOMC(
    A = A,
    B = B,
    C = jnp.identity(A.shape[0]),
    T = T,
    H = 5,   # To change, but good for debug
    cost_bound = 1e5,
    R = 3,
    beta = 5.0,
    sigma = 0.2,
    eta_mul = 1e5,
    grad_mul = 1e-4, # Typically wants to decrease as eta_mul increases to offset its effect
)
bcomc_avg_err, bcomc_errs = get_err(T, LDS(state_size=B.shape[0], action_size=B.shape[1], A=A, B=B), bcomc, noise)
print("BCOMC incurs ", bcomc_avg_err, " loss")

eta: 0.04142526122144524
Step 100:
g: [[[ 0.09123193 -0.07614101]]

 [[-0.15030236  0.12236363]]

 [[ 0.12308322 -0.10191705]]

 [[ 0.08834518  0.32783431]]

 [[ 0.17312289 -0.12263036]]]
M: [[[ 0.12239645 -0.03141811]]

 [[-0.08970305  0.20401401]]

 [[ 0.39711544  0.31057542]]

 [[ 0.24793973 -0.1402913 ]]

 [[-0.46198872 -0.1698254 ]]]
~M: [[[ 1.46734599e-02 -3.22580796e-01]]

 [[-4.24792674e-02  6.22590352e-01]]

 [[ 9.01625607e-01  2.10592675e-01]]

 [[ 4.53652940e-04 -4.55881751e-01]]

 [[-8.52772710e-01 -4.27530348e-01]]]
y_nat: [[-0.1969698  -0.02743688]
 [-0.20458575  0.00233902]
 [-0.2048792  -0.00582721]
 [-0.13398146 -0.02204664]
 [-0.18888266 -0.11731445]]
100 avg err: 58.62587429863519
Step 200:
g: [[[ 0.00076924 -0.00033191]]

 [[-0.00108821  0.00079924]]

 [[ 0.00037185  0.0011395 ]]

 [[ 0.00204654 -0.0006921 ]]

 [[-0.00022659 -0.00067406]]]
M: [[[ 0.13048073  0.06519609]]

 [[-0.11485747  0.09047858]]

 [[ 0.3033751   0.32577187]]

 [[ 0.311596   -0.34332298]]

 [[-0

KeyboardInterrupt: 