In [2]:
import copy
import glob
import os
import time

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from arguments import get_args

#from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
#from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
#from baselines.common.vec_env.vec_normalize import VecNormalize
from all_stuff import * # this has the above modules consolidated into a single file. god this was a bitch

from envs import make_env # had to manually add some files into directory for env to reference bc baselines 
# modules not working right

from kfac import KFACOptimizer
from model import CNNPolicy, MLPPolicy
from storage import RolloutStorage
from visualize import visdom_plot

In [14]:
from visdom import Visdom

In [17]:
class args:
    def __init__(self):
        self.env_name='PongNoFrameskip-v4'
        self.seed=1
        self.log_dir=''
        self.save_dir='saved_models'
        self.cuda=True
        self.algo='a2c'
        self.num_stack=4
        self.num_steps=5
        self.num_processes=16
        self.recurrent_policy=False
        self.vis=False
        self.lr=7e-4
        self.eps=1e-5
        self.alpha=.99
        self.max_grad_norm=.5
        self.value_loss_coef=.5
        self.entropy_coef=.1
        self.num_frames=8e6
        self.use_gae=False
        self.gamma=.99
        self.tau=.95
        self.save_interval=1000
        self.log_interval=100
        self.vis_interval=100
        self.from_saved=True
        
args = args()

save_path = os.path.join(args.save_dir, args.algo)
SAVE_PATH = os.path.join(save_path, args.env_name + ".pt")

In [18]:
num_updates = int(args.num_frames) // args.num_steps // args.num_processes

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

try:
    os.makedirs(args.log_dir)
except OSError:
    files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
    for f in files:
        os.remove(f)

In [None]:
def main():
    print("#######")
    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
    print("#######")

    os.environ['OMP_NUM_THREADS'] = '1'

    if args.vis:
        from visdom import Visdom
        viz = Visdom()
        win = None

    envs = [make_env(args.env_name, args.seed, i, args.log_dir)
                for i in range(args.num_processes)]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    global actor_critic
    if len(envs.observation_space.shape) == 3:
        actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
    else:
        assert not args.recurrent_policy, \
            "Recurrent policy is not implemented for the MLP controller"
        actor_critic = MLPPolicy(obs_shape[0], envs.action_space)
        
    if args.from_saved:
        actor_critic.load_state_dict(torch.load(SAVE_PATH))

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic.cuda()

    if args.algo == 'a2c':
        optimizer = optim.RMSprop(actor_critic.parameters(), args.lr, eps=args.eps, alpha=args.alpha)
    elif args.algo == 'ppo':
        optimizer = optim.Adam(actor_critic.parameters(), args.lr, eps=args.eps)
    elif args.algo == 'acktr':
        optimizer = KFACOptimizer(actor_critic)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space,\
                              actor_critic.state_size)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            value, action, action_log_prob, states = actor_critic.act(Variable(rollouts.observations[step], volatile=True),
                                                                      Variable(rollouts.states[step], volatile=True),
                                                                      Variable(rollouts.masks[step], volatile=True))
            cpu_actions = action.data.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(step, current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)

        next_value = actor_critic(Variable(rollouts.observations[-1], volatile=True),
                                  Variable(rollouts.states[-1], volatile=True),
                                  Variable(rollouts.masks[-1], volatile=True))[0].data

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

        if args.algo in ['a2c', 'acktr']:
            values, action_log_probs, dist_entropy, states = actor_critic.evaluate_actions(Variable(rollouts.observations[:-1].view(-1, *obs_shape)),
                                                                                           Variable(rollouts.states[0].view(-1, actor_critic.state_size)),
                                                                                           Variable(rollouts.masks[:-1].view(-1, 1)),
                                                                                           Variable(rollouts.actions.view(-1, action_shape)))

            values = values.view(args.num_steps, args.num_processes, 1)
            action_log_probs = action_log_probs.view(args.num_steps, args.num_processes, 1)

            advantages = Variable(rollouts.returns[:-1]) - values
            value_loss = advantages.pow(2).mean()

            action_loss = -(Variable(advantages.data) * action_log_probs).mean()

            if args.algo == 'acktr' and optimizer.steps % optimizer.Ts == 0:
                # Sampled fisher, see Martens 2014
                actor_critic.zero_grad()
                pg_fisher_loss = -action_log_probs.mean()

                value_noise = Variable(torch.randn(values.size()))
                if args.cuda:
                    value_noise = value_noise.cuda()

                sample_values = values + value_noise
                vf_fisher_loss = -(values - Variable(sample_values.data)).pow(2).mean()

                fisher_loss = pg_fisher_loss + vf_fisher_loss
                optimizer.acc_stats = True
                fisher_loss.backward(retain_graph=True)
                optimizer.acc_stats = False

            optimizer.zero_grad()
            (value_loss * args.value_loss_coef + action_loss - dist_entropy * args.entropy_coef).backward()

            if args.algo == 'a2c':
                nn.utils.clip_grad_norm(actor_critic.parameters(), args.max_grad_norm)

            optimizer.step()
        
        
        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()
                
            torch.save(actor_critic.state_dict(), SAVE_PATH)

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
                format(j, total_num_steps,
                       int(total_num_steps / (end - start)),
                       final_rewards.mean(),
                       final_rewards.median(),
                       final_rewards.min(),
                       final_rewards.max(), dist_entropy.data[0],
                       value_loss.data[0], action_loss.data[0]))
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name, args.algo)
            except IOError:
                pass
            
