In [1]:
import torchrl
import torchrl.network as network
from torchrl.a2c import A2C
from torchrl.a2c import ActorCritic

# later incorporate this into torchrl
from torchrl.utils import make_atari_env
from torchrl.utils import weights_init
from torchrl.parallel import SubprocVecEnv
from torchrl.parallel import VecFrameStack

import time
import torch
import torch.optim as optim
import numpy as np
from tensorboardX import SummaryWriter


In [2]:
params = {}
params["num_in_channels"] = 4
params["num_latent_nodes"] = 512

# control the speed of training
params["num_workers"] = 10
params["rollout_len"] = 5
params["use_cuda"] = True

# log interval steps
params["log_interval"] = 2000

In [3]:
env = VecFrameStack(make_atari_env("BreakoutNoFrameskip-v4", params["num_workers"]), 4)

Process Process-3:
Process Process-4:
Process Process-2:
Process Process-1:
Process Process-9:
Process Process-8:
Process Process-10:
Traceback (most recent call last):
Process Process-6:
Traceback (most recent call last):
  File "/home/will/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process Process-5:
Traceback (most recent call last):
Process Process-7:
Traceback (most recent call last):
  File "/home/will/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/will/anaconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/will/Desktop/pytorchrl/torchrl/parallel/subproc_vec_env.py", line 14, in _worker
    cmd,

In [4]:
params["num_actions"] = env.action_space.n

# hyperparaters for traing
params["discount_gamma"] = 0.9
params["entropy_coef"] = 0.01
params["value_coef"] = 0.5
params["use_gae"] = True
params["gae_tau"] = 0.95
params["max_grad_norm"] = 0.5
params["learning_rate"] = 7e-4
params["RMS_alpha"] = 0.99
params["RMS_eps"] = 1e-5
params["resume"] = False

In [5]:
# define a model and its optimizer
model = ActorCritic(params)
model.apply(weights_init)

# RMSprop better for RL
optimizer = optim.RMSprop(model.parameters(),
                           params["learning_rate"],
                           alpha=params["RMS_alpha"],
                           eps=params["RMS_eps"]
                          )

In [6]:
agent = A2C(env, model, optimizer, params)

In [7]:
if params["resume"]:
    checkpoint = torch.load("Breakout/model0.pth")
    model.load_state_dict(checkpoint["state_dict"])
    agent.num_steps = checkpoint["num_steps"]

In [8]:
writer = SummaryWriter()

In [9]:
batch_size = int(agent.num_workers * agent.rollout_len)
num_iter = int(4e7 // batch_size)

In [10]:
model.cuda()
model.train()
t0 = time.time()

log_interval = params["log_interval"]
log_iter = int(log_interval // batch_size)

for e in range(num_iter):
    agent.train_step()
    
    if e % log_iter == 0:
        writer.add_scalar('policy loss', agent.policy_loss, agent.num_steps)
        writer.add_scalar('entropy loss', agent.entropy_loss, agent.num_steps)
        writer.add_scalar('value loss', agent.value_loss, agent.num_steps)
        
        if len(agent.episode_rewards) > 0:
            writer.add_scalar('Mean Reward', np.mean(agent.episode_rewards), agent.num_steps)
            writer.add_scalar('Max Reward', np.max(agent.episode_rewards), agent.num_steps)
            
            print('Returns %.2f/%.2f/%.2f/%.2f (mean/median/min/max)  %.2f steps/s' % (
                np.mean(agent.episode_rewards),
                np.median(agent.episode_rewards),
                np.min(agent.episode_rewards),
                np.max(agent.episode_rewards),
                batch_size * log_iter / (time.time() - t0)
            ))
            
            agent.episode_rewards.clear()

            t0 = time.time()
    
    # save the model every 1000000 steps
    if e % (1e6 // batch_size) == 0:
        checkpoint = {
            'num_steps': agent.num_steps,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict()
            }
        torch.save(checkpoint, "Breakout/saved_model" + str(agent.num_steps) + ".pth")
        

Returns 0.22/0.00/0.00/2.00 (mean/median/min/max)  670.13 steps/s
Returns 0.20/0.00/0.00/2.00 (mean/median/min/max)  987.90 steps/s
Returns 0.20/0.00/0.00/3.00 (mean/median/min/max)  1020.16 steps/s
Returns 0.30/0.00/0.00/4.00 (mean/median/min/max)  1082.33 steps/s
Returns 0.22/0.00/0.00/3.00 (mean/median/min/max)  1068.01 steps/s
Returns 0.30/0.00/0.00/3.00 (mean/median/min/max)  1063.05 steps/s
Returns 0.40/0.00/0.00/2.00 (mean/median/min/max)  1065.03 steps/s
Returns 0.40/0.00/0.00/3.00 (mean/median/min/max)  1037.69 steps/s
Returns 0.30/0.00/0.00/2.00 (mean/median/min/max)  1039.85 steps/s
Returns 0.20/0.00/0.00/2.00 (mean/median/min/max)  1063.77 steps/s
Returns 0.24/0.00/0.00/2.00 (mean/median/min/max)  1019.06 steps/s
Returns 0.14/0.00/0.00/1.00 (mean/median/min/max)  978.74 steps/s
Returns 0.38/0.00/0.00/2.00 (mean/median/min/max)  1021.25 steps/s
Returns 0.28/0.00/0.00/3.00 (mean/median/min/max)  1044.36 steps/s
Returns 0.14/0.00/0.00/3.00 (mean/median/min/max)  1036.94 steps/

KeyboardInterrupt: 

In [111]:
import copy

In [108]:
torch.manual_seed(1)
temp_rollout = []
for i in range(5):
    log_probs = torch.randn(8)
    values = torch.randn(8)
    rewards = torch.randint(0, 2, (8,)).float()
    terminals = torch.zeros((8,)).float()
    entropys = torch.randn(8)
    temp_rollout.append([log_probs, values, rewards, 1 - terminals, entropys])

In [109]:
last_value = torch.randn(1)

In [110]:
last_value

tensor([0.4728])

In [130]:
%%timeit
rollout = copy.deepcopy(temp_rollout)
rollout.append([None, last_value, None, None, None])

processed_rollout = [None] * (len(rollout) - 1)
advantages = torch.zeros((8, 1))
returns = last_value

for i in reversed(range(len(rollout) - 1)):
    log_prob, value, rewards, terminals, entropy = rollout[i]
    next_value = rollout[i + 1][1]
    returns = rewards + params["discount_gamma"] * terminals * returns
    
    
    td_error = rewards + params["discount_gamma"] * terminals * next_value.detach() - value.detach()
    advantages = advantages * params["gae_tau"] * params["discount_gamma"]  * terminals + td_error
    
    
    processed_rollout[i] = [log_prob, value, returns, advantages, entropy]

437 µs ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [97]:
log_prob, value, returns, advantages, entropy = map(lambda x: torch.cat(x, dim=0), zip(*processed_rollout))
returns

tensor([1.2792, 2.9892, 1.2792, 1.9353, 3.5643, 2.7182, 2.7182, 1.0892, 0.3102,
        2.2102, 0.3102, 1.0392, 2.8492, 3.0202, 3.0202, 1.2102, 0.3447, 1.3447,
        0.3447, 1.1547, 2.0547, 2.2447, 2.2447, 1.3447, 0.3830, 0.3830, 0.3830,
        1.2830, 2.2830, 1.3830, 1.3830, 0.3830, 0.4256, 0.4256, 0.4256, 1.4256,
        1.4256, 0.4256, 0.4256, 0.4256])

In [132]:
%%timeit
rollout = copy.deepcopy(temp_rollout)
returns = last_value
discounted_rollout = []

advantages = torch.zeros((8))

for i in reversed(range(len(rollout))):
    log_prob, values, rewards, terminals, entropy = rollout[i]
    #values = values.squeeze()
    returns = rewards + params["discount_gamma"] * terminals * returns


    next_values = last_value
    gae_returns = rewards + params["discount_gamma"] * terminals * next_values.detach()
    td_error = gae_returns - values.detach()
    advantages = advantages * params["gae_tau"] * params["discount_gamma"] * terminals + td_error

    discounted_rollout.append(
        [log_prob, values, returns, advantages, entropy])


424 µs ± 1.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [102]:
log_prob, value, returns, advantages, entropy = map(lambda x: torch.cat(x, dim=0), zip(*processed_rollout))

In [103]:
returns

tensor([1.2792, 2.9892, 1.2792, 1.9353, 3.5643, 2.7182, 2.7182, 1.0892, 0.3102,
        2.2102, 0.3102, 1.0392, 2.8492, 3.0202, 3.0202, 1.2102, 0.3447, 1.3447,
        0.3447, 1.1547, 2.0547, 2.2447, 2.2447, 1.3447, 0.3830, 0.3830, 0.3830,
        1.2830, 2.2830, 1.3830, 1.3830, 0.3830, 0.4256, 0.4256, 0.4256, 1.4256,
        1.4256, 0.4256, 0.4256, 0.4256])