# Advantage Actor-Critic


For each minibatch:
- We do `n_steps_per_update` in `n_envs` environments in parallel.
    - (This gives `n_steps_per_update * n_envs` steps in total per minibatch.)

To calculate the advantages, we are using _Generalized Advantage Estimation_ (GAE).

- Balances tradeoff between variance and bias of advantage estimates

In [1]:
from __future__ import annotations
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import einx
from torch import nn, optim, Tensor
from tqdm.auto import tqdm, trange
import gymnasium as gym

In [None]:
class NoopTerminateWrapper(gym.Wrapper):
    """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.

    When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
    and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
    with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
     - ``new_obs`` is the first observation after calling :meth:`self.env.reset`
     - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
     - ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
     - ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
     - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
       with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
       and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.

    Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
        new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
        final reward, terminated and truncated state from the previous episode.
        If you need the final state from the previous episode, you need to retrieve it via the
        "final_observation" key in the info dict.
        Make sure you know what you're doing if you use this wrapper!
    """

    def __init__(self, env: gym.Env):
        """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.

        Args:
            env (gym.Env): The environment to apply the wrapper
        """
        super().__init__(env)

    def step(self, action, mask=None):
        """Steps through only the environments specified in the mask returning the results.
        
        Args:
            actions: Actions to take in each environment
            mask: Boolean array indicating which environments to step (True) and which to skip (False)
        
        Returns:
            The batched environment step results (observations, rewards, terminations, truncations, infos)
        """
        if mask is None:
            mask = np.full_like(actions, True, dtype=np.bool)
        
        actions = gym.vector.utils.iterate(self.env.action_space, actions)
        
        # Make a copy of the current state to avoid modifying environments we're not stepping
        masked_rewards = np.copy(self.env._rewards)
        masked_terminations = np.copy(self.env._terminations)
        masked_truncations = np.copy(self.env._truncations)
        
        infos = {}
        for i, action in enumerate(actions):
            # Skip environments not in mask
            if not mask[i]:
                continue
                
            if self.env.autoreset_mode == AutoresetMode.NEXT_STEP:
                if self.env._autoreset_envs[i]:
                    self.env._env_obs[i], env_info = self.env.envs[i].reset()
    
                    masked_rewards[i] = 0.0
                    masked_terminations[i] = False
                    masked_truncations[i] = False
                else:
                    (
                        self.env._env_obs[i],
                        masked_rewards[i],
                        masked_terminations[i],
                        masked_truncations[i],
                        env_info,
                    ) = self.env.envs[i].step(action)
            elif self.env.autoreset_mode == AutoresetMode.DISABLED:
                # assumes that the user has correctly autoreset
                assert not self.env._autoreset_envs[i], f"{self.env._autoreset_envs=}"
                (
                    self.env._env_obs[i],
                    masked_rewards[i],
                    masked_terminations[i],
                    masked_truncations[i],
                    env_info,
                ) = self.env.envs[i].step(action)
            elif self.env.autoreset_mode == AutoresetMode.SAME_STEP:
                (
                    self.env._env_obs[i],
                    masked_rewards[i],
                    masked_terminations[i],
                    masked_truncations[i],
                    env_info,
                ) = self.env.envs[i].step(action)
    
                if masked_terminations[i] or masked_truncations[i]:
                    infos = self.env._add_info(
                        infos,
                        {"final_obs": self.env._env_obs[i], "final_info": env_info},
                        i,
                    )
    
                    self.env._env_obs[i], env_info = self.env.envs[i].reset()
            else:
                raise ValueError(f"Unexpected autoreset mode, {self.env.autoreset_mode}")
    
            infos = self.env._add_info(infos, env_info, i)
    
        # Update the internal state variables
        self.env._rewards = masked_rewards
        self.env._terminations = masked_terminations
        self.env._truncations = masked_truncations
        
        # Only update _autoreset_envs for the masked environments
        masked_autoreset = np.logical_or(masked_terminations, masked_truncations)
        self.env._autoreset_envs[mask] = masked_autoreset[mask]
    
        # Concatenate the observations
        self.env._observations = concatenate(
            self.env.single_observation_space, self.env._env_obs, self.env._observations
        )
    
        return (
            deepcopy(self.env._observations) if self.env.copy else self.env._observations,
            np.copy(self.env._rewards),
            np.copy(self.env._terminations),
            np.copy(self.env._truncations),
            infos,
        )
    