main()

#######
#######
Updates 0, num timesteps 80, FPS 523, mean/median reward 0.0/0.0, min/max reward 0.0/0.0, entropy 1.79029, value loss 0.00267, policy loss 0.01246
Updates 100, num timesteps 8080, FPS 971, mean/median reward 0.0/0.0, min/max reward 0.0/0.0, entropy 1.76807, value loss 0.02701, policy loss -0.02641
Updates 200, num timesteps 16080, FPS 968, mean/median reward -1.3/0.0, min/max reward -21.0/0.0, entropy 1.78006, value loss 0.00998, policy loss -0.01206
Updates 300, num timesteps 24080, FPS 963, mean/median reward -9.4/-16.0, min/max reward -21.0/0.0, entropy 1.76831, value loss 0.01005, policy loss 0.05645
Updates 400, num timesteps 32080, FPS 960, mean/median reward -17.9/-18.0, min/max reward -21.0/-13.0, entropy 1.76969, value loss 0.01870, policy loss -0.02066
Updates 500, num timesteps 40080, FPS 960, mean/median reward -17.9/-18.0, min/max reward -21.0/-13.0, entropy 1.75021, value loss 0.01575, policy loss 0.05616
Updates 600, num timesteps 48080, FPS 961, mean/med

Updates 5000, num timesteps 400080, FPS 948, mean/median reward -15.6/-16.0, min/max reward -19.0/-12.0, entropy 1.76890, value loss 0.02039, policy loss 0.03667
Updates 5100, num timesteps 408080, FPS 947, mean/median reward -15.6/-16.0, min/max reward -19.0/-12.0, entropy 1.76513, value loss 0.02309, policy loss 0.05520
Updates 5200, num timesteps 416080, FPS 947, mean/median reward -15.7/-16.0, min/max reward -18.0/-12.0, entropy 1.75162, value loss 0.01138, policy loss 0.05928
Updates 5300, num timesteps 424080, FPS 948, mean/median reward -16.3/-16.0, min/max reward -20.0/-13.0, entropy 1.76595, value loss 0.01237, policy loss 0.04616
Updates 5400, num timesteps 432080, FPS 948, mean/median reward -16.4/-16.0, min/max reward -20.0/-13.0, entropy 1.74363, value loss 0.01202, policy loss -0.01641
Updates 5500, num timesteps 440080, FPS 948, mean/median reward -16.5/-16.0, min/max reward -21.0/-13.0, entropy 1.73912, value loss 0.02820, policy loss 0.03290
Updates 5600, num timesteps

