In [1]:
import copy
import glob
import os
import time

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from envs import make_env 

This model was made by @ikostrikov. I simply streamlined it and removed the dependency on OpenAI baselines package in an effort to:
1) See what was going on under the hood
2) Have an easy-to-use model for experiments in transfer learning
3) Remove dependency on Baselines, which was causing some problem

In [2]:
class args:
    def __init__(self):
        self.env_name='PongNoFrameskip-v4'
        self.seed=1
        self.log_dir=''
        self.save_dir='saved_models'
        self.cuda=True
        self.num_stack=4
        self.num_steps=5
        self.num_processes=16
        self.lr=7e-4
        self.eps=1e-5
        self.alpha=.99
        self.max_grad_norm=.5
        self.value_loss_coef=.5
        self.entropy_coef=.1
        self.num_frames=8e6
        self.use_gae=False
        self.gamma=.99
        self.tau=.95
        self.save_interval=1000
        self.log_interval=100
        self.vis_interval=100
        self.load_model=True
        self.save_model=False
        
args = args()

SAVE_PATH = "saved_models/a2c_121717.pt"
LOAD_PATH = "saved_models/a2c_121717.pt"

In [3]:
num_updates = int(args.num_frames) // args.num_steps // args.num_processes

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

In [7]:
def main():
    os.environ['OMP_NUM_THREADS'] = '1'

    envs = [make_env(args.env_name, args.seed, i, args.log_dir)
                for i in range(args.num_processes)]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
    
    actor_critic = CNNPolicy(obs_shape[0], envs.action_space)
  
    if args.load_model:
        actor_critic.load_state_dict(torch.load(LOAD_PATH))

    action_shape = 1

    if args.cuda:
        actor_critic.cuda()

    optimizer = optim.RMSprop(actor_critic.parameters(), args.lr, eps=args.eps, alpha=args.alpha)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space,\
                              actor_critic.state_size)
    
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)
    
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            
            value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True),
                                                                      Variable(rollouts.states[step], volatile=True),
                                                                      Variable(rollouts.masks[step], volatile=True))
            cpu_actions = action.data.squeeze(1).cpu().numpy()

            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)

        next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
                                  Variable(rollouts.states[-1], volatile=True),
                                  Variable(rollouts.masks[-1], volatile=True))[0].data

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

        values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
                                                                                       Variable(rollouts.states[0].view(-1, actor_critic.state_size)),
                                                                                       Variable(rollouts.masks[:-1].view(-1, 1)),
                                                                                       Variable(rollouts.actions.view(-1, action_shape)))

        values = values.view(args.num_steps, args.num_processes, 1)
        action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)

        advantages = Variable(rollouts.returns[:-1]) - values
        value_loss = advantages.pow(2).mean()

        action_loss = -(Variable(advantages.data) * action_log_probs).mean()

        optimizer.zero_grad()
        (value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()

        nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)

        optimizer.step()
 
        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_model:
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()
                
            torch.save(actor_critic.state_dict(), SAVE_PATH)

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
                format(j, total_num_steps,
                       int(total_num_steps / (end - start)),
                       final_rewards.mean(),
                       final_rewards.median(),
                       final_rewards.min(),
                       final_rewards.max(), dist_entropy.data[0],
                       value_loss.data[0], action_loss.data[0]))
            
main()

Updates 0, num timesteps 80, FPS 44, mean/median reward 0.0/0.0, min/max reward 0.0/0.0, entropy 0.22002, value loss 276.99185, policy loss -2.29684
Updates 100, num timesteps 8080, FPS 808, mean/median reward 0.0/0.0, min/max reward 0.0/0.0, entropy 1.78344, value loss 0.12969, policy loss -0.02170
Updates 200, num timesteps 16080, FPS 872, mean/median reward -17.9/-21.0, min/max reward -21.0/0.0, entropy 1.78924, value loss 0.01250, policy loss 0.04140
Updates 300, num timesteps 24080, FPS 902, mean/median reward -20.4/-21.0, min/max reward -21.0/-19.0, entropy 1.78281, value loss 0.08495, policy loss -0.04526
Updates 400, num timesteps 32080, FPS 911, mean/median reward -20.4/-21.0, min/max reward -21.0/-19.0, entropy 1.78493, value loss 0.08840, policy loss -0.09418
Updates 500, num timesteps 40080, FPS 922, mean/median reward -20.4/-21.0, min/max reward -21.0/-19.0, entropy 1.78297, value loss 0.06992, policy loss 0.17489
Updates 600, num timesteps 48080, FPS 926, mean/median rewa

