In [1]:
import torchrl
from torchrl.utils import make_parallel_env

import torch
import numpy as np

from torchrl.a2c import RolloutStorage
from torchrl.a2c import ActorCritic
from collections import deque

import time

import pdb

In [2]:
torch.set_num_threads(1)
device = torch.device("cuda")

In [43]:
params = {}
params["rollout_len"] = 5
params["num_workers"] = 32
params["obs_shape"] = env.observation_space.shape
params["action_shape"] = env.action_space.n

params["total_steps"] = 1e7
params["discount_gamma"] = 0.99
params["gae_tau"] = 0.95
params["use_gae"] = False
params["entropy_coef"] = 0.01
params["value_coef"] = 0.5
params['num_in_channels'] = 4
params["max_grad_norm"] = 0.5
params["learning_rate"] = 7e-4
params["RMS_alpha"] = 0.99
params["RMS_eps"] = 1e-5
params["resume"] = False

In [44]:
rollouts = RolloutStorage(params)
rollouts.to(device)

In [45]:
env = make_parallel_env("BreakoutNoFrameskip-v4", 
                        seed=1, num_workers=params["num_workers"], 
                        num_frame_stack=4, device=device, log_dir="results")
curr_states = env.reset()

In [39]:
model = ActorCritic(params).to(device)

In [40]:
optimizer = torch.optim.RMSprop(model.parameters(),
                           params["learning_rate"],
                           alpha=params["RMS_alpha"],
                           eps=params["RMS_eps"]
                          )

In [41]:
num_steps = 0

In [42]:
episode_rewards = deque(maxlen=50)

t0 = time.time()

