In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm

import gymnasium as gym

from ppo import A2C

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# environment hyperparams
n_envs = 10
n_updates = 1000
n_steps_per_update = 128
nbatch = n_envs * n_steps_per_update
batch_size = 32
nbatch_train = nbatch // batch_size
n_epochs = 4

# agent hyperparams
epsilon = 0.2
gamma = 0.999
lam = 0.95  # hyperparameter for GAE
ent_coef = 0.01  # coefficient for the entropy bonus (to encourage exploration)
actor_lr = 0.001
critic_lr = 0.005

envs = gym.vector.make("LunarLander-v2", num_envs=n_envs, max_episode_steps=600)
obs_shape = envs.single_observation_space.shape[0]
action_shape = envs.single_action_space.n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)

  gym.logger.warn(
  return torch._C._cuda_getDeviceCount() > 0


In [3]:
envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs, deque_size=n_envs * n_updates)

critic_losses = []
actor_losses = []
entropies = []

for sample_phase in tqdm(range(n_updates)):
    # we don't have to reset the envs, they just continue playing
    # until the episode is over and then reset automatically

    # reset lists that collect experiences of an episode (sample phase)
    ep_states = torch.zeros(n_steps_per_update*n_envs, obs_shape, device=device)
    ep_actions = torch.zeros(n_steps_per_update*n_envs, device=device)
    ep_value_preds = torch.zeros(n_steps_per_update*n_envs, device=device)
    ep_rewards = torch.zeros(n_steps_per_update*n_envs, device=device)
    ep_entropies = torch.zeros(n_steps_per_update*n_envs, device=device)
    ep_action_log_probs = torch.zeros(n_steps_per_update*n_envs, device=device)
    masks = torch.zeros(n_steps_per_update*n_envs, device=device)

    # at the start of training reset all envs to get an initial state
    if sample_phase == 0:
        states, info = envs_wrapper.reset(seed=42)

    # play n steps in our parallel environments to collect data
    for step in range(n_steps_per_update):
        ep_states[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = torch.tensor(states, device=device)
        # select an action A_{t} using S_{t} as input for the agent
        actions, action_log_probs, state_value_preds, entropy = agent.select_action(
            states
        )
        
        # perform the action A_{t} in the environment to get S_{t+1} and R_{t+1}
        states, rewards, terminated, truncated, infos = envs_wrapper.step(
            actions.cpu().numpy()
        )

        ep_value_preds[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = torch.squeeze(state_value_preds)
        ep_rewards[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = torch.tensor(rewards, device=device)
        ep_action_log_probs[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = action_log_probs
        ep_actions[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = torch.squeeze(actions)
        ep_entropies[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = torch.squeeze(entropy)

        # add a mask (for the return calculation later);
        # for each env the mask is 1 if the episode is ongoing and 0 if it is terminated (not by truncation!)
        masks[sample_phase*n_steps_per_update+step:sample_phase*n_steps_per_update+step+n_envs] = torch.tensor([not term for term in terminated])

    # calculate the losses for actor and critic
    inds = np.arange(nbatch)
    for _ in range(n_epochs):
        # Randomize the indexes
        np.random.shuffle(inds)
        # 0 to batch_size with batch_train_size step
        for start in range(0, nbatch, nbatch_train):
            end = start + nbatch_train
            mbinds = inds[start:end]

        critic_loss, actor_loss = agent.get_losses(
            ep_states[mbinds],
            ep_actions[mbinds],
            ep_rewards[mbinds],
            ep_action_log_probs[mbinds],
            ep_value_preds[mbinds],
            ep_entropies[mbinds],
            masks[mbinds],
            gamma,
            lam,
            epsilon,
            ent_coef,
            device,
            n_envs,
        )

        # update the actor and critic networks
        agent.update_parameters(critic_loss, actor_loss)

    agent.sync_actor()

    # log the losses and entropy
    # TODO: change to multi-batch 
    critic_losses.append(critic_loss.detach().cpu().numpy())
    actor_losses.append(actor_loss.detach().cpu().numpy())
    entropies.append(entropy.detach().mean().cpu().numpy())



  0%|          | 0/1000 [00:00<?, ?it/s]


RuntimeError: shape mismatch: value tensor of shape [40] cannot be broadcast to indexing result of shape [0]

In [None]:
# %%
# Plotting
# --------
#

""" plot the results """

# %matplotlib inline

rolling_length = 20
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 5))
fig.suptitle(
    f"Training plots for {agent.__class__.__name__} in the LunarLander-v2 environment \n \
             (n_envs={n_envs}, n_steps_per_update={n_steps_per_update}, randomize_domain={randomize_domain})"
)

# episode return
axs[0][0].set_title("Episode Returns")
episode_returns_moving_average = (
    np.convolve(
        np.array(envs_wrapper.return_queue).flatten(),
        np.ones(rolling_length),
        mode="valid",
    )
    / rolling_length
)
axs[0][0].plot(
    np.arange(len(episode_returns_moving_average)) / n_envs,
    episode_returns_moving_average,
)
axs[0][0].set_xlabel("Number of episodes")

# entropy
axs[1][0].set_title("Entropy")
entropy_moving_average = (
    np.convolve(np.array(entropies), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][0].plot(entropy_moving_average)
axs[1][0].set_xlabel("Number of updates")


# critic loss
axs[0][1].set_title("Critic Loss")
critic_losses_moving_average = (
    np.convolve(
        np.array(critic_losses).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0][1].plot(critic_losses_moving_average)
axs[0][1].set_xlabel("Number of updates")


# actor loss
axs[1][1].set_title("Actor Loss")
actor_losses_moving_average = (
    np.convolve(np.array(actor_losses).flatten(), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][1].plot(actor_losses_moving_average)
axs[1][1].set_xlabel("Number of updates")

plt.tight_layout()
plt.show()