# Single-taks PPO

In [None]:
# jax is already installed on the colab, uncomment only if needed
# !pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# !pip install 'xminigrid[baselines]'
!pip install "xminigrid[baselines] @ git+https://github.com/corl-team/xland-minigrid.git"

Collecting xminigrid[baselines]@ git+https://github.com/corl-team/xland-minigrid.git
  Cloning https://github.com/corl-team/xland-minigrid.git to /tmp/pip-install-j09355fw/xminigrid_49a68358bba44f51bedfedf85d836ea7
  Running command git clone --filter=blob:none --quiet https://github.com/corl-team/xland-minigrid.git /tmp/pip-install-j09355fw/xminigrid_49a68358bba44f51bedfedf85d836ea7
  Resolved https://github.com/corl-team/xland-minigrid.git to commit 991f13c7885c24c82302a1ee3a68a24a29801a94
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting matplotlib>=3.7.2 (from xminigrid[baselines]@ git+https://github.com/corl-team/xland-minigrid.git)
  Downloading matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.

In [None]:
import time
import math
from typing import TypedDict, Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import flax
import flax.linen as nn
import distrax
import optax
import imageio
import wandb
import matplotlib.pyplot as plt

from flax import struct
from flax.linen.initializers import glorot_normal, orthogonal, zeros_init
from flax.training.train_state import TrainState
from flax.jax_utils import replicate, unreplicate
from dataclasses import asdict, dataclass
from functools import partial

import xminigrid
from xminigrid.environment import Environment, EnvParams
from xminigrid.wrappers import GymAutoResetWrapper

### Networks

In [None]:
# Model adapted from minigrid baselines:
# https://github.com/lcswillems/rl-starter-files/blob/master/model.py

# custom RNN cell, which is more convenient than default in flax
# GRU helps solve the vanishing gradient problem and can capture dependencies from long input seq
class GRU(nn.Module):  # inherits from `nn.Module` which is a base class for all NN in Flax
    hidden_dim: int  # attribute of the GRU class, size of the hidden state of the GRU cell

    @nn.compact
    def __call__(self, xs, init_state):  # allows the GRU class instances to be called like a function
        seq_len, input_dim = xs.shape  # extracts the sequence length and input dimension from the shape of the input tensor `xs`

        # this init might not be optimal, for example, bias for reset gate should be -1 (for now ok)

        # weight matrix for input, used to transform the input `x` at each step
        # the shape `(hidden_dim * 3, input_dim)` reflects the three gates in a GRU cell: reset, update, and new gates
        # `glorot_normal`, also known as Xavier normal initializer, used here for `Wi` to help keeping the gradient
        # magnitudes in a reasonable range during the beginning of training
        Wi = self.param("Wi", glorot_normal(in_axis=1, out_axis=0), (self.hidden_dim * 3, input_dim))

        # weight matrix for the hidden state, used to transform the previous hidden state `h`
        # same reasoning for the size `(hidden_dim * 3, hidden_dim)`
        # `orthogonal` used for `Wh`, ensuring that the different hidden units start off in a decorrelated manner,
        # which can help in learning diverse features in the early stages
        Wh = self.param("Wh", orthogonal(column_axis=0), (self.hidden_dim * 3, self.hidden_dim))

        # bias vector for the input transformations
        # the shape `(hidden_dim * 3,)` supports the three gates
        # `zeros_init()` initializes biases with zeros
        bi = self.param("bi", zeros_init(), (self.hidden_dim * 3,))

        # bias for the new gate transformation, specially after applying the rest gate operation
        # `zeros_init()` initializes biases with zeros
        bn = self.param("bn", zeros_init(), (self.hidden_dim,))

        # this is where the actual computations for the GRU gates occur at each step in the input seq
        def _step_fn(h, x):  # h is the current hidden state of the GRU cell, x is the current input at this step of the seq

            # gate pre-activations
            igates = jnp.split(Wi @ x + bi, 3)  # input transformations, three parts of the result corresponding to one of GRU's gates before applying activations functions
            hgates = jnp.split(Wh @ h, 3)  # hidden state transformations

            # gate activations

            # controls how much of the past info (prev hidden state) will be forgotten
            reset = nn.sigmoid(igates[0] + hgates[0])  #  output between 0 (forget the past) and 1 (keep past info)
            # determines how much of old hidden state should be retained and how much of the new candidate state should be used to update the hidden state
            update = nn.sigmoid(igates[1] + hgates[1])  # output between 0 and 1
            # candidate activation for the hidden state at the current timestep; the rest modulation allows the cell to drop info from the prev state if the reset gate is activated
            new = nn.tanh(igates[2] + reset * (hgates[2] + bn))  # output between -1 and 1
            # hidden state update; if update close to 1, the new state is similar to the old state
            next_h = (1 - update) * new + update * h

            return next_h, next_h  # this is for jax.lax.scan, as the function must return a tuple of two elem (an updated carry and an output elem)

        # lax.scan iterates over the seq `xs`, and for each `x` in `xs`, it calls `_step_fn(h, x)`, where `h` is the hidden state carried from the prev iteration
        # then `_step_fn` calculates the next hidden state based on the current hidden state and input; this updated state is then passed as the hidden state for the next iter
        # `all_states`: for each application of `_step_fn`, alongside updating the carry (hidden state), it also produces and output (in this case, the same updated hidden state)
        # `lax.scan` collects these outputs across all iterations into the array `all_states`. This array will contain the seq of all hidden states computed for each timestep
        # `last_state`: after processing all elem of `xs`, the final output of the carry (last output from `_step_fn`) is returned as `last_state`
        last_state, all_states = jax.lax.scan(_step_fn, init=init_state, xs=xs)
        return all_states, last_state

In [None]:
# construct a multi-later RNN using our custum GRU cells
class RNNModel(nn.Module):
  hidden_dim: int  # dim of the hidden state for each GRU cell in the network
  num_layers: int  # num of GRU layers in the model

  @nn.compact
  def __call__(self, xs, init_state):
      # xs: [seq_len, input_dim]
      # init_state: [num_layers, hidden_dim]
      outs, states = [], []  # `outs` collect the outputs from each layer; `states` collect the final states from each layer

      # iterates through each GRU layer
      for layer in range(self.num_layers):
          xs, state = GRU(hidden_dim=self.hidden_dim)(xs, init_state[layer])
          outs.append(xs)
          states.append(state)

      # sum outputs from all layers, kinda like in ResNet
      # the summed outputs is from all layers, providing a single output seq where each elem is an aggregate of info from all layers
      # layer states is the array of final states from each layer, which is useful for continuing the seq processing in subsequent steps or for debugging
      return jnp.array(outs).sum(0), jnp.array(states)

In [None]:
# transform `RNNModel` into a batch-processing model, enabling parallel processing of multiple sequences
# `variable_axes` specifies how the axes of the param should be treated when vectorizing; setting `{"params": None}` indicates that the param do not get an added batch dim (the same set of param is used across all items in the batch)
# `split_rngs`, setting `{"params": False}` means that the RNGs used in the model (if any) are not split along the batch axis, meaning that the same RNG state is used across the entire batch
BatchedRNNModel = flax.linen.vmap(
    RNNModel, variable_axes={"params": None}, split_rngs={"params": False}, axis_name="batch"
)

In [None]:
class ActorCriticInput(TypedDict):
    observation: jax.Array
    prev_action: jax.Array
    prev_reward: jax.Array

class ActorCriticRNN(nn.Module):
    num_actions: int
    action_emb_dim: int = 16
    rnn_hidden_dim: int = 64
    rnn_num_layers: int = 1
    head_hidden_dim: int = 64
    img_obs: bool = False

    @nn.compact
    def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:
        B, S = inputs["observation"].shape[:2]
        # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py
        if self.img_obs:
            img_encoder = nn.Sequential(
                [
                    nn.Conv(16, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                    nn.relu,
                    nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                    nn.relu,
                    nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                    nn.relu,
                    nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                ]
            )
        else:
            img_encoder = nn.Sequential(
                [
                    nn.Conv(16, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                    nn.relu,
                    nn.Conv(32, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                    nn.relu,
                    nn.Conv(64, (2, 2), padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
                    nn.relu,
                ]
            )

        # the action encoder transforms discrete action indices into a continuous embedding space
        # `self.num_actions` is the total number of possible actions, which defines the size of the input dimension to the embedding layer
        # `self.action_emb_dim` is the dimensionality of the embedding space, which is the size of the vector that each action will be transformed into
        action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)

        # RNN core processes sequences of embedded observations, actions, and possibly other inputs through multiple layers of an RNN
        rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)

        # actor network computes a set of action logits based on the RNN's output
        # these logits will determine the policy by defining the probability distribution over possible actions
        actor = nn.Sequential(
            [
                nn.Dense(self.head_hidden_dim, kernel_init=orthogonal(2)),  # orthogonal initializer helps in maintaining the diversity of initial weights and prevents the vanishing/exploding gradient problem
                nn.tanh,
                nn.Dense(self.num_actions, kernel_init=orthogonal(0.01)),
            ]
        )

        # critic network estimates the value function from the state represented by the RNN's output
        # this value estimation is crucial for calculating the advantages used in training the policy
        # theh final layer outputs a single value, representing the expected return (value function) from the current state
        critic = nn.Sequential(
            [
                nn.Dense(self.head_hidden_dim, kernel_init=orthogonal(2)),
                nn.tanh,
                nn.Dense(1, kernel_init=orthogonal(1.0)),
            ]
        )

        # [batch_size, seq_len, ...]
        # `img_encoder` processes the raw observation; after encoding, the data is reshaped
        # `-1` allows the array to automatically calculate the necessary size to maintain the same total number of elements within each sample in the batch
        obs_emb = img_encoder(inputs["observation"]).reshape(B, S, -1)
        # this transforms the previous actions into a dense vector representation using an embedding layer
        act_emb = action_encoder(inputs["prev_action"])

        # [batch_size, seq_len, hidden_dim + act_emb_dim + 1]
        # concatenates the embedded observations, actions, and rewards into a single tensor `out`
        # `inputs["prev_reward"][..., None]` extends the reward tensor along a new axis so it can be concatenated with the observation and action embeddings
        # overall, this concatenates along the last axis, effectively creating a single feature vector per timestep that combines info from observations, actions, and rewards
        out = jnp.concatenate([obs_emb, act_emb, inputs["prev_reward"][..., None]], axis=-1)

        # core networks
        # processing with RNN core: the RNN takes the current input and the previous hidden state `hidden` to produce the output for the current timestep and the new hidden state for the next timestep
        out, new_hidden = rnn_core(out, hidden)
        # actor and critic networks
        # passes the RNN output through the actor network, which predicts a set of action logits.
        # These logits are used to form a probability distribution over possible actions
        # then we convert the logits into a categorical distribution, which is used for sampling actions during training and inference
        dist = distrax.Categorical(logits=actor(out))
        # the same RNN output is also passed to the critic network, which outputs a value estimate
        # this estimate represents the expected sum of future rewards from the current state, crucial for the training of the actor network via policy gradients
        values = critic(out)

        return dist, jnp.squeeze(values, axis=-1), new_hidden

    # `carry` in JAX is just initial hidden state
    def initialize_carry(self, batch_size):
        return jnp.zeros((batch_size, self.rnn_num_layers, self.rnn_hidden_dim))


### Utils

In [None]:
class Transition(struct.PyTreeNode):
    done: jax.Array
    action: jax.Array
    value: jax.Array
    reward: jax.Array
    log_prob: jax.Array
    obs: jax.Array
    # for rnn policy
    prev_action: jax.Array
    prev_reward: jax.Array

In [None]:
# calculate the Generalized Advantage Estimation
# the advantage function helps indicate how much better taking a particular action is compared to the policy's average
def calculate_gae(
    transitions: Transition,  # encapsulate the state transitions including rewards, values, actions
    last_val: jax.Array,  # the value estimate of the terminal state of the sequence, used to bootstrap the advantage estimation for the last transition
    gamma: float,  # discount factor to discount future rewards
    gae_lambda: float,  # smoothing param for GAE, controls the bias-variance tradeoff in the estimation; a lambda closer to 1 results in lower bias and higher variance
) -> tuple[jax.Array, jax.Array]:
    # single iteration for the loop
    # this inner function is designed to compute the advantage at each timestep, working backward from the last timestep to the first
    def _get_advantages(gae_and_next_value, transition):  # `gae_and_next_value` is a tuple containing the current estimate of GAE and the value of the next state; `transition` is the transition data for the current timestep
        gae, next_value = gae_and_next_value
        # `delta` is the temporal difference error for the current state; (1 - transition.done) indicates whether the state is terminal
        delta = transition.reward + gamma * next_value * (1 - transition.done) - transition.value
        gae = delta + gamma * gae_lambda * (1 - transition.done) * gae
        return (gae, transition.value), gae

    # apply `_get_advantages` across all transitions and processing them in reverse
    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val),
        transitions,
        reverse=True,
    )

    # advantages and values (Q)
    # advantages are used to weight the policy gradient updates, helping to focus learning on the most beneficial actions
    # values (Q) is used for actor-critic updates
    return advantages, advantages + transitions.value



In [None]:
# to perform a single update of the network (both actor and critic) in a PPO setup using a batch of transitions collected from the environment
def ppo_update_networks(
    train_state: TrainState,
    transitions: Transition,
    init_hstate: jax.Array,  # initial hidden states of the RNN
    advantages: jax.Array,   # advantages for eact action, used to weight the updates to the policy
    targets: jax.Array,      # target values for the value function, used to compute the value loss
    clip_eps: float,         # the clipping parameter for PPO, helps control the change to the policy by clipping the ratio of the new to old policy probabilities
    vf_coef: float,          # coefficient for the value loss in the total loss calculation, balancing the importance of the value loss relative to other components
    ent_coef: float,         # coefficient for the entropy bonus, encouraging exploration by rewarding higher entropy in the policy distribution
):
    # NORMALIZE ADVANTAGES
    # this reduces variance in training and scales the updates during training
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # computes the losses for both the actor and critic networks
    def _loss_fn(params):
        # RETURN NETWORK
        # apply the network model to get the distribution, value predictions, and potentially updated hidden states from the current params and observations
        dist, value, _ = train_state.apply_fn(
            params,
            {
                # [batch_size, seq_len, ...]
                "observation": transitions.obs,
                "prev_action": transitions.prev_action,
                "prev_reward": transitions.prev_reward,
            },
            init_hstate,
        )
        # log probability of the taken actions under the current policy
        log_prob = dist.log_prob(transitions.action)

        # CALCULATE VALUE LOSS
        # for updating the critic component of the actor-critic architecture
        # the critic's job is to estimate the value function, which quantifies the expected return from a given state under the current policy
        # `transitions.value` is the value estimates at each state as computed by the critic network at the time the actions were taken (old value estimates)
        # `value` is the current value estimates computed using the updated critic network parameters based on the current model's weights
        # the difference between the new value estimates and the old ones (`value - transitions.value`) is clipped to lie within the range `[-clip_eps, clip_eps]`
        # the clipping is done to control the magnitude of the update to the value function, preventing large jumps that can destabilize training by ensuring that the new value predictions do not deviate too much from the old predictions
        value_pred_clipped = transitions.value + (value - transitions.value).clip(-clip_eps, clip_eps)
        value_loss = jnp.square(value - targets)
        value_loss_clipped = jnp.square(value_pred_clipped - targets)
        # for each instance, the higher of the two losses (unlipped and clipped value losses) is used
        # this ensures that the loss reflects the worst-case scenario between moving too far from the old value estimate and diverging from the target value
        # `0.5 *` scales down the loss, which is common in quadratic loss functions to balance the gradient magnitudes during backpropogation
        value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()

        # CALCULATE ACTOR LOSS
        # key to updating the policy in a way that improves decision-making while controlling for overly aggressive updates
        # `log_prob` is the log probability of the actions as predicted by the current policy
        # `transitions.log_prob` is the log probability of the actions as predicted by the policy at the time the action was taken (stored from when the transitions were generated)
        ratio = jnp.exp(log_prob - transitions.log_prob)
        # this part of the actor loss is directly proportional to the advantage and the probability ratio; represents how much better or worse the action is under the new policy compared to the old, weighted by the advantage of each action
        actor_loss1 = advantages * ratio
        # applies clipping to the ratio to ensure that the updates are bounded
        actor_loss2 = advantages * jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
        # final actor loss; minimum here ensures that the loss used does not exceed the clipped version
        actor_loss = -jnp.minimum(actor_loss1, actor_loss2).mean()
        # `entropy` is calculated from the distribution which reflects the policy's uncertainty. A higher entropy suggests more randomness in the action selection, which can be beneficial for exploration
        entropy = dist.entropy().mean()

        # `vf_coef` (value function coefficient) determines how much weight the value loss has relative to the actor loss in the total loss calculation
        # the value loss helps in accurately estimating how good each state-action pair is
        # `- ent_coef * entropy` acts as a regularization term that encourages exploration by rewarding higher policy entropy
        total_loss = actor_loss + vf_coef * value_loss - ent_coef * entropy

        return total_loss, (value_loss, actor_loss, entropy)

    # `jax.value_and_grad` computes both the value of a function and its gradient wrt its inputs
    # `grads` here is the gradients of the total loss wrt the model params, which are used to update the model in the training step
    (loss, (vloss, aloss, entropy)), grads = jax.value_and_grad(_loss_fn, has_aux=True)(train_state.params)
    # `jax.lax.pmean` performs a "parallel mean" across all devices specified by the `axis_name="device"`; this ensures that each device gets the average of the computed values and gradients across all devices
    (loss, vloss, aloss, entropy, grads) = jax.lax.pmean((loss, vloss, aloss, entropy, grads), axis_name="devices")
    # it modifies the model params in a way that ideally reduces the total loss on the next eval of the model
    train_state = train_state.apply_gradients(grads=grads)

    # PACKING UPDATE INFO
    update_info = {
        "total_loss": loss,
        "value_loss": vloss,
        "actor_loss": aloss,
        "entropy": entropy,
    }

    return train_state, update_info

In [None]:
# for evaluation (evaluate for N consecutive episodes, sum rewards)
# N=1 single task, N>1 for meta-RL

# simulates consecutive episodes of interaction between an agent and an environment to evaluate the performance of a trained model

class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)  # accumulates the total rewards obtained during the rollout
    length: jax.Array = jnp.asarray(0)    # counts the total number of steps taken across all episodes
    episodes: jax.Array = jnp.asarray(0)  # tracks the number of completed episodes