Updates 10100, num timesteps 808080, FPS 950, mean/median reward -15.2/-15.0, min/max reward -19.0/-12.0, entropy 1.75017, value loss 0.02734, policy loss -0.01731
Updates 10200, num timesteps 816080, FPS 950, mean/median reward -15.8/-16.0, min/max reward -19.0/-12.0, entropy 1.78312, value loss 0.00488, policy loss 0.02538
Updates 10300, num timesteps 824080, FPS 950, mean/median reward -15.5/-16.0, min/max reward -19.0/-11.0, entropy 1.76112, value loss 0.02674, policy loss -0.00508
Updates 10400, num timesteps 832080, FPS 950, mean/median reward -15.4/-15.0, min/max reward -19.0/-11.0, entropy 1.72969, value loss 0.01597, policy loss 0.04706
Updates 10500, num timesteps 840080, FPS 950, mean/median reward -15.1/-15.0, min/max reward -19.0/-11.0, entropy 1.77237, value loss 0.00759, policy loss 0.02207
Updates 10600, num timesteps 848080, FPS 950, mean/median reward -14.5/-15.0, min/max reward -18.0/-11.0, entropy 1.76263, value loss 0.02021, policy loss 0.00143
Updates 10700, num t

Updates 15200, num timesteps 1216080, FPS 951, mean/median reward -13.4/-14.0, min/max reward -18.0/-10.0, entropy 1.74420, value loss 0.03319, policy loss 0.01899
Updates 15300, num timesteps 1224080, FPS 951, mean/median reward -14.2/-14.0, min/max reward -17.0/-10.0, entropy 1.78659, value loss 0.01999, policy loss 0.01770
Updates 15400, num timesteps 1232080, FPS 951, mean/median reward -13.5/-14.0, min/max reward -17.0/-6.0, entropy 1.72720, value loss 0.03928, policy loss 0.00455
Updates 15500, num timesteps 1240080, FPS 951, mean/median reward -13.8/-14.0, min/max reward -18.0/-6.0, entropy 1.78256, value loss 0.01599, policy loss 0.00403
Updates 15600, num timesteps 1248080, FPS 951, mean/median reward -13.8/-15.0, min/max reward -18.0/-6.0, entropy 1.77036, value loss 0.00936, policy loss -0.01223
Updates 15700, num timesteps 1256080, FPS 951, mean/median reward -13.6/-14.0, min/max reward -18.0/-6.0, entropy 1.74907, value loss 0.02234, policy loss -0.00887
Updates 15800, num

Updates 20300, num timesteps 1624080, FPS 951, mean/median reward -9.2/-10.0, min/max reward -18.0/-2.0, entropy 1.74636, value loss 0.03698, policy loss 0.09638
Updates 20400, num timesteps 1632080, FPS 951, mean/median reward -9.2/-10.0, min/max reward -18.0/-3.0, entropy 1.75421, value loss 0.02389, policy loss -0.02937
Updates 20500, num timesteps 1640080, FPS 951, mean/median reward -8.5/-9.0, min/max reward -18.0/-2.0, entropy 1.77102, value loss 0.00586, policy loss 0.02259
Updates 20600, num timesteps 1648080, FPS 951, mean/median reward -6.9/-8.0, min/max reward -12.0/-2.0, entropy 1.77270, value loss 0.02521, policy loss -0.00685
Updates 20700, num timesteps 1656080, FPS 951, mean/median reward -6.7/-7.0, min/max reward -12.0/-2.0, entropy 1.74195, value loss 0.00558, policy loss -0.00133
Updates 20800, num timesteps 1664080, FPS 951, mean/median reward -7.6/-8.0, min/max reward -15.0/-2.0, entropy 1.74048, value loss 0.01089, policy loss -0.01343
Updates 20900, num timesteps

Updates 25400, num timesteps 2032080, FPS 951, mean/median reward -2.8/-2.0, min/max reward -11.0/9.0, entropy 1.73753, value loss 0.02376, policy loss 0.03826
Updates 25500, num timesteps 2040080, FPS 951, mean/median reward -3.9/-5.0, min/max reward -13.0/9.0, entropy 1.75275, value loss 0.01684, policy loss -0.02172
Updates 25600, num timesteps 2048080, FPS 951, mean/median reward -6.2/-7.0, min/max reward -13.0/9.0, entropy 1.77328, value loss 0.05090, policy loss -0.10125
Updates 25700, num timesteps 2056080, FPS 951, mean/median reward -6.4/-7.0, min/max reward -13.0/9.0, entropy 1.75442, value loss 0.01173, policy loss -0.00516
Updates 25800, num timesteps 2064080, FPS 951, mean/median reward -7.7/-9.0, min/max reward -13.0/9.0, entropy 1.73971, value loss 0.00581, policy loss 0.02782
Updates 25900, num timesteps 2072080, FPS 951, mean/median reward -8.4/-10.0, min/max reward -14.0/9.0, entropy 1.77311, value loss 0.02909, policy loss -0.07565
Updates 26000, num timesteps 208008

