### Replay Buffer

In our previous implementation of parallelization, during the update step, we made use of experiences from only the most recent episodes. The above approach may lead to unstable training as the most recent episodes are likely to be correlated. To solve this, we can use a replay buffer (or an experience replay) that stores the past transitions of the form
 $(s, a, r, s')$.

 Briefly, instead of learning directly from the most recent experience, the agent samples a batch of experiences uniformly from this buffer. This mechanism helps to break correlations between consecutive samples and improves the stability of learning. By reusing past transitions multiple times, the agent can also learn more efficiently from limited interaction with the environment, which is especially important when data collection is costly.


In [3]:
#!pip install gymnax

Collecting gymnax
  Downloading gymnax-0.0.9-py3-none-any.whl.metadata (19 kB)
Downloading gymnax-0.0.9-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.6/86.6 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gymnax
Successfully installed gymnax-0.0.9


In [4]:
import jax
import jax.numpy as jnp
import optax
import gymnax
from gymnax.wrappers.purerl import FlattenObservationWrapper
import jax.tree_util as jtu

In [5]:
class ReplayBuffer:
    def __init__(self, capacity, obs_dim, n_envs):
        self.capacity = capacity
        self.n_envs = n_envs
        self.ptr = 0
        self.size = 0   # number of valid samples

        self.obs_buf = jnp.zeros((capacity, obs_dim))
        self.actions_buf = jnp.zeros((capacity,), dtype=jnp.int32)
        self.rewards_buf = jnp.zeros((capacity,))
        self.next_obs_buf = jnp.zeros((capacity, obs_dim))
        self.dones_buf = jnp.zeros((capacity,))

    def add(self, obs, actions, rewards, next_obs, dones):
        """
        Add a batch of transitions from all environments.
        obs: (n_envs, obs_dim)
        actions: (n_envs,)
        rewards: (n_envs,)
        next_obs: (n_envs, obs_dim)
        dones: (n_envs,)
        """
        n = obs.shape[0]
        idxs = (jnp.arange(n) + self.ptr) % self.capacity

        self.obs_buf = self.obs_buf.at[idxs].set(obs)
        self.actions_buf = self.actions_buf.at[idxs].set(actions)
        self.rewards_buf = self.rewards_buf.at[idxs].set(rewards)
        self.next_obs_buf = self.next_obs_buf.at[idxs].set(next_obs)
        self.dones_buf = self.dones_buf.at[idxs].set(dones)

        # advance pointer
        self.ptr = (self.ptr + n) % self.capacity
        # track current size
        self.size = min(self.size + n, self.capacity)

    def sample(self, rng, batch_size):
        max_size = self.size  # only sample from valid entries
        idxs = jax.random.randint(rng, (batch_size,), 0, max_size)
        return (self.obs_buf[idxs],
                self.actions_buf[idxs],
                self.rewards_buf[idxs],
                self.next_obs_buf[idxs],
                self.dones_buf[idxs])

    def __len__(self):
        return self.size


In [None]:
# Wrap the environment to ensure consistent observation shapes
env, env_params = gymnax.make("CartPole-v1")
env = FlattenObservationWrapper(env)
env_params = env_params

def rbf_features(x, centers, sigma=0.5):
    # x: (d,) or (batch_size, d), centers: (n_centers, d)

    # Normalize input to appropriate range for CartPole
    x= x / jnp.array([2.4, 3.0, 0.2, 3.0])  # CartPole observation scaling

    diffs = x[None] - centers if x.ndim == 1 else x[:, None] - centers
    sq_dist = jnp.sum(diffs**2, axis=-1)
    return jnp.exp(-sq_dist / (2 * sigma**2))

def init_params(rng, n_features, n_actions):
    W = jax.random.normal(rng, (n_features, n_actions)) * 0.1
    return W

def q_values(W, obs, centers, sigma=0.5):
    phi = rbf_features(obs, centers, sigma)  # (batch_size, n_features) or (n_features,)
    return jnp.dot(phi, W)  # (batch_size, n_actions) or (n_actions,)

def select_action(W, obs, rng, centers, sigma=0.5, epsilon=0.1):
    q = q_values(W, obs, centers, sigma)
    greedy = jnp.argmax(q, axis=-1)
    explore = jax.random.bernoulli(rng, epsilon, shape=greedy.shape)
    random_actions = jax.random.randint(rng, greedy.shape, 0, q.shape[-1])
    return jnp.where(explore, random_actions, greedy)

def td_loss(W, obs, action, reward, next_obs, done, gamma, centers, sigma):
    q = q_values(W, obs, centers, sigma)
    q_selected = jnp.take_along_axis(q, action[:, None], axis=-1).squeeze()

    next_q = jnp.max(q_values(W, next_obs, centers, sigma), axis=-1)
    target = reward + gamma * (1 - done) * next_q

    return jnp.mean(0.5 * (q_selected - target) ** 2)

def train_parallel_simple(num_episodes=500, lr=1e-2, gamma=0.99, n_centers=50,
                         sigma=0.5, num_envs=16, batch_size=16):
    obs_dim = env.observation_space(env_params).shape[0]
    n_actions = env.action_space(env_params).n

    rng = jax.random.PRNGKey(0)
    rng, centers_rng, init_rng = jax.random.split(rng, 3)

    # Random RBF centers
    centers = jax.random.uniform(centers_rng, (n_centers, obs_dim), minval=-1, maxval=1)
    W = init_params(init_rng, n_centers, n_actions)

    opt = optax.adam(lr)
    opt_state = opt.init(W)

    # Vectorized functions
    vmap_reset = jax.vmap(env.reset, in_axes=(0, None))
    vmap_step = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    vmap_select_action = jax.vmap(select_action, in_axes=(None, 0, 0, None, None, None))

    @jax.jit
    def update_batch(W, opt_state, obs, actions, rewards, next_obs, dones):
        def batch_loss(W):
            return td_loss(W, obs, actions, rewards, next_obs, dones, gamma, centers, sigma)

        grads = jax.grad(batch_loss)(W)
        updates, opt_state = opt.update(grads, opt_state, W)
        W = optax.apply_updates(W, updates)
        return W, opt_state

    # Initialize environments
    rng, *env_rngs = jax.random.split(rng, num_envs + 1)
    obs, states = vmap_reset(jnp.array(env_rngs), env_params)

    episode_rewards = jnp.zeros(num_envs)
    all_rewards = []

    buffer = ReplayBuffer(capacity=100000, obs_dim=obs_dim, n_envs=num_envs)   # initialize buffer

    for step in range(100000):  # Large enough to collect required episodes
        rng, action_rng = jax.random.split(rng)
        action_rngs = jax.random.split(action_rng, num_envs)

        # Select actions for all environments
        actions = vmap_select_action(W, obs, jnp.array(action_rngs), centers, sigma, epsilon)

        # Step all environments
        rng, *step_rngs = jax.random.split(rng, num_envs + 1)
        next_obs, next_states, rewards, dones, _ = vmap_step(
            jnp.array(step_rngs), states, actions, env_params
        )

        # Store experience in replay buffer
        buffer.add(obs, actions, rewards, next_obs, dones.astype(jnp.float32))

        # Update episode rewards
        episode_rewards += rewards

        # Record completed episodes and reset counters for done envs
        completed_mask = dones.astype(bool)
        completed_rewards = episode_rewards[completed_mask]
        all_rewards.extend(completed_rewards.tolist())

        episode_rewards = episode_rewards * (1 - dones.astype(jnp.float32))

        # Compute resets for all (unconditionally for simplicity)
        rng, reset_rng = jax.random.split(rng)
        reset_rngs = jax.random.split(reset_rng, num_envs)
        reset_obs, reset_states = vmap_reset(jnp.array(reset_rngs), env_params)

        # Apply resets only to done environments
        obs = jnp.where(dones[:, None], reset_obs, next_obs)
        states = jtu.tree_map(lambda r, n: jnp.where(dones, r, n), reset_states, next_states)

        # Train only if buffer is warm enough
        if buffer.size >= 1000:
            rng, sample_rng = jax.random.split(rng)
            batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = \
                buffer.sample(sample_rng, batch_size)

            W, opt_state = update_batch(
                W, opt_state,
                batch_obs, batch_actions, batch_rewards,
                batch_next_obs, batch_dones
            )

        # Print progress
        if step % 100 == 0:
            if all_rewards:
                recent_rewards = all_rewards[-100:] if len(all_rewards) > 100 else all_rewards
                avg_reward = jnp.mean(jnp.array(recent_rewards))
                print(f"Step {step}, Avg reward: {avg_reward:.1f}, "
                      f"Total episodes: {len(all_rewards)}")

                # Early stopping if solved
                if len(all_rewards) >= 100 and avg_reward >= 475.0:
                    print("CartPole solved!")
                    break
            else:
                print(f"Step {step}, No episodes completed yet")

        # Stop if we've collected enough episodes
        if len(all_rewards) >= num_episodes:
            break

    return W, centers, all_rewards

# Run training
W, centers, rewards = train_parallel_simple(num_episodes=50000, num_envs=16, n_centers=500, lr=5e-3, batch_size=32 )
print(f"Final average reward: {jnp.mean(jnp.array(rewards[-100:])):.1f}")

Step 0, No episodes completed yet
Step 100, Avg reward: 10.2, Total episodes: 153
Step 200, Avg reward: 17.7, Total episodes: 222
Step 300, Avg reward: 27.6, Total episodes: 258
Step 400, Avg reward: 39.0, Total episodes: 289
Step 500, Avg reward: 47.0, Total episodes: 314
Step 600, Avg reward: 53.7, Total episodes: 336
Step 700, Avg reward: 62.2, Total episodes: 353
Step 800, Avg reward: 68.0, Total episodes: 363
Step 900, Avg reward: 78.6, Total episodes: 376
Step 1000, Avg reward: 91.8, Total episodes: 391
Step 1100, Avg reward: 95.4, Total episodes: 396
Step 1200, Avg reward: 109.4, Total episodes: 410
Step 1300, Avg reward: 113.2, Total episodes: 414
Step 1400, Avg reward: 128.4, Total episodes: 425
Step 1500, Avg reward: 137.0, Total episodes: 433
Step 1600, Avg reward: 146.5, Total episodes: 440
Step 1700, Avg reward: 155.3, Total episodes: 447
Step 1800, Avg reward: 159.7, Total episodes: 450
Step 1900, Avg reward: 173.0, Total episodes: 458
Step 2000, Avg reward: 185.1, Total 

KeyboardInterrupt: 

## Epsilon Decay
So far, we have been keeping $\epsilon$ to be fixed (0.1 in the above case). However, we know that while exploration is essential for an agent to discover rewarding states and actions, maintaining a high level of exploration throughout training can prevent the agent from fully exploiting the knowledge it has acquired. To balance exploration and exploitation, we can gradually reduce the probability of taking random actions over time using an exponential decay schedule. Exponential decay decreases the exploration rate rapidly at the beginning when the agent knows very little about the environment, and then more slowly as training progresses. This allows the agent to explore widely in the early stages, while gradually shifting towards exploiting learned strategies once sufficient experience has been gathered.

Using exponential decay for exploration helps improve learning stability and efficiency. By reducing random actions in a controlled manner, the agent avoids being trapped in suboptimal behaviors caused by excessive exploration, while still retaining some chance to discover better actions later in training. This is especially true if the agent has almost "fully learned", in that case, we do not want the agent to explore as much.


In [8]:
# Wrap the environment to ensure consistent observation shapes
env, env_params = gymnax.make("CartPole-v1")
env = FlattenObservationWrapper(env)
env_params = env_params

def rbf_features(x, centers, sigma=0.5):
    # x: (d,) or (batch_size, d), centers: (n_centers, d)

    # Normalize input to appropriate range for CartPole
    x= x / jnp.array([2.4, 3.0, 0.2, 3.0])  # CartPole observation scaling

    diffs = x[None] - centers if x.ndim == 1 else x[:, None] - centers
    sq_dist = jnp.sum(diffs**2, axis=-1)
    return jnp.exp(-sq_dist / (2 * sigma**2))

def init_params(rng, n_features, n_actions):
    W = jax.random.normal(rng, (n_features, n_actions)) * 0.1
    return W

def q_values(W, obs, centers, sigma=0.5):
    phi = rbf_features(obs, centers, sigma)  # (batch_size, n_features) or (n_features,)
    return jnp.dot(phi, W)  # (batch_size, n_actions) or (n_actions,)

def select_action(W, obs, rng, centers, sigma=0.5, epsilon=0.1):
    q = q_values(W, obs, centers, sigma)
    greedy = jnp.argmax(q, axis=-1)
    explore = jax.random.bernoulli(rng, epsilon, shape=greedy.shape)
    random_actions = jax.random.randint(rng, greedy.shape, 0, q.shape[-1])
    return jnp.where(explore, random_actions, greedy)

def td_loss(W, obs, action, reward, next_obs, done, gamma, centers, sigma):
    q = q_values(W, obs, centers, sigma)
    q_selected = jnp.take_along_axis(q, action[:, None], axis=-1).squeeze()

    next_q = jnp.max(q_values(W, next_obs, centers, sigma), axis=-1)
    target = reward + gamma * (1 - done) * next_q

    return jnp.mean(0.5 * (q_selected - target) ** 2)

def epsilon_schedule(step, eps_start=1.0, eps_end=0.05, decay_rate=0.999):
    return jnp.maximum(eps_end, eps_start * (decay_rate ** step))

def train_parallel_simple(num_episodes=500, lr=1e-2, gamma=0.99, n_centers=50,
                         sigma=0.5, num_envs=16, batch_size=16):
    obs_dim = env.observation_space(env_params).shape[0]
    n_actions = env.action_space(env_params).n

    rng = jax.random.PRNGKey(0)
    rng, centers_rng, init_rng = jax.random.split(rng, 3)

    # Random RBF centers
    centers = jax.random.uniform(centers_rng, (n_centers, obs_dim), minval=-1, maxval=1)
    W = init_params(init_rng, n_centers, n_actions)

    opt = optax.adam(lr)
    opt_state = opt.init(W)

    # Vectorized functions
    vmap_reset = jax.vmap(env.reset, in_axes=(0, None))
    vmap_step = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    vmap_select_action = jax.vmap(select_action, in_axes=(None, 0, 0, None, None, None))

    @jax.jit
    def update_batch(W, opt_state, obs, actions, rewards, next_obs, dones):
        def batch_loss(W):
            return td_loss(W, obs, actions, rewards, next_obs, dones, gamma, centers, sigma)

        grads = jax.grad(batch_loss)(W)
        updates, opt_state = opt.update(grads, opt_state, W)
        W = optax.apply_updates(W, updates)
        return W, opt_state

    # Initialize environments
    rng, *env_rngs = jax.random.split(rng, num_envs + 1)
    obs, states = vmap_reset(jnp.array(env_rngs), env_params)

    episode_rewards = jnp.zeros(num_envs)
    all_rewards = []

    buffer = ReplayBuffer(capacity=100000, obs_dim=obs_dim, n_envs=num_envs)   # initialize buffer

    for step in range(10000):  # Large enough to collect required episodes
        rng, action_rng = jax.random.split(rng)
        action_rngs = jax.random.split(action_rng, num_envs)

        epsilon = epsilon_schedule(step, eps_start=1.0, eps_end=0.01, decay_rate = 0.999)

        # Select actions for all environments
        actions = vmap_select_action(W, obs, jnp.array(action_rngs), centers, sigma, epsilon)

        # Step all environments
        rng, *step_rngs = jax.random.split(rng, num_envs + 1)
        next_obs, next_states, rewards, dones, _ = vmap_step(
            jnp.array(step_rngs), states, actions, env_params
        )

        # Store experience in replay buffer
        buffer.add(obs, actions, rewards, next_obs, dones.astype(jnp.float32))

        # Update episode rewards
        episode_rewards += rewards

        # Record completed episodes and reset counters for done envs
        completed_mask = dones.astype(bool)
        completed_rewards = episode_rewards[completed_mask]
        all_rewards.extend(completed_rewards.tolist())

        episode_rewards = episode_rewards * (1 - dones.astype(jnp.float32))

        # Compute resets for all (unconditionally for simplicity)
        rng, reset_rng = jax.random.split(rng)
        reset_rngs = jax.random.split(reset_rng, num_envs)
        reset_obs, reset_states = vmap_reset(jnp.array(reset_rngs), env_params)

        # Apply resets only to done environments
        obs = jnp.where(dones[:, None], reset_obs, next_obs)
        states = jtu.tree_map(lambda r, n: jnp.where(dones, r, n), reset_states, next_states)

        # Train only if buffer is warm enough
        if buffer.size >= 1000:
            rng, sample_rng = jax.random.split(rng)
            batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = \
                buffer.sample(sample_rng, batch_size)

            W, opt_state = update_batch(
                W, opt_state,
                batch_obs, batch_actions, batch_rewards,
                batch_next_obs, batch_dones
            )

        # Print progress
        if step % 100 == 0:
            if all_rewards:
                recent_rewards = all_rewards[-100:] if len(all_rewards) > 100 else all_rewards
                avg_reward = jnp.mean(jnp.array(recent_rewards))
                print(f"Step {step}, Avg reward: {avg_reward:.1f}, "
                      f"Total episodes: {len(all_rewards)}")

                # Early stopping if solved
                if len(all_rewards) >= 100 and avg_reward >= 475.0:
                    print("CartPole solved!")
                    break
            else:
                print(f"Step {step}, No episodes completed yet")

        # Stop if we've collected enough episodes
        if len(all_rewards) >= num_episodes:
            break

    return W, centers, all_rewards

# Run training
W, centers, rewards = train_parallel_simple(num_episodes=50000, num_envs=16, n_centers=500, lr=5e-3, batch_size=32 )
print(f"Final average reward: {jnp.mean(jnp.array(rewards[-100:])):.1f}")

Step 0, No episodes completed yet
Step 100, Avg reward: 19.6, Total episodes: 72
Step 200, Avg reward: 21.2, Total episodes: 145
Step 300, Avg reward: 23.7, Total episodes: 207
Step 400, Avg reward: 26.3, Total episodes: 267
Step 500, Avg reward: 27.8, Total episodes: 313
Step 600, Avg reward: 32.1, Total episodes: 362
Step 700, Avg reward: 34.6, Total episodes: 398
Step 800, Avg reward: 39.9, Total episodes: 431
Step 900, Avg reward: 46.5, Total episodes: 455
Step 1000, Avg reward: 54.4, Total episodes: 481
Step 1100, Avg reward: 60.4, Total episodes: 502
Step 1200, Avg reward: 67.3, Total episodes: 523
Step 1300, Avg reward: 72.9, Total episodes: 536
Step 1400, Avg reward: 79.0, Total episodes: 548
Step 1500, Avg reward: 83.5, Total episodes: 559
Step 1600, Avg reward: 94.6, Total episodes: 570
Step 1700, Avg reward: 103.7, Total episodes: 577
Step 1800, Avg reward: 110.8, Total episodes: 586
Step 1900, Avg reward: 112.6, Total episodes: 588
Step 2000, Avg reward: 127.8, Total episod