def rollout(
    rng: jax.Array,
    env: Environment,
    env_params: EnvParams,
    train_state: TrainState,
    init_hstate: jax.Array,
    num_consecutive_episodes: int = 1,
) -> RolloutStats:
    # checks if the number of completed episodes is less than `num_consecutive_episodes`
    def _cond_fn(carry):
        rng, stats, timestep, prev_action, prev_reward, hstate = carry
        return jnp.less(stats.episodes, num_consecutive_episodes)

    def _body_fn(carry):
        rng, stats, timestep, prev_action, prev_reward, hstate = carry

        rng, _rng = jax.random.split(rng)

        # APPLYING THE POLICY
        dist, _, hstate = train_state.apply_fn(
            train_state.params,
            {
                "observation": timestep.observation[None, None, ...],
                "prev_action": prev_action[None, None, ...],
                "prev_reward": prev_reward[None, None, ...],
            },
            hstate,
        )

        # samples an action from the probability distribution output by the policy network; then remove any singleton dimensions from the action tensor
        action = dist.sample(seed=_rng).squeeze()
        # advances the environment's state using the selected action; updates the `timestep`, which includes the new observation, reward, and done flag
        timestep = env.step(env_params, timestep, action)

        # UPDATING STATISTICS AND REPACKING CARRY
        stats = stats.replace(
            reward=stats.reward + timestep.reward,
            length=stats.length + 1,
            episodes=stats.episodes + timestep.last(),
        )
        carry = (rng, stats, timestep, action, timestep.reward, hstate)

        return carry

    # ENVIRONMENT RESET AND INITIAL SETUP
    timestep = env.reset(env_params, rng)
    prev_action = jnp.asarray(0)
    prev_reward = jnp.asarray(0)
    init_carry = (rng, RolloutStats(), timestep, prev_action, prev_reward, init_hstate)

    # WHILE-LOOP EXEC
    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)

    return final_carry[1]  # specifically extracts the `RolloutStats` instance from the tuple, which now contains the total rewards accumulated, total length in terms of steps, and the number of episodes completed during the rollout