Updates 5100, num timesteps 408080, FPS 959, mean/median reward -19.9/-20.0, min/max reward -21.0/-17.0, entropy 1.78811, value loss 0.01487, policy loss -0.02550
Updates 5200, num timesteps 416080, FPS 959, mean/median reward -20.2/-20.0, min/max reward -21.0/-18.0, entropy 1.79041, value loss 0.01049, policy loss 0.05789
Updates 5300, num timesteps 424080, FPS 959, mean/median reward -20.4/-21.0, min/max reward -21.0/-19.0, entropy 1.78867, value loss 0.00843, policy loss -0.00056
Updates 5400, num timesteps 432080, FPS 960, mean/median reward -20.0/-20.0, min/max reward -21.0/-17.0, entropy 1.78900, value loss 0.02422, policy loss -0.05720
Updates 5500, num timesteps 440080, FPS 960, mean/median reward -20.0/-20.0, min/max reward -21.0/-18.0, entropy 1.79090, value loss 0.00909, policy loss -0.00639
Updates 5600, num timesteps 448080, FPS 960, mean/median reward -20.0/-20.0, min/max reward -21.0/-18.0, entropy 1.78876, value loss 0.01732, policy loss 0.09902
Updates 5700, num timest

Updates 10200, num timesteps 816080, FPS 965, mean/median reward -19.7/-20.0, min/max reward -21.0/-18.0, entropy 1.78951, value loss 0.01650, policy loss -0.04056
Updates 10300, num timesteps 824080, FPS 965, mean/median reward -19.9/-20.0, min/max reward -21.0/-18.0, entropy 1.78906, value loss 0.03442, policy loss -0.11058
Updates 10400, num timesteps 832080, FPS 965, mean/median reward -20.1/-20.0, min/max reward -21.0/-16.0, entropy 1.78903, value loss 0.01202, policy loss -0.00304
Updates 10500, num timesteps 840080, FPS 965, mean/median reward -20.3/-21.0, min/max reward -21.0/-16.0, entropy 1.78958, value loss 0.01562, policy loss 0.08803
Updates 10600, num timesteps 848080, FPS 965, mean/median reward -20.0/-20.0, min/max reward -21.0/-19.0, entropy 1.79017, value loss 0.01205, policy loss 0.07234
Updates 10700, num timesteps 856080, FPS 965, mean/median reward -19.7/-20.0, min/max reward -21.0/-17.0, entropy 1.78871, value loss 0.02370, policy loss -0.08772
Updates 10800, num

Updates 15200, num timesteps 1216080, FPS 967, mean/median reward -19.4/-20.0, min/max reward -21.0/-17.0, entropy 1.78087, value loss 0.02307, policy loss -0.03950
Updates 15300, num timesteps 1224080, FPS 967, mean/median reward -19.2/-19.0, min/max reward -21.0/-17.0, entropy 1.78090, value loss 0.05348, policy loss 0.07946
Updates 15400, num timesteps 1232080, FPS 967, mean/median reward -19.5/-20.0, min/max reward -21.0/-18.0, entropy 1.77885, value loss 0.02499, policy loss 0.02787
Updates 15500, num timesteps 1240080, FPS 968, mean/median reward -19.6/-19.0, min/max reward -21.0/-19.0, entropy 1.78404, value loss 0.00853, policy loss 0.03214
Updates 15600, num timesteps 1248080, FPS 967, mean/median reward -19.1/-19.0, min/max reward -21.0/-17.0, entropy 1.78098, value loss 0.04706, policy loss -0.00625
Updates 15700, num timesteps 1256080, FPS 968, mean/median reward -18.7/-19.0, min/max reward -20.0/-16.0, entropy 1.78341, value loss 0.06280, policy loss 0.03279
Updates 15800,

