# A2C Demo
This notebook focuses on training and testing the Advantage Actor-Critic (A2C) algorithm.

In [1]:
from a2c.model import ACNetwork
from a2c.agent import A2CAgent
from utils.config import Config
from utils.helper import set_device

from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY
import torch.optim as optim

In [2]:
# Set hyperparameters
ENV_NAME = 'SuperMarioBros-v0'
GAMMA = 0.99
LEARNING_RATE = 0.001
EPSILON = 1e-3
ENTROPY_WEIGHT = 0.01
N_STEPS = 4 # TD bootstrapping
GRAD_CLIP = 0.1 # Prevents gradients from being too large
NUM_EPISODES = 100
SAVE_MODEL_FILENAME = 'a2c'

# Create environment
env = gym_super_mario_bros.make(ENV_NAME)
env = JoypadSpace(env, RIGHT_ONLY)

# Set cuda device
device = set_device()

# Add core items to config
config = Config()
config.add(
    env=env,
    env_name=ENV_NAME,
    gamma=GAMMA,
    lr=LEARNING_RATE,
    epsilon=EPSILON,
    entropy_weight=ENTROPY_WEIGHT,
    rollout_size=N_STEPS,
    grad_clip=GRAD_CLIP,
    device=device,
    num_episodes=NUM_EPISODES,
    filename=SAVE_MODEL_FILENAME
)

# Store environment parameters
config.set_env_params()

  logger.warn(


CUDA unavailable. Device set to CPU.


In [3]:
# https://github.com/Kautenja/gym-super-mario-bros/blob/master/gym_super_mario_bros/actions.py
print('Available actions:', config.action_space)
print('Obs space shape: ', config.input_shape)

Available actions: Discrete(5)
Obs space shape:  (240, 256, 3)


In [4]:
# Create network
a2c = ACNetwork(config.input_shape, config.n_actions).to(device)

# Add optimizer and network to config
config.add(
    optimizer_fn=lambda params: optim.Adam(
        params,
        lr=config.lr,
        eps=config.epsilon
    ),
    network_fn=lambda: a2c
)

In [5]:
# Train agent
agent = A2CAgent(config)
agent.train()

Running training with N-Steps: 4
(1/100) Actions: [3, 0, 3, 4], Avg return: -0.024, Total Loss: 0.016675731167197227
(2/100) Actions: [0, 4, 2, 4], Avg return: 0.080, Total Loss: 0.014151601120829582
(3/100) Actions: [3, 4, 2, 3], Avg return: 0.170, Total Loss: 0.011996699497103691
(4/100) Actions: [0, 2, 3, 4], Avg return: 0.271, Total Loss: 0.009534895420074463
(5/100) Actions: [3, 1, 1, 0], Avg return: 0.386, Total Loss: 0.006642619147896767
(6/100) Actions: [1, 4, 0, 2], Avg return: 0.514, Total Loss: 0.0036561237648129463
(7/100) Actions: [3, 0, 3, 2], Avg return: 0.658, Total Loss: 0.00023069418966770172
(8/100) Actions: [0, 4, 4, 2], Avg return: 0.817, Total Loss: -0.0034129414707422256
(9/100) Actions: [2, 0, 2, 1], Avg return: 0.994, Total Loss: -0.007785201072692871
(10/100) Actions: [3, 1, 2, 0], Avg return: 1.189, Total Loss: -0.012473942711949348
(11/100) Actions: [2, 1, 1, 4], Avg return: 1.403, Total Loss: -0.017377033829689026
(12/100) Actions: [2, 0, 4, 2], Avg return: