### Synchronous parallel environment

In our previous implementation, we run one episode at a time. To accelerate learning, we now leverage synchronous parallel environment execution, utilizing GPU parallelism to run multiple episodes simultaneously. This approach allows us to gather diverse experiences more rapidly within the same time frame.

The psuedocode of the below code is as follow:
1. Execute `num_envs` (default: 16) environments in parallel using JAX's vectorization (vmap) on the GPU.

2. Collect and store the experiences (states, actions, rewards, next states) from all parallel environments at each step.

3. Once the total number of collected experiences reaches `batch_size`, compute the Temporal Difference (TD) loss by averaging the error across the entire batch.

4. Perform a gradient-based update using this averaged loss and repeat the process.

This method provides two key advantages:
- Faster experience collection: By running multiple environments in parallel, we generate a larger and more diverse set of experiences without increasing wall-clock time.

- More stable and accurate Updates: Calculating the TD error over a batch of experiences, rather than a single instance, provides a better estimate of the true gradient, leading to more robust and effective policy updates.

In [None]:
import jax
import jax.numpy as jnp
import optax
import gymnax
from gymnax.wrappers.purerl import FlattenObservationWrapper
import jax.tree_util as jtu

# Wrap the environment to ensure consistent observation shapes
env, env_params = gymnax.make("CartPole-v1")
env = FlattenObservationWrapper(env)
env_params = env_params

def rbf_features(x, centers, sigma=0.5):
    # x: (d,) or (batch_size, d), centers: (n_centers, d)

    # Normalize input to appropriate range for CartPole
    x= x / jnp.array([2.4, 3.0, 0.2, 3.0])  # CartPole observation scaling

    diffs = x[None] - centers if x.ndim == 1 else x[:, None] - centers
    sq_dist = jnp.sum(diffs**2, axis=-1)
    return jnp.exp(-sq_dist / (2 * sigma**2))

def init_params(rng, n_features, n_actions):
    W = jax.random.normal(rng, (n_features, n_actions)) * 0.1
    return W

def q_values(W, obs, centers, sigma=0.5):
    phi = rbf_features(obs, centers, sigma)  # (batch_size, n_features) or (n_features,)
    return jnp.dot(phi, W)  # (batch_size, n_actions) or (n_actions,)

def select_action(W, obs, rng, centers, sigma=0.5, epsilon=0.1):
    q = q_values(W, obs, centers, sigma)
    greedy = jnp.argmax(q, axis=-1)
    explore = jax.random.bernoulli(rng, epsilon, shape=greedy.shape)
    random_actions = jax.random.randint(rng, greedy.shape, 0, q.shape[-1])
    return jnp.where(explore, random_actions, greedy)

def td_loss(W, obs, action, reward, next_obs, done, gamma, centers, sigma):
    q = q_values(W, obs, centers, sigma)
    q_selected = jnp.take_along_axis(q, action[:, None], axis=-1).squeeze()

    next_q = jnp.max(q_values(W, next_obs, centers, sigma), axis=-1)
    target = reward + gamma * (1 - done) * next_q

    return jnp.mean(0.5 * (q_selected - target) ** 2)