Updates 20200, num timesteps 1616080, FPS 969, mean/median reward -18.5/-19.0, min/max reward -21.0/-16.0, entropy 1.77127, value loss 0.01239, policy loss 0.03854
Updates 20300, num timesteps 1624080, FPS 969, mean/median reward -18.7/-19.0, min/max reward -21.0/-14.0, entropy 1.74376, value loss 0.01250, policy loss -0.04821
Updates 20400, num timesteps 1632080, FPS 969, mean/median reward -18.2/-19.0, min/max reward -21.0/-14.0, entropy 1.78521, value loss 0.05471, policy loss 0.10010
Updates 20500, num timesteps 1640080, FPS 969, mean/median reward -18.1/-19.0, min/max reward -21.0/-14.0, entropy 1.76032, value loss 0.02497, policy loss -0.08266
Updates 20600, num timesteps 1648080, FPS 969, mean/median reward -17.7/-18.0, min/max reward -20.0/-14.0, entropy 1.76023, value loss 0.03277, policy loss -0.02534
Updates 20700, num timesteps 1656080, FPS 969, mean/median reward -17.8/-18.0, min/max reward -20.0/-15.0, entropy 1.78542, value loss 0.01700, policy loss -0.02744
Updates 2080

Updates 25200, num timesteps 2016080, FPS 970, mean/median reward -15.1/-15.0, min/max reward -20.0/-11.0, entropy 1.75234, value loss 0.05111, policy loss -0.04186
Updates 25300, num timesteps 2024080, FPS 970, mean/median reward -15.4/-16.0, min/max reward -18.0/-12.0, entropy 1.75428, value loss 0.00884, policy loss 0.00004
Updates 25400, num timesteps 2032080, FPS 970, mean/median reward -15.9/-16.0, min/max reward -18.0/-12.0, entropy 1.76341, value loss 0.02214, policy loss -0.06472
Updates 25500, num timesteps 2040080, FPS 970, mean/median reward -16.4/-17.0, min/max reward -20.0/-10.0, entropy 1.75968, value loss 0.10450, policy loss -0.03081
Updates 25600, num timesteps 2048080, FPS 970, mean/median reward -16.7/-17.0, min/max reward -20.0/-10.0, entropy 1.75934, value loss 0.02588, policy loss -0.09025
Updates 25700, num timesteps 2056080, FPS 970, mean/median reward -16.4/-17.0, min/max reward -20.0/-10.0, entropy 1.78018, value loss 0.00769, policy loss 0.02153
Updates 2580

Updates 30200, num timesteps 2416080, FPS 971, mean/median reward -15.8/-17.0, min/max reward -20.0/-10.0, entropy 1.77345, value loss 0.01981, policy loss 0.08759
Updates 30300, num timesteps 2424080, FPS 971, mean/median reward -15.6/-15.0, min/max reward -20.0/-10.0, entropy 1.77067, value loss 0.01360, policy loss 0.03163
Updates 30400, num timesteps 2432080, FPS 971, mean/median reward -16.0/-17.0, min/max reward -20.0/-10.0, entropy 1.75098, value loss 0.00983, policy loss -0.00155
Updates 30500, num timesteps 2440080, FPS 971, mean/median reward -15.3/-16.0, min/max reward -19.0/-10.0, entropy 1.76022, value loss 0.02399, policy loss -0.03885
Updates 30600, num timesteps 2448080, FPS 971, mean/median reward -14.9/-15.0, min/max reward -19.0/-10.0, entropy 1.75143, value loss 0.03435, policy loss -0.11808
Updates 30700, num timesteps 2456080, FPS 971, mean/median reward -15.0/-15.0, min/max reward -19.0/-11.0, entropy 1.71839, value loss 0.02559, policy loss 0.04041
Updates 30800

Updates 35200, num timesteps 2816080, FPS 972, mean/median reward -12.0/-13.0, min/max reward -18.0/-6.0, entropy 1.76781, value loss 0.01982, policy loss -0.01758
Updates 35300, num timesteps 2824080, FPS 972, mean/median reward -12.0/-13.0, min/max reward -18.0/-6.0, entropy 1.73327, value loss 0.01379, policy loss -0.09113
Updates 35400, num timesteps 2832080, FPS 972, mean/median reward -12.4/-13.0, min/max reward -18.0/-6.0, entropy 1.76866, value loss 0.01256, policy loss 0.00879
Updates 35500, num timesteps 2840080, FPS 972, mean/median reward -12.2/-13.0, min/max reward -16.0/-6.0, entropy 1.74637, value loss 0.02424, policy loss -0.06033
Updates 35600, num timesteps 2848080, FPS 972, mean/median reward -12.6/-13.0, min/max reward -16.0/-6.0, entropy 1.77000, value loss 0.02239, policy loss 0.02790
Updates 35700, num timesteps 2856080, FPS 972, mean/median reward -12.3/-13.0, min/max reward -17.0/-6.0, entropy 1.66442, value loss 0.06933, policy loss 0.01119
Updates 35800, num 