### Training

In [None]:
@dataclass
class TrainConfig:
    env_id: str = "MiniGrid-Empty-6x6"
    benchmark_id: Optional[str] = None
    ruleset_id: Optional[int] = None
    img_obs: bool = False
    # agent
    action_emb_dim: int = 16
    rnn_hidden_dim: int = 1024
    rnn_num_layers: int = 1
    head_hidden_dim: int = 256
    # training
    num_envs: int = 8192
    num_steps: int = 16
    update_epochs: int = 1
    num_minibatches: int = 16
    total_timesteps: int = 5_000_000
    lr: float = 0.001
    clip_eps: float = 0.2
    gamma: float = 0.99
    gae_lambda: float = 0.95
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    eval_episodes: int = 40
    seed: int = 42

    def __post_init__(self):
        num_devices = jax.local_device_count()
        # splitting computation across all available devices
        self.num_envs_per_device = self.num_envs // num_devices
        self.total_timesteps_per_device = self.total_timesteps // num_devices
        self.eval_episodes_per_device = self.eval_episodes // num_devices
        assert self.num_envs % num_devices == 0
        self.num_updates = self.total_timesteps_per_device // self.num_steps // self.num_envs_per_device
        print(f"Num devices: {num_devices}, Num updates: {self.num_updates}")