batch_size = int(params["num_workers"] * params["rollout_len"])
num_iter = int(params["total_steps"] // batch_size)
log_interval = 1000
log_iter = int(log_interval // batch_size)

for e in range(num_iter):
    #-------------------
    # Evaluation Network 
    #-------------------

    # collect rollout
    for i in range(params["rollout_len"]):
        # Sample actions
        with torch.no_grad():
            actions, values = model.act(curr_states)

        # move a step
        states, rewards, dones, infos = env.step(actions)

        for info in infos:
            if 'episode' in info.keys():
                episode_rewards.append(info['episode']['r'])

        rollouts.insert(states, rewards, 1 - torch.from_numpy(dones.astype(np.float32)), actions, values.squeeze())
        curr_states = states

    # calculate the next value
    with torch.no_grad():
        next_value = model.get_value(curr_states).squeeze()

    rollouts.compute_returns(next_value)
    # to cpu or gpu
    rollouts.to(device)
    

    #---------------
    # Update Network 
    #---------------
    
    batch_states = rollouts.obs[:-1].view(-1, *params["obs_shape"])
    batch_actions = rollouts.actions.view(-1)

    log_probs, values, entropy = model.eval_action(batch_states, batch_actions)
    
    advs = rollouts.returns[:-1].view(-1) - values.squeeze()
    value_loss = advs.pow(2).mean()
    policy_loss = -(advs.detach() * log_probs).mean()
    
    loss = policy_loss + params["value_coef"]*value_loss - params["entropy_coef"]*entropy


    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(
            model.parameters(), params["max_grad_norm"])
    optimizer.step()

    # time for logging
    if e % log_iter == 0:
        if len(episode_rewards) > 0:
            print('Total steps: %d Returns %.2f/%.2f/%.2f/%.2f (mean/median/min/max)  %.2f steps/s' % (
                num_steps,
                np.mean(episode_rewards),
                np.median(episode_rewards),
                np.min(episode_rewards),
                np.max(episode_rewards),
                batch_size * log_iter / (time.time() - t0)
            ))
            t0 = time.time()
    
    num_steps += params["rollout_len"] * params["num_workers"]


Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  608.00 steps/s
Returns 0.39/0.00/0.00/1.00 (mean/median/min/max)  2359.42 steps/s
Returns 0.45/0.00/0.00/2.00 (mean/median/min/max)  2529.60 steps/s
Returns 0.50/0.00/0.00/2.00 (mean/median/min/max)  2557.51 steps/s
Returns 0.37/0.00/0.00/2.00 (mean/median/min/max)  2472.86 steps/s
Returns 0.50/0.00/0.00/4.00 (mean/median/min/max)  2252.48 steps/s
Returns 0.62/0.00/0.00/4.00 (mean/median/min/max)  2494.79 steps/s
Returns 0.62/0.00/0.00/4.00 (mean/median/min/max)  2604.12 steps/s
Returns 0.58/0.00/0.00/4.00 (mean/median/min/max)  2265.29 steps/s
Returns 0.48/0.00/0.00/4.00 (mean/median/min/max)  2257.49 steps/s
Returns 0.54/0.00/0.00/7.00 (mean/median/min/max)  2428.89 steps/s
Returns 0.62/0.00/0.00/7.00 (mean/median/min/max)  2492.17 steps/s
Returns 0.46/0.00/0.00/7.00 (mean/median/min/max)  2403.62 steps/s
Returns 0.48/0.00/0.00/7.00 (mean/median/min/max)  2486.82 steps/s
Returns 0.62/0.00/0.00/7.00 (mean/median/min/max)  2444.35 step

Returns 2.00/0.00/0.00/11.00 (mean/median/min/max)  2320.42 steps/s
Returns 1.70/0.00/0.00/11.00 (mean/median/min/max)  2260.08 steps/s
Returns 1.92/0.00/0.00/11.00 (mean/median/min/max)  2262.04 steps/s
Returns 2.14/0.00/0.00/11.00 (mean/median/min/max)  2207.71 steps/s
Returns 2.14/0.00/0.00/11.00 (mean/median/min/max)  2338.32 steps/s
Returns 2.14/0.00/0.00/11.00 (mean/median/min/max)  2226.89 steps/s
Returns 2.12/0.00/0.00/11.00 (mean/median/min/max)  2199.83 steps/s
Returns 2.12/0.00/0.00/11.00 (mean/median/min/max)  2289.92 steps/s
Returns 1.90/0.00/0.00/11.00 (mean/median/min/max)  2287.75 steps/s
Returns 2.16/0.00/0.00/11.00 (mean/median/min/max)  2215.31 steps/s
Returns 2.34/0.00/0.00/11.00 (mean/median/min/max)  2296.52 steps/s
Returns 1.90/0.00/0.00/11.00 (mean/median/min/max)  2339.66 steps/s
Returns 2.12/0.00/0.00/11.00 (mean/median/min/max)  2339.69 steps/s
Returns 2.34/0.00/0.00/11.00 (mean/median/min/max)  2455.52 steps/s
Returns 2.78/0.00/0.00/11.00 (mean/median/min/ma

Returns 2.20/0.00/0.00/11.00 (mean/median/min/max)  2124.55 steps/s
Returns 1.98/0.00/0.00/11.00 (mean/median/min/max)  2065.31 steps/s
Returns 1.76/0.00/0.00/11.00 (mean/median/min/max)  2050.07 steps/s
Returns 1.98/0.00/0.00/11.00 (mean/median/min/max)  2098.95 steps/s
Returns 2.42/0.00/0.00/11.00 (mean/median/min/max)  2031.57 steps/s
Returns 2.42/0.00/0.00/11.00 (mean/median/min/max)  2125.84 steps/s
Returns 2.64/0.00/0.00/11.00 (mean/median/min/max)  1973.60 steps/s
Returns 2.42/0.00/0.00/11.00 (mean/median/min/max)  2067.60 steps/s
Returns 2.42/0.00/0.00/11.00 (mean/median/min/max)  2060.40 steps/s
Returns 2.64/0.00/0.00/11.00 (mean/median/min/max)  2086.98 steps/s
Returns 2.64/0.00/0.00/11.00 (mean/median/min/max)  2131.95 steps/s
Returns 3.08/0.00/0.00/11.00 (mean/median/min/max)  2056.51 steps/s
Returns 2.86/0.00/0.00/11.00 (mean/median/min/max)  2069.57 steps/s
Returns 2.64/0.00/0.00/11.00 (mean/median/min/max)  2097.40 steps/s
Returns 2.86/0.00/0.00/11.00 (mean/median/min/ma

Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2053.17 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2182.97 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2182.77 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2324.33 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2077.43 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2182.20 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2196.76 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2364.27 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2004.88 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2177.19 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2165.74 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2305.58 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2205.96 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2274.68 steps/s
Returns 0.00/0.00/0.00/0.00 (mean/median/min/max)  2231.41 ste

KeyboardInterrupt: 