Updates 40300, num timesteps 3224080, FPS 973, mean/median reward -6.6/-8.0, min/max reward -14.0/4.0, entropy 1.69997, value loss 0.01959, policy loss 0.06930
Updates 40400, num timesteps 3232080, FPS 973, mean/median reward -6.9/-8.0, min/max reward -14.0/4.0, entropy 1.77823, value loss 0.01306, policy loss -0.00873
Updates 40500, num timesteps 3240080, FPS 973, mean/median reward -7.2/-9.0, min/max reward -14.0/4.0, entropy 1.77146, value loss 0.05634, policy loss 0.00381
Updates 40600, num timesteps 3248080, FPS 973, mean/median reward -7.2/-9.0, min/max reward -14.0/4.0, entropy 1.64191, value loss 0.01695, policy loss -0.03295
Updates 40700, num timesteps 3256080, FPS 973, mean/median reward -7.4/-9.0, min/max reward -14.0/4.0, entropy 1.77805, value loss 0.02736, policy loss 0.10404
Updates 40800, num timesteps 3264080, FPS 973, mean/median reward -8.4/-9.0, min/max reward -14.0/-2.0, entropy 1.72492, value loss 0.02657, policy loss 0.05143
Updates 40900, num timesteps 3272080,

Updates 45400, num timesteps 3632080, FPS 974, mean/median reward -6.4/-6.0, min/max reward -12.0/-3.0, entropy 1.77631, value loss 0.07451, policy loss 0.10090
Updates 45500, num timesteps 3640080, FPS 974, mean/median reward -6.6/-6.0, min/max reward -12.0/-3.0, entropy 1.77519, value loss 0.01599, policy loss -0.09302
Updates 45600, num timesteps 3648080, FPS 974, mean/median reward -5.6/-6.0, min/max reward -10.0/2.0, entropy 1.71904, value loss 0.05460, policy loss -0.03105
Updates 45700, num timesteps 3656080, FPS 974, mean/median reward -4.9/-5.0, min/max reward -10.0/2.0, entropy 1.75295, value loss 0.02363, policy loss 0.01933
Updates 45800, num timesteps 3664080, FPS 974, mean/median reward -4.6/-5.0, min/max reward -10.0/2.0, entropy 1.73894, value loss 0.01364, policy loss 0.09037
Updates 45900, num timesteps 3672080, FPS 974, mean/median reward -4.6/-5.0, min/max reward -10.0/2.0, entropy 1.70345, value loss 0.06189, policy loss 0.15229
Updates 46000, num timesteps 3680080

Process Process-9:
Process Process-10:
Process Process-8:
Process Process-2:
Process Process-1:
Process Process-4:
Process Process-11:
Process Process-7:
Process Process-5:
Process Process-6:
Process Process-15:
Process Process-3:
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-12:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process Process-16:
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/anaconda3/lib/python3.6

  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/core.py", line 315, in _step
    observation, reward, done, info = self.env.step(action)