In [16]:
class GRPO(nn.Module):
    def __init__(
        self,
        n_features,
        n_actions,
        device,
        actor_lr,
        n_envs
    ):
        super().__init__()
        self.device = device
        self.n_envs = n_envs
        
        # Advantage of an action is the difference between
        # the return and state-value
        # A(s, a) = Q(s, a) - V(s)
        #
        # **Q:** how do we know Q(s, a)? is this the discount factor thing
        actor_layers = [
            nn.Linear(n_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            # estimate action logits (will be fed into a softmax later)
            nn.Linear(32, n_actions),
        ]

        self.actor = nn.Sequential(*actor_layers).to(self.device)
        self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
        
        

    def forward(self, x: np.ndarray):
        """
        Args:
            x: Current states (env, feat)

        Returns:
            action_logits: Actor estimated actions (env, action)
        """
        x = torch.from_numpy(x).to(self.device)
        action_logits = self.actor(x)
        return action_logits

    def select_action(self, x):
        """
        Returns:
            actions: Sampled actions to update state (n_envs,)
            action_log_probs: Log softmax of actions (n_envs, n_actions)
            entropy: (n_envs,)
        """
        action_logits = self.forward(x)
        # uses softmax
        action_pd = torch.distributions.Categorical(logits=action_logits)
        actions = action_pd.sample()
        action_log_probs = action_pd.log_prob(actions)
        entropy = action_pd.entropy()
        return actions, action_log_probs, entropy

    def get_losses(
        self,
        rewards,
        action_log_probs,
        entropy,
        masks,
        gamma: float,
        ent_coef: float,
        n_generations: int,
        n_steps_per_update: int,
        device,
    ):
        """
        Computes loss of a minibatch (transitions collected in one sampling phase)
        for on-policy GRPO.

        Uses standard action_log_probs rather than ratio like in PPO
    
        Args:
            rewards: Rewards for each time step in episode (step, (env, gen))
            action_log_probs: logprob of action taken at each time step in episode (step, (env, gen))
            masks: Masks for each time step in episode
            gamma: Discount factor
            ent_coef: Entropy coefficeint
            n_generations (int): Number of generations per env
            n_steps_per_update (int)
            device

        Returns:
            actor_loss
        """
        rewards = einx.rearrange(
            "step (env gen) -> env gen step",
            rewards,
            gen=n_generations,
        )
        reward_per_generation = einx.sum("env gen [step]", rewards)
        mean_reward = einx.mean("env [gen]", reward_per_generation)
        
        advantages = einx.subtract(
            "env gen, env -> env gen",
            reward_per_generation,
            mean_reward,
        )
        
        # std_rewards = einx.std("n_envs [n_generations]", reward_per_generation) + 1e-8
        # advantages = einx.divide("[n_envs n_generations], [n_envs]", advantages, std_rewards)
        
        discounts = torch.tensor([gamma ** i for i in range(n_steps_per_update)][::-1], device=device)
        per_step_advantages = einx.multiply(
            "env gen, step -> env gen step",
            advantages,
            discounts,
        )

        
        # ratio = torch.exp(action_log_probs - action_log_probs)
        pg_loss1 = -einx.multiply("env gen step, step (env gen) -> env gen step", per_step_advantages, action_log_probs).mean()
        return pg_loss1



        # action_log_probs = einx.rearrange(
        #     "n_steps_per_update, (n_envs, n_generations) n_actions -> n_envs n_generations n_steps_per_update n_actions",
        #     action_log_probs,
        #     n_generations=self.n_generations,
        # )

        # return actor loss at some point


    def update_parameters(self, actor_loss):
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

In [17]:
# envs = gym.vector.VectorEnv(
#     [
#         lambda: gym.make(
#             "LunarLander-v3",
#             gravity=-10.0,
#             enable_wind=True,
#             wind_power=15.0,
#             turbulence_power=1.5,
#             max_episode_steps=600,
#         ),
#         lambda: gym.make(
#             "LunarLander-v3",
#             gravity=-9.8,
#             enable_wind=True,
#             wind_power=10.0,
#             turbulence_power=1.3,
#             max_episode_steps=600,
#         ),
#         lambda: gym.make(
#             "LunarLander-v3", gravity=-7.0, enable_wind=False, max_episode_steps=600
#         ),
#     ]
# )

In [18]:
n_envs = 10
n_updates = 200
n_steps_per_update = 128
n_generations = 4

total_envs = n_envs * n_generations

# hyperparameters
gamma = 0.999
ent_coef = 0.01 # coef for entropy bonus (encourage exploration)
actor_lr = 1e-3

In [19]:
envs = gym.make_vec(
    "LunarLander-v3",
    num_envs=total_envs,
    max_episode_steps=600, 
    vectorization_mode="sync",
    vector_kwargs=dict(autoreset_mode=gym.vector.AutoresetMode.DISABLED),
)
envs_wrapper = gym.wrappers.vector.RecordEpisodeStatistics(
    envs,
    buffer_length=n_envs * n_updates
)

obs_shape = envs.single_observation_space.shape[0]
action_shape = envs.single_action_space.n

device = torch.device("mps")

agent = GRPO(obs_shape, action_shape, device, actor_lr, n_envs)

In [20]:
actor_losses = []
entropies = []

states, info = envs_wrapper.reset(seed=42)

# don't have to reset envs, they keep playing until episode over and reset automatically
# how to set n_generations envs to be the same?

pbar = trange(n_updates)
for sample_phase in pbar:
    # reset lists that collect experience of an episode
    ep_action_log_probs = torch.zeros(n_steps_per_update, total_envs, device=device)
    ep_rewards = torch.zeros(n_steps_per_update, total_envs, device=device)
    masks = torch.zeros(n_steps_per_update, total_envs, device=device)
    reset_mask = torch.full((total_envs,), False, device=device, dtype=torch.bool)

    for step in range(n_steps_per_update):
        # sample actions a[t] using s[t] as input
        (
            actions,
            action_log_probs,
            entropy,
        ) = agent.select_action(states)

        # perform actions[t] in env to get s[t+1] and r[t+1]
        states, rewards, terminated, truncated, info = envs_wrapper.step(actions.cpu().numpy())

        ep_action_log_probs[step] = action_log_probs
        ep_rewards[step] = torch.tensor(rewards, device=device, dtype=torch.float32)

        # mask: if env is ongoing 1 else 0
        masks[step] = torch.tensor(~np.logical_or(terminated, truncated))

        # reset envs
        # get all envs that have completed in all their generations
        reset_mask = einx.all("(env [gen])", ~masks[step].to(torch.bool), gen=n_generations)
        reset_mask = einx.rearrange("env -> (env gen)", reset_mask, gen=n_generations)
        if reset_mask.any():
            ep_action_log_probs[:, reset_mask] = 0.
            masks[:, reset_mask] = False
            envs.reset(seed=step, options=dict(reset_mask=reset_mask.cpu().numpy()))
        break

    actor_loss = agent.get_losses(
        ep_rewards,
        ep_action_log_probs,
        entropy,
        masks,
        gamma,
        ent_coef,
        n_generations,
        n_steps_per_update,
        device,
    )

    agent.update_parameters(actor_loss)

    stats = {
        "actor_loss": actor_loss.item(),
        "entropy": entropy.detach().mean().item()
    }
    pbar.set_postfix(stats)
    actor_losses.append(stats["actor_loss"])
    entropies.append(stats["entropy"])

  0%|          | 0/200 [00:00<?, ?it/s]

NameError: name 'AutoresetMode' is not defined

In [None]:
""" plot the results """

# %matplotlib inline

rolling_length = 20
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 5))
fig.suptitle(
    f"Training plots for {agent.__class__.__name__} in the LunarLander-v3 environment \n \
             (n_envs={n_envs}, n_steps_per_update={n_steps_per_update}, randomize_domain={False})"
)

