In [3]:
import jax
import jax.numpy as jnp

from qdax.core.neuroevolution.buffers.trajectory_buffer_v2 import TrajectoryBuffer, Transition

2025-11-22 15:33:33.473374: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [4]:
obs_dim = 3
action_dim = 2
buffer_size = 100
episode_length = 10
env_batch_size = 5

init_transition = Transition.init_dummy(observation_dim=obs_dim, action_dim=action_dim)
replay_buffer = TrajectoryBuffer.init(
    buffer_size=buffer_size,
    transition=init_transition,
    env_batch_size=env_batch_size,
    episode_length=episode_length
)
print(f"Buffer can hold {replay_buffer.num_trajectories} complete episodes")

Buffer can hold 10 complete episodes


In [3]:
# Insert episodes with alternating observations IN THE SAME BATCH
# This way zeros and ones are in different parallel environments
for i in range(2):
    # Create a batch where first half of envs have zeros, second half have ones
    obs_pattern = jnp.concatenate([
        jnp.zeros((episode_length, env_batch_size // 2 + 1, obs_dim)),  # First 3 envs: zeros
        jnp.ones((episode_length, env_batch_size // 2, obs_dim))  # Last 2 envs: ones
    ], axis=1)
    
    actions_pattern = jnp.concatenate([
        jnp.zeros((episode_length, env_batch_size // 2 + 1, action_dim)),
        jnp.ones((episode_length, env_batch_size // 2, action_dim))
    ], axis=1)
    
    rewards_pattern = jnp.concatenate([
        jnp.ones((episode_length, env_batch_size // 2 + 1)) * 0.1,
        jnp.ones((episode_length, env_batch_size // 2)) * 0.2
    ], axis=1)
    
    # Mark last step as done for all envs
    dones = jnp.zeros((episode_length, env_batch_size))
    dones = dones.at[-1, :].set(1.0)
    
    transitions = Transition(
        obs=obs_pattern,
        next_obs=obs_pattern + 0.5,
        rewards=rewards_pattern,
        dones=dones,
        truncations=jnp.zeros((episode_length, env_batch_size)),
        actions=actions_pattern
    )
    
    # Flatten to (episode_length * env_batch_size, ...)
    transitions = jax.tree_map(
        lambda x: x.reshape(-1, *x.shape[2:]) if x.ndim > 2 else x.reshape(-1),
        transitions
    )
    replay_buffer = replay_buffer.insert(transitions)

print(f"Buffer filled: {replay_buffer.current_size}/{replay_buffer.buffer_size}")
print(f"Complete episodes: {replay_buffer.current_episodic_data_size}/{replay_buffer.num_trajectories}")

Buffer filled: 100/100
Complete episodes: 10/10


  transitions = jax.tree_map(
  transitions = jax.tree_map(


In [4]:
random_key = jax.random.PRNGKey(0)

samples, random_key = replay_buffer.sample_n_step(
    random_key=random_key,
    sample_size=10,
    n_step=3,
)

  transitions = jax.tree_map(reshape_field, transitions)


In [5]:
print("Sampled observations shape:", samples.obs.shape)
print("\nFirst sampled n-step sequence (should show zeros or ones):")
print(samples.obs[0])
print("\nSecond sampled n-step sequence:")
print(samples.obs[1])
print("\nRewards from first sequence:")
print(samples.rewards[0])

Sampled observations shape: (10, 3, 3)

First sampled n-step sequence (should show zeros or ones):
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

Second sampled n-step sequence:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

Rewards from first sequence:
[0.1 0.1 0.1]


In [6]:
print(f"Current buffer size: {replay_buffer.current_size}")
print(f"Current episodic data size: {replay_buffer.current_episodic_data_size}")

Current buffer size: 100
Current episodic data size: 10


In [7]:
# Check first 20 entries in the data buffer
print("First 20 transitions in buffer (should show mix of zeros and ones):")
print("Observation column 0:", replay_buffer.data[:20, 0])

First 20 transitions in buffer (should show mix of zeros and ones):
Observation column 0: [0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1.]


In [8]:
# Check episodic data structure
print("Episodic data (indices of transitions belonging to each episode):")
print(replay_buffer.episodic_data[:10])

Episodic data (indices of transitions belonging to each episode):
[[ 0.  5. 10. 15. 20. 25. 30. 35. 40. 45.]
 [ 1.  6. 11. 16. 21. 26. 31. 36. 41. 46.]
 [ 2.  7. 12. 17. 22. 27. 32. 37. 42. 47.]
 [ 3.  8. 13. 18. 23. 28. 33. 38. 43. 48.]
 [ 4.  9. 14. 19. 24. 29. 34. 39. 44. 49.]
 [50. 55. 60. 65. 70. 75. 80. 85. 90. 95.]
 [51. 56. 61. 66. 71. 76. 81. 86. 91. 96.]
 [52. 57. 62. 67. 72. 77. 82. 87. 92. 97.]
 [53. 58. 63. 68. 73. 78. 83. 88. 93. 98.]
 [54. 59. 64. 69. 74. 79. 84. 89. 94. 99.]]


In [9]:
# Verify pattern by checking unique values
print("Checking data integrity:")
obs_first_dim = replay_buffer.data[:replay_buffer.current_size, 0]
unique_vals = jnp.unique(obs_first_dim[~jnp.isnan(obs_first_dim)])
print(f"Unique observation values in buffer: {unique_vals}")
print("Expected: [0., 1.] for mixed pattern")

# Show distribution
zeros_count = jnp.sum(obs_first_dim == 0.0)
ones_count = jnp.sum(obs_first_dim == 1.0)
print(f"\nZeros: {zeros_count}, Ones: {ones_count}")

Checking data integrity:
Unique observation values in buffer: [0. 1.]
Expected: [0., 1.] for mixed pattern

Zeros: 60, Ones: 40
Unique observation values in buffer: [0. 1.]
Expected: [0., 1.] for mixed pattern

Zeros: 60, Ones: 40


In [1]:
import jax
import jax.numpy as jnp
from brax.v1 import envs
from qdax.core.neuroevolution.buffers.trajectory_buffer_v2 import TrajectoryBuffer, Transition

buffer_size = 1000
episode_length = 50
env_batch_size = 5

env = envs.create("hopper", episode_length=episode_length)
random_key = jax.random.PRNGKey(1)
keys = jax.random.split(random_key, env_batch_size)

env_states = jax.vmap(env.reset)(keys)

2025-11-22 16:49:27.283144: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
init_transition = Transition.init_dummy(observation_dim=env.observation_size, action_dim=env.action_size)
replay_buffer = TrajectoryBuffer.init(
    buffer_size=buffer_size,
    transition=init_transition,
    env_batch_size=env_batch_size,
    episode_length=episode_length
)

In [3]:
def play_step_function(random_key, env_state):
    random_key, subkey = jax.random.split(random_key)
    rand_action = jax.random.uniform(subkey, (env.action_size,), minval=0.5, maxval=1.0)
    next_state = env.step(env_state, rand_action)
    transition = Transition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=rand_action,
        truncations=next_state.info["truncation"],
    )
    return random_key, next_state, transition


play_step_fn = jax.vmap(play_step_function)

def scan_play_step_fn(
    carry,
    unused,
):
    random_key, env_state, transitions = play_step_fn(*carry)
    
    return (random_key, env_state), transitions


In [4]:
keys, env_states, transitions = play_step_fn(keys, env_states)

In [37]:
for i in range(1000):
    (keys, env_states), transitions = jax.lax.scan(
        scan_play_step_fn,
        (keys, env_states),
        (),
        length=10
    )
    replay_buffer = replay_buffer.insert(transitions)

In [38]:
replay_buffer.episodic_data

Array([[550., 555., 560., 565., 570., 575., 580., 585., 590., 595., 600.,
        605., 610., 615., 620., 625., 630., 635., 640., 645., 650., 655.,
        660., 665., 670., 675., 680., 685.,  nan,  nan,  nan,  nan,  nan,
         nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
         nan,  nan,  nan,  nan,  nan,  nan],
       [646., 651., 656., 661., 666., 671., 676., 681., 686., 691., 696.,
        701., 706., 711., 716., 721., 726., 731., 736., 741., 746., 751.,
        756., 761., 766., 771., 776., 781.,  nan,  nan,  nan,  nan,  nan,
         nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
         nan,  nan,  nan,  nan,  nan,  nan],
       [447., 452., 457., 462., 467., 472., 477., 482., 487., 492., 497.,
        502., 507., 512., 517., 522., 527., 532., 537., 542., 547., 552.,
        557., 562., 567., 572., 577., 582., 587., 592.,  nan,  nan,  nan,
         nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,
         nan,  nan,  n

In [59]:
samples, random_key = replay_buffer.sample_n_step(random_key, sample_size=32, n_step=5)

  transitions = jax.tree_map(reshape_field, transitions)


In [60]:
print(samples.obs.shape)

(32, 5, 11)


In [64]:
jnp.sum(samples.obs[:, 1:] != samples.next_obs[:, :-1])

Array(0, dtype=int32)

In [62]:
mask = samples.obs[:, 1:] != samples.next_obs[:, :-1]
mask.shape

(32, 4, 11)

In [57]:
samples.dones[samples.obs[:, 1:] != samples.next_obs[:, :-1]]

IndexError: boolean index did not match shape of indexed array in index 0: got (32, 4, 11), expected (32, 5)

In [55]:
jnp.sum(samples.obs[:, 1:] == samples.next_obs[:, :-1])

Array(1309, dtype=int32)