KeyboardInterrupt
  File "/home/ubuntu/anaconda3/lib/python3.6/multiprocessing/c

  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/wrappers/time_limit.py", line 36, in _step
    observation, reward, done, info = self.env.step(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/core.py", line 96, in step
    return self._step(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/core.py", line 96, in step
    return self._step(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/core.py", line 96, in step
    return self._step(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/core.py", line 280, in _step
    return self.env.step(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 80, in _step
    reward += self.ale.act(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/gym/core.py", line 280, in _step
    return self.env.step(action)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/atari_py/ale_python_int

KeyboardInterrupt: 

The following cells are taken from the Baselines package.

In [4]:
from utils import orthogonal

class Categorical(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Categorical, self).__init__()
        self.linear = nn.Linear(num_inputs, num_outputs)

    def forward(self, x):
        x = self.linear(x)
        return x

    def sample(self, x, deterministic):
        x = self(x)

        probs = F.softmax(x)
        if deterministic is False:
            action = probs.multinomial()
        else:
            action = probs.max(1)[1]
        return action

    def logprobs_and_entropy(self, x, actions):
        x = self(x)

        log_probs = F.log_softmax(x)
        probs = F.softmax(x)

        action_log_probs = log_probs.gather(1, actions)

        dist_entropy = -(log_probs * probs).sum(-1).mean()
        return action_log_probs, dist_entropy
    

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        orthogonal(m.weight.data)
        if m.bias is not None:
            m.bias.data.fill_(0)


class CNNPolicy(nn.Module):
    def __init__(self, num_inputs, action_space):
        super(CNNPolicy, self).__init__()
        self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 32, 3, stride=1)

        self.linear1 = nn.Linear(32 * 7 * 7, 512)

        self.critic_linear = nn.Linear(512, 1)

        num_outputs = action_space.n
        self.dist = Categorical(512, num_outputs)

        self.train() # training mode. Only affects dropout, batchnorm etc
        self.reset_parameters()
        
    def act(self, inputs, states, masks, deterministic=False):
        value, x, states = self(inputs, states, masks)
        action = self.dist.sample(x, deterministic=deterministic)
        action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, action)
        return value, action, action_log_probs, states

    def evaluate_actions(self, inputs, states, masks, actions):
        value, x, states = self(inputs, states, masks)
        action_log_probs, dist_entropy = self.dist.logprobs_and_entropy(x, actions)
        return value, action_log_probs, dist_entropy, states

    @property
    def state_size(self):
        return 1

    def reset_parameters(self):
        self.apply(weights_init)

        relu_gain = nn.init.calculate_gain('relu')
        self.conv1.weight.data.mul_(relu_gain)
        self.conv2.weight.data.mul_(relu_gain)
        self.conv3.weight.data.mul_(relu_gain)
        self.linear1.weight.data.mul_(relu_gain)

    def forward(self, inputs, states, masks):
        x = self.conv1(inputs / 255.0)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.relu(x)

        x = self.conv3(x)
        x = F.relu(x)

        x = x.view(-1, 32 * 7 * 7)
        x = self.linear1(x)
        x = F.relu(x)

        return self.critic_linear(x), x, states

In [5]:
class RolloutStorage(object):
    def __init__(self, num_steps, num_processes, obs_shape, action_space, state_size):
        self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape)
        self.states = torch.zeros(num_steps + 1, num_processes, state_size)
        self.rewards = torch.zeros(num_steps, num_processes, 1)
        self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
        self.returns = torch.zeros(num_steps + 1, num_processes, 1)
        self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
        
        action_shape = 1

        self.actions = torch.zeros(num_steps, num_processes, action_shape)
            
        self.actions = self.actions.long()
        self.masks = torch.ones(num_steps + 1, num_processes, 1)

    def cuda(self):
        self.observations = self.observations.cuda()
        self.states = self.states.cuda()
        self.rewards = self.rewards.cuda()
        self.value_preds = self.value_preds.cuda()
        self.returns = self.returns.cuda()
        self.action_log_probs = self.action_log_probs.cuda()
        self.actions = self.actions.cuda()
        self.masks = self.masks.cuda()

    def insert(self, step, current_obs, state, action, action_log_prob, value_pred, reward, mask):
        self.observations[step + 1].copy_(current_obs)
        self.states[step + 1].copy_(state)
        self.actions[step].copy_(action)
        self.action_log_probs[step].copy_(action_log_prob)
        self.value_preds[step].copy_(value_pred)
        self.rewards[step].copy_(reward)
        self.masks[step + 1].copy_(mask)

    def after_update(self):
        self.observations[0].copy_(self.observations[-1])
        self.states[0].copy_(self.states[-1])
        self.masks[0].copy_(self.masks[-1])

    def compute_returns(self, next_value, use_gae, gamma, tau):
        self.returns[-1] = next_value
        for step in reversed(range(self.rewards.size(0))):
            self.returns[step] = self.returns[step + 1] * \
                gamma * self.masks[step + 1] + self.rewards[step]


In [6]:
from multiprocessing import Process, Pipe

def worker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            if done:
                ob = env.reset()
            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset()
            remote.send(ob)
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'close':
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send((env.action_space, env.observation_space))
        else:
            raise NotImplementedError


class CloudpickleWrapper(object):
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """
    def __init__(self, x):
        self.x = x
    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)
    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)



class SubprocVecEnv(object):
    def __init__(self, env_fns):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
            for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = True # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()

        self.remotes[0].send(('get_spaces', None))
        self.action_space, self.observation_space = self.remotes[0].recv()


    def step(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        results = [remote.recv() for remote in self.remotes]
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return

        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True

    @property
    def num_envs(self):
        return len(self.remotes)