def train_parallel_simple(num_episodes=500, lr=1e-2, gamma=0.99, n_centers=50,
                         sigma=0.5, num_envs=16, batch_size=16):
    obs_dim = env.observation_space(env_params).shape[0]
    n_actions = env.action_space(env_params).n

    rng = jax.random.PRNGKey(0)
    rng, centers_rng, init_rng = jax.random.split(rng, 3)

    # Random RBF centers
    centers = jax.random.uniform(centers_rng, (n_centers, obs_dim), minval=-1, maxval=1)
    W = init_params(init_rng, n_centers, n_actions)

    opt = optax.adam(lr)
    opt_state = opt.init(W)

    # Vectorized functions
    vmap_reset = jax.vmap(env.reset, in_axes=(0, None))
    vmap_step = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    vmap_select_action = jax.vmap(select_action, in_axes=(None, 0, 0, None, None, None))

    @jax.jit
    def update_batch(W, opt_state, batch_obs, batch_actions, batch_rewards,
                    batch_next_obs, batch_dones):
        def batch_loss(W):
            obs = batch_obs.reshape(-1, obs_dim)
            actions = batch_actions.reshape(-1)
            rewards = batch_rewards.reshape(-1)
            next_obs = batch_next_obs.reshape(-1, obs_dim)
            dones = batch_dones.reshape(-1)
            return td_loss(W, obs, actions, rewards, next_obs, dones, gamma, centers, sigma)

        grads = jax.grad(batch_loss)(W)
        updates, opt_state = opt.update(grads, opt_state, W)
        W = optax.apply_updates(W, updates)
        return W, opt_state

    # Initialize environments
    rng, *env_rngs = jax.random.split(rng, num_envs + 1)
    obs, states = vmap_reset(jnp.array(env_rngs), env_params)

    episode_rewards = jnp.zeros(num_envs)
    all_rewards = []

    # Initialize batch storage
    batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = [], [], [], [], []

    for step in range(100000):  # Large enough to collect required episodes
        rng, action_rng = jax.random.split(rng)
        action_rngs = jax.random.split(action_rng, num_envs)

        # Select actions for all environments
        actions = vmap_select_action(W, obs, jnp.array(action_rngs), centers, sigma, 0.1)

        # Step all environments
        rng, *step_rngs = jax.random.split(rng, num_envs + 1)
        next_obs, next_states, rewards, dones, _ = vmap_step(
            jnp.array(step_rngs), states, actions, env_params
        )

        # Store experience for batch update
        batch_obs.append(obs)
        batch_actions.append(actions)
        batch_rewards.append(rewards)
        batch_next_obs.append(next_obs)
        batch_dones.append(dones.astype(jnp.float32))  # Convert to float for consistency

        # Update episode rewards
        episode_rewards += rewards

        # Record completed episodes and reset counters for done envs
        completed_mask = dones.astype(bool)
        completed_rewards = episode_rewards[completed_mask]
        all_rewards.extend(completed_rewards.tolist())

        episode_rewards = episode_rewards * (1 - dones.astype(jnp.float32))

        # Compute resets for all (unconditionally for simplicity)
        rng, reset_rng = jax.random.split(rng)
        reset_rngs = jax.random.split(reset_rng, num_envs)
        reset_obs, reset_states = vmap_reset(jnp.array(reset_rngs), env_params)

        # Apply resets only to done environments
        obs = jnp.where(dones[:, None], reset_obs, next_obs)
        states = jtu.tree_map(lambda r, n: jnp.where(dones, r, n), reset_states, next_states)

        # Perform batch update when we have enough experience
        if len(batch_obs) >= batch_size:
            batch_obs_arr = jnp.stack(batch_obs)
            batch_actions_arr = jnp.stack(batch_actions)
            batch_rewards_arr = jnp.stack(batch_rewards)
            batch_next_obs_arr = jnp.stack(batch_next_obs)
            batch_dones_arr = jnp.stack(batch_dones)

            W, opt_state = update_batch(
                W, opt_state, batch_obs_arr, batch_actions_arr,
                batch_rewards_arr, batch_next_obs_arr, batch_dones_arr
            )

            # Clear batch
            batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = [], [], [], [], []

        # Print progress
        if step % 100 == 0:
            if all_rewards:
                recent_rewards = all_rewards[-100:] if len(all_rewards) > 100 else all_rewards
                avg_reward = jnp.mean(jnp.array(recent_rewards))
                print(f"Step {step}, Avg reward: {avg_reward:.1f}, "
                      f"Total episodes: {len(all_rewards)}")

                # Early stopping if solved
                if len(all_rewards) >= 100 and avg_reward >= 475.0:
                    print("CartPole solved!")
                    break
            else:
                print(f"Step {step}, No episodes completed yet")

        # Stop if we've collected enough episodes
        if len(all_rewards) >= num_episodes:
            break

    return W, centers, all_rewards

# Run training
W, centers, rewards = train_parallel_simple(num_episodes=50000, num_envs=16, n_centers=500, lr=5e-3, batch_size=16)
print(f"Final average reward: {jnp.mean(jnp.array(rewards[-100:])):.1f}")

Step 0, No episodes completed yet
Step 100, Avg reward: 20.7, Total episodes: 57
Step 200, Avg reward: 27.3, Total episodes: 99
Step 300, Avg reward: 16.9, Total episodes: 219
Step 400, Avg reward: 21.8, Total episodes: 257
Step 500, Avg reward: 36.8, Total episodes: 306
Step 600, Avg reward: 11.3, Total episodes: 448
Step 700, Avg reward: 16.5, Total episodes: 509
Step 800, Avg reward: 23.8, Total episodes: 519
Step 900, Avg reward: 41.4, Total episodes: 532
Step 1000, Avg reward: 58.5, Total episodes: 565
Step 1100, Avg reward: 67.5, Total episodes: 582
Step 1200, Avg reward: 74.5, Total episodes: 589
Step 1300, Avg reward: 96.4, Total episodes: 606
Step 1400, Avg reward: 99.4, Total episodes: 617
Step 1500, Avg reward: 95.4, Total episodes: 628
Step 1600, Avg reward: 98.7, Total episodes: 653
Step 1700, Avg reward: 104.4, Total episodes: 662
Step 1800, Avg reward: 109.5, Total episodes: 664
Step 1900, Avg reward: 123.6, Total episodes: 676
Step 2000, Avg reward: 134.1, Total episode