Updates 30600, num timesteps 2448080, FPS 951, mean/median reward -3.7/-4.0, min/max reward -11.0/4.0, entropy 1.73158, value loss 0.03355, policy loss -0.05678
Updates 30700, num timesteps 2456080, FPS 951, mean/median reward -4.3/-6.0, min/max reward -11.0/4.0, entropy 1.75808, value loss 0.01149, policy loss 0.01564
Updates 30800, num timesteps 2464080, FPS 951, mean/median reward -4.1/-4.0, min/max reward -11.0/4.0, entropy 1.75640, value loss 0.02941, policy loss -0.02936
Updates 30900, num timesteps 2472080, FPS 951, mean/median reward -2.3/-3.0, min/max reward -11.0/7.0, entropy 1.76954, value loss 0.01074, policy loss -0.08342
Updates 31000, num timesteps 2480080, FPS 951, mean/median reward 0.7/-1.0, min/max reward -8.0/12.0, entropy 1.78501, value loss 0.01222, policy loss 0.02139
Updates 31100, num timesteps 2488080, FPS 951, mean/median reward 1.3/1.0, min/max reward -6.0/12.0, entropy 1.72220, value loss 0.02326, policy loss 0.00245
Updates 31200, num timesteps 2496080, FP

Updates 35800, num timesteps 2864080, FPS 953, mean/median reward -3.9/-6.0, min/max reward -15.0/9.0, entropy 1.67741, value loss 0.03446, policy loss 0.00441
Updates 35900, num timesteps 2872080, FPS 953, mean/median reward -4.2/-6.0, min/max reward -15.0/7.0, entropy 1.74674, value loss 0.01058, policy loss -0.03855
Updates 36000, num timesteps 2880080, FPS 953, mean/median reward -2.8/-5.0, min/max reward -15.0/7.0, entropy 1.74719, value loss 0.04161, policy loss 0.03043
Updates 36100, num timesteps 2888080, FPS 954, mean/median reward -1.0/-1.0, min/max reward -15.0/11.0, entropy 1.78451, value loss 0.01196, policy loss -0.02384
Updates 36200, num timesteps 2896080, FPS 954, mean/median reward 0.1/3.0, min/max reward -15.0/11.0, entropy 1.72872, value loss 0.01344, policy loss -0.01655
Updates 36300, num timesteps 2904080, FPS 954, mean/median reward 1.0/3.0, min/max reward -15.0/11.0, entropy 1.67733, value loss 0.02535, policy loss 0.03178
Updates 36400, num timesteps 2912080, 

Updates 41000, num timesteps 3280080, FPS 955, mean/median reward 10.1/11.0, min/max reward -2.0/18.0, entropy 1.66653, value loss 0.01354, policy loss 0.05572
Updates 41100, num timesteps 3288080, FPS 955, mean/median reward 11.3/12.0, min/max reward -2.0/18.0, entropy 1.69453, value loss 0.01878, policy loss 0.05319
Updates 41200, num timesteps 3296080, FPS 955, mean/median reward 12.2/12.0, min/max reward 4.0/18.0, entropy 1.72717, value loss 0.02135, policy loss 0.04833
Updates 41300, num timesteps 3304080, FPS 955, mean/median reward 10.1/11.0, min/max reward -8.0/17.0, entropy 1.77991, value loss 0.03691, policy loss -0.12717
Updates 41400, num timesteps 3312080, FPS 955, mean/median reward 8.7/11.0, min/max reward -8.0/15.0, entropy 1.70483, value loss 0.02541, policy loss -0.01995
Updates 41500, num timesteps 3320080, FPS 955, mean/median reward 4.9/5.0, min/max reward -8.0/14.0, entropy 1.77298, value loss 0.01340, policy loss 0.03803
Updates 41600, num timesteps 3328080, FPS 