In [47]:
from typing import Tuple
import jax
import jax.numpy as jnp
import numpy as np
import torch

In [48]:
def compute_gae(
    rewards,
    values,
    dones,
    last_value,
    gamma: float,
    lambd: float,
):
    def compute_gae_at_timestep(carry, x):
        gae, next_value = carry
        value, reward, done = x
        delta = reward + gamma * next_value * (1 - done) - value
        gae = delta + gamma * lambd * (1 - done) * gae
        return (gae, value), gae

    _, advantages = jax.lax.scan(
        compute_gae_at_timestep,
        (jnp.zeros_like(last_value), last_value),
        (values, rewards, dones),
        reverse=True,
        unroll=16,
    )
    return advantages, advantages + values

In [49]:
def get_returns_advantages(
    rewards: torch.Tensor,
    values: torch.Tensor,
    dones: torch.Tensor,
    last_value: torch.Tensor,
    gamma: float = 0.99,
    lam: float = 0.95,
    normalize_returns: bool = False,
    normalize_advantages: bool = False,
):
    """Compute generalized advantage estimate.
    rewards: a list of rewards at each step.
    values: the value estimate of the state at each step.
    episode_ends: an array of the same shape as rewards, with a 1 if the
        episode ended at that step and a 0 otherwise.
    gamma: the discount factor.
    lam: the GAE lambda parameter.
    """
    with torch.no_grad():
        # Invert episode_ends to have 0 if the episode ended and 1 otherwise
        T = rewards.shape[0]
        N = rewards.shape[1]
        gae_step = torch.zeros((N,))
        advantages = torch.zeros((T, N))
        values = values.detach()

        for t in reversed(range(T)):
            if t == (T - 1):
                next_value = last_value
            else:
                next_value = values[t + 1, :]
            delta = (
                rewards[t, :] + gamma * next_value * (1-dones[t, :]) - values[t, :]
            )
            gae_step = delta + gamma * lam * (1-dones[t, :]) * gae_step
            # And store it
            advantages[t, :] = gae_step

        returns = advantages + values
        if normalize_returns:
            # normalize over num_steps
            returns = (returns - returns.mean()) / (returns.std() + 1e-5)

        if normalize_advantages:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
        return returns, advantages


In [50]:
values = np.zeros((256, 1))
rewards = np.zeros((256, 1))
rewards[50,:] = 2.0
dones = np.zeros((256, 1))
dones[50, :] = 1
rewards[-1] = 1.0
dones[-1] = 1
values[-1] = 0.5
last_value = np.ones((1,))

In [51]:
gamma = 0.99
lambd = 0.95

num_steps = rewards.shape[0]
num_envs = rewards.shape[1]

In [52]:
jax_gae = compute_gae(rewards, values, dones, last_value, gamma, lambd)
jax_advantages = np.array(jax_gae[0])
print(jax_advantages[-1])
print(np.mean(jax_advantages))

[0.5]
0.19087662


In [53]:
values = torch.zeros((256, 1))
rewards = torch.zeros((256, 1))
rewards[50,:] = 2.0
dones = torch.zeros((256, 1))
dones[50, :] = 1
rewards[-1] = 1.0
dones[-1] = 1
values[-1] = 0.5
last_value = torch.ones((1,))

torch_advantages = get_returns_advantages(rewards, values, dones, last_value, gamma, lambd)[-1]
print(torch_advantages[-1])
print(torch.mean(torch_advantages))

tensor([0.5000])
tensor(0.1909)