def make_states(config: TrainConfig):
    # for learning rate scheduling
    def linear_schedule(count):
        frac = 1.0 - (count // (config.num_minibatches * config.update_epochs)) / config.num_updates
        return config.lr * frac

    # setup environment
    env, env_params = xminigrid.make(config.env_id)  # creates and environment based on env_id
    env = GymAutoResetWrapper(env)  # for automatically resetting the environment when it reaches a terminal state

    # for single-task XLand environments
    if config.benchmark_id is not None:
        assert "XLand-MiniGrid" in config.env_id, "Benchmarks should be used only with XLand environments."
        assert config.ruleset_id is not None, "Ruleset ID should be specified for benchmarks usage."
        benchmark = xminigrid.load_benchmark(config.benchmark_id)
        env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))

    # enabling image observations if needed
    if config.img_obs:
        from xminigrid.experimental.img_obs import RGBImgObservationWrapper

        env = RGBImgObservationWrapper(env)

    # setup training state
    rng = jax.random.PRNGKey(config.seed)
    rng, _rng = jax.random.split(rng)

    network = ActorCriticRNN(
        num_actions=env.num_actions(env_params),
        action_emb_dim=config.action_emb_dim,
        rnn_hidden_dim=config.rnn_hidden_dim,
        rnn_num_layers=config.rnn_num_layers,
        head_hidden_dim=config.head_hidden_dim,
        img_obs=config.img_obs,
    )
    # [batch_size, seq_len, ...]
    init_obs = {
        "observation": jnp.zeros((config.num_envs_per_device, 1, *env.observation_shape(env_params))),
        "prev_action": jnp.zeros((config.num_envs_per_device, 1), dtype=jnp.int32),
        "prev_reward": jnp.zeros((config.num_envs_per_device, 1)),
    }
    init_hstate = network.initialize_carry(batch_size=config.num_envs_per_device)

    network_params = network.init(_rng, init_obs, init_hstate)

    # optimizer setup
    tx = optax.chain(
        optax.clip_by_global_norm(config.max_grad_norm),
        optax.inject_hyterparams(optax.adam)(learning_rate=linear_schedule, eps=1e-8),
    )
    train_state = TrainState.create(apply_fn=network.apply, params=network_params, tx=tx)

    return rng, env, env_params, init_hstate, train_state