# episode return
axs[0][0].set_title("Episode Returns")
episode_returns_moving_average = (
    np.convolve(
        np.array(envs_wrapper.return_queue).flatten(),
        np.ones(rolling_length),
        mode="valid",
    )
    / rolling_length
)
axs[0][0].plot(
    np.arange(len(episode_returns_moving_average)) / n_envs,
    episode_returns_moving_average,
)
axs[0][0].set_xlabel("Number of episodes")

# entropy
axs[1][0].set_title("Entropy")
entropy_moving_average = (
    np.convolve(np.array(entropies), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][0].plot(entropy_moving_average)
axs[1][0].set_xlabel("Number of updates")


# critic loss
axs[0][1].set_title("Critic Loss")
critic_losses_moving_average = (
    np.convolve(
        np.array(critic_losses).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0][1].plot(critic_losses_moving_average)
axs[0][1].set_xlabel("Number of updates")


# actor loss
axs[1][1].set_title("Actor Loss")
actor_losses_moving_average = (
    np.convolve(np.array(actor_losses).flatten(), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][1].plot(actor_losses_moving_average)
axs[1][1].set_xlabel("Number of updates")

plt.tight_layout()
plt.show()

In [None]:
n_showcase_episodes = 3

for episode in range(n_showcase_episodes):
    env = gym.make("LunarLander-v3", render_mode="human", max_episode_steps=500)

    # get an initial state
    state, info = env.reset()

    # play one episode
    done = False
    while not done:
        # select an action A_{t} using S_{t} as input for the agent
        with torch.no_grad():
            action, _, _, _ = agent.select_action(state[None, :])

        # perform the action A_{t} in the environment to get S_{t+1} and R_{t+1}
        state, reward, terminated, truncated, info = env.step(action.item())

        # update if the environment is done
        done = terminated or truncated

env.close()

In [None]:
# from gymnasium.utils.play import play

# play(gym.make('LunarLander-v3', render_mode='rgb_array'),
#     keys_to_action={'w': 2, 'a': 1, 'd': 3}, noop=0)