In [1]:
import torchrl
from torchrl.utils import make_parallel_env

import torch
import numpy as np

from torchrl.a2c import RolloutStorage
from torchrl.a2c import ActorCritic
from collections import deque

import time

import pdb

In [2]:
torch.set_num_threads(1)
device = torch.device("cuda")

In [3]:
import gym
from torchrl.utils import atari_wrappers
env = atari_wrappers.make_atari("Breakout-v4")

In [4]:
env.reset().shape

(210, 160, 3)

In [5]:
params = {}
params["rollout_len"] = 5
params["num_workers"] = 32

env = make_parallel_env("Breakout-v4", 
                        seed=1, num_workers=params["num_workers"], 
                        num_frame_stack=4, device=device, log_dir="results")

params["obs_shape"] = env.observation_space.shape
params["action_shape"] = env.action_space.n

params["total_steps"] = 1e7
params["discount_gamma"] = 0.99
params["gae_tau"] = 0.95
params["use_gae"] = False
params["entropy_coef"] = 0.01
params["value_coef"] = 0.5
params['num_in_channels'] = 4
params["max_grad_norm"] = 0.5
params["learning_rate"] = 7e-4
params["RMS_alpha"] = 0.99
params["RMS_eps"] = 1e-5
params["resume"] = False

In [6]:
obs = env.reset()

In [7]:
obs.shape

torch.Size([32, 4, 84, 84])

In [None]:
rollouts = RolloutStorage(params)
rollouts.obs[0].copy_(obs)
rollouts.to(device)

In [None]:
model = ActorCritic(params).to(device)

In [None]:
optimizer = torch.optim.RMSprop(model.parameters(),
                           params["learning_rate"],
                           alpha=params["RMS_alpha"],
                           eps=params["RMS_eps"]
                          )

In [None]:
num_steps = 0

In [None]:
episode_rewards = deque(maxlen=50)

t0 = time.time()

batch_size = int(params["num_workers"] * params["rollout_len"])
num_iter = int(params["total_steps"] // batch_size)
log_interval = 1000
log_iter = int(log_interval // batch_size)

rollouts.obs[0].copy_(obs)

for e in range(num_iter):
    #-------------------
    # Evaluation Network 
    #-------------------

    # collect rollout
    for step in range(params["rollout_len"]):
        # Sample actions
        with torch.no_grad():
            actions, values = model.act(rollouts.obs[step])

        # move a step
        states, rewards, dones, infos = env.step(actions)

        for info in infos:
            if 'episode' in info.keys():
                episode_rewards.append(info['episode']['r'])

        rollouts.insert(states, rewards, 1 - torch.from_numpy(dones.astype(np.float32)), actions, values.squeeze())

    # calculate the next value
    with torch.no_grad():
        next_value = model.get_value(rollouts.obs[-1]).squeeze()

    rollouts.compute_returns(next_value)
    # to cpu or gpu
    rollouts.to(device)
    

    #---------------
    # Update Network 
    #---------------
    
    batch_states = rollouts.obs[:-1].view(-1, *params["obs_shape"])
    batch_actions = rollouts.actions.view(-1)

    log_probs, values, entropy = model.eval_action(batch_states, batch_actions)
    
    advs = rollouts.returns[:-1].view(-1) - values.squeeze()
    value_loss = advs.pow(2).mean()
    policy_loss = -(advs.detach() * log_probs).mean()
    
    loss = policy_loss + params["value_coef"]*value_loss - params["entropy_coef"]*entropy


    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(
            model.parameters(), params["max_grad_norm"])
    optimizer.step()
    
    rollouts.after_update()

    # time for logging
    if e % log_iter == 0:
        if len(episode_rewards) > 0:
            print('Total steps: %d Returns %.2f/%.2f/%.2f/%.2f (mean/median/min/max)  %.2f steps/s' % (
                num_steps,
                np.mean(episode_rewards),
                np.median(episode_rewards),
                np.min(episode_rewards),
                np.max(episode_rewards),
                batch_size * log_iter / (time.time() - t0)
            ))
            t0 = time.time()
    
    num_steps += params["rollout_len"] * params["num_workers"]