In [None]:
def make_train(
    env: Environment,
    env_params: EnvParams,
    config: TrainConfig,
):
    @partial(jax.pmap, axis_name="devices")
    def train(
        rng: jax.Array,
        train_state: TrainState,
        init_hstate: jax.Array,
    ):
        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config.num_envs_per_device)

        timestep = jax.vmap(env.reset, in_axes=(None, 0))(env_params, reset_rng)
        prev_action = jnp.zeros(config.num_envs_per_device, dtype=jnp.int32)
        prev_reward = jnp.zeros(config.num_envs_per_device)

        # TRAIN LOOP
        def _update_step(runner_state, _):
            # COLLECT TRAJECTORIES
            # perform one step of interaction between the agent and the environment, updating the state based on the agent's action and the environment's response
            def _env_step(runner_state, _):
                rng, train_state, prev_timestep, prev_action, prev_reward, prev_hstate = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                dist, value, hstate = train_state.apply_fn(
                    train_state.params,
                    {
                        # [batch_size, seq_len=1, ...]
                        "observation": prev_timestep.observation[:, None],
                        "prev_action": prev_action[:, None],
                        "prev_reward": prev_reward[:, None],
                    },
                    prev_hstate,
                )
                action, log_prob = dist.sample_and_log_prob(seed=_rng)
                # squeeze seq_len where possible
                action, value, log_prob = action.squeeze(1), value.squeeze(1), log_prob.squeeze(1)

                # STEP ENV
                timestep = jax.vmap(env.step, in_axes=(None, 0, 0))(env_params, prev_timestep, action)

                transition = Transition(
                    done=timestep.last(),
                    action=action,
                    value=value,
                    reward=timestep.reward,
                    log_prob=log_prob,
                    obs=prev_timestep.observation,
                    prev_action=prev_action,
                    prev_reward=prev_reward,
                )
                runner_state = (rng, train_state, timestep, action, timestep.reward, hstate)
                return runner_state, transition

            initial_hstate = runner_state[-1]
            # transitions: [seq_len, batch_size, ...]
            # applies `_env_step` across multiple steps
            runner_state, transitions = jax.lax.scan(_env_step, runner_state, None, config.num_steps)

            # CALCULATE ADVANTAGE
            rng, train_state, timestep, prev_action, prev_reward, hstate = runner_state
            # calculate value of the last step for boostrapping
            _, last_val, _ = train_state.apply_fn(
                train_state.params,
                {
                    "observation": timestep.observation[:, None],
                    "prev_action": prev_action[:, None],
                    "prev_reward": prev_reward[:, None],
                },
                hstate,
            )
            advantages, targets = calculate_gae(transitions, last_val.squeeze(1), config.gamma, config.gae_lambda)

            # UPDATE NETWORK
            # handles the training for one epoch by processing several minibatches of data
            def _update_epoch(update_state, _):
                def _update_minbatch(train_state, batch_info):
                    init_hstate, transitions, advantages, targets = batch_info
                    new_train_state, update_info = ppo_update_networks(
                        train_state=train_state,
                        transitions=transitions,
                        init_hstate=init_hstate.squeeze(1),
                        advantages=advantages,
                        targets=targets,
                        clip_eps=config.clip_eps,
                        vf_coef=config.vf_coef,
                        ent_coef=config.ent_coef,
                    )
                    return new_train_state, update_info

                rng, train_state, init_hstate, transitions, advantages, targets = update_state

                # MINIBATCHES PREPARATION
                rng, _rng = jax.random.split(rng)
                permutation = jax.random.permutation(_rng, config.num_envs_per_device)
                # [seq_len, batch_size, ...]
                batch = (init_hstate, transitions, advantages, targets)
                # [batch_size, seq_len, ...] as our model assumes
                batch = jtu.tree_map(lambda x: x.swapaxes(0, 1), batch)

                shuffled_batch = jtu.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
                # [num_minibatches, minibatch_size, ...]
                minibatches = jtu.tree_map(
                    lambda x: jnp.reshape(x, (config.num_minibatches, -1) + x.shape[1:]), shuffled_batch
                )
                train_state, update_info = jax.lax.scan(_update_minbatch, train_state, minibatches)

                update_state = (rng, train_state, init_hstate, transitions, advantages, targets)
                return update_state, update_info

            # []