In [3]:
from jax import make_jaxpr, jit
import jax.numpy as jnp
from jax.lax import dynamic_update_slice

In [37]:
import time
from stable_baselines3 import PPO
import numpy as np

In [9]:
@jit
def test(obs, act, obs_buf, act_buf, buf_idx):
    obs_buf = dynamic_update_slice(obs_buf, obs, (buf_idx,))
    act_buf = dynamic_update_slice(act_buf, act, (buf_idx,))
    return obs_buf, act_buf

In [25]:
@jit
def _add_samples_cm(obs, act, obs_buf, act_buf, buf_idx):
    obs_buf = dynamic_update_slice(obs_buf, obs, (buf_idx, 0))
    act_buf = dynamic_update_slice(act_buf, act, (buf_idx, 0))
    return obs_buf, act_buf


In [11]:
print(make_jaxpr(test)(jnp.zeros(8), jnp.ones(8), jnp.ones(20), jnp.ones(20), 0))

{ lambda ; a:f32[8] b:f32[8] c:f32[20] d:f32[20] e:i32[]. let
    f:f32[20] g:f32[20] = xla_call[
      call_jaxpr={ lambda ; h:f32[8] i:f32[8] j:f32[20] k:f32[20] l:i32[]. let
          m:bool[] = lt l 0
          n:i32[] = add l 20
          o:i32[] = select m n l
          p:f32[20] = dynamic_update_slice j h o
          q:bool[] = lt l 0
          r:i32[] = add l 20
          s:i32[] = select q r l
          t:f32[20] = dynamic_update_slice k i s
        in (p, t) }
      name=test
    ] a b c d e
  in (f, g) }


In [12]:
obs_buf = jnp.zeros(100)
act_buf = jnp.zeros(100)

In [13]:
obs = jnp.ones(20)
act = jnp.ones(20)

In [14]:
test(obs, act, obs_buf, act_buf, 0)

(DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
              1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
              1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., 0.,

In [24]:
t = time.time()
buf_idx = 0
for _ in range(100_000):
    obs_buf, act_buf = test(obs, act, obs_buf, act_buf, buf_idx)
    buf_idx += 20
    buf_idx %= 100
print(f"seconds: {time.time() - t}")

seconds: 0.4274144172668457


In [105]:
from functools import partial
class CAMOC:
    def __init__(self):
        obs_size = 1*2 + (9)*2
        act_size = 2
        prealloc_size = 1009 * 200 * 20
        self._obs = jnp.empty((prealloc_size, obs_size))
        self._act = jnp.empty((prealloc_size, act_size))
        self._obs_idx = self._act_idx = 0
    @partial(jit, static_argnums=(0,))
    def add_samples(self, observations, actions):
        self._obs, self._act = _add_samples_cm(observations, actions, self._obs, self._act, self._obs_idx)
        #self._obs[self._obs_idx: self._obs_idx + observations.shape[0], :] = observations
        #self._act[self._act_idx: self._act_idx + actions.shape[0], :] = actions
        self._obs_idx += observations.shape[0]
        self._act_idx += actions.shape[0]

In [106]:
class Container:
    def __init__(self):
        self.cagent = CAMOC()

In [107]:
cagent = Container()

In [33]:
model = PPO.load("./policies/rotator_coverage_v0_2022_01_26_23_36")

In [34]:
from envs.rotator_coverage import rotator_coverage_v0
env = rotator_coverage_v0.env_eval()

[0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 1.  ]


  logger.warn(


In [108]:
s = time.time()
for tidx in range(50):
    if tidx % 10 == 0:
        print("Sampling trajectory {}".format(tidx))
    env.reset()
    for i, agent in enumerate(env.agent_iter()):
        obs, reward, done, info = env.last()
        act = model.predict(obs, deterministic=True)[0] if not done else None
        env.step(act)
        #print(f"Step: {i}")
        if not done:
            #pass
            #o = np.array([obs[-20:]])
            #a = np.array([act])
            #print(f"o: {o.shape} a: {a.shape}")
            #cagent.cagent.add_samples(jnp.asarray(np.array([obs[-20:]])), jnp.asarray(np.array([act])))
            #x = np.array([obs[-20:]])
            #y = np.array([act])
            cagent.cagent.add_samples(np.array([obs[-20:]]), np.array([act]))
print(f"time: {time.time() - s}")

Sampling trajectory 0
Sampling trajectory 10
Sampling trajectory 20
Sampling trajectory 30
Sampling trajectory 40
time: 25.102649211883545


In [51]:
import traceback
# 25.577797889709473 if not: pass
# 25.835766553878784 if not: make arrays
# 25.788559436798096 if not: call func (which passes)
# 24.67454433441162 if not: execute in numpy - yea wtf why is it faster
# 25.102649211883545 if not: excute in jax

In [54]:
env.reset()
for i, agent in enumerate(env.agent_iter()):
    obs, reward, done, info = env.last()
    act = model.predict(obs, deterministic=True)[0] if not done else None
    env.step(act)
    print(f"Step: {i}")
    if not done:
        o = np.array([obs[-20:]])
        a = np.array([act])
        print(f"o: {o.shape} a: {a.shape}")
        cagent.cagent.add_samples(jnp.asarray(np.array([obs[-20:]])), jnp.asarray(np.array([act])))
        break

Step: 0
o: (1, 20) a: (1, 2)
