In [None]:
# import argparse
import gym
import numpy as np
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable


# parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
# parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
#                     help='discount factor (default: 0.99)')
# parser.add_argument('--seed', type=int, default=543, metavar='N',
#                     help='random seed (default: 543)')
# parser.add_argument('--render', action='store_true',
#                     help='render the environment')
# parser.add_argument('--log_interval', type=int, default=10, metavar='N',
#                     help='interval between training status logs (default: 10)')
# args = parser.parse_args()

gamma = 0.99
seed = 543
log_interval = 10

render_opt = False

env = gym.make('CartPole-v0')
env.seed(seed)
torch.manual_seed(seed)


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.affine2 = nn.Linear(128, 2)

        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.affine2(x)
        return F.softmax(action_scores)


policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)


def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(Variable(state))
    action = probs.multinomial()
    policy.saved_actions.append(action)
    return action.data


def finish_episode():
    R = 0
    rewards = []
    for r in policy.rewards[::-1]:
        R = r + gamma * R
        rewards.insert(0, R)
    rewards = torch.Tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
    for action, r in zip(policy.saved_actions, rewards):
        action.reinforce(r)
    optimizer.zero_grad()
    autograd.backward(policy.saved_actions, [None for _ in policy.saved_actions])
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_actions[:]


running_reward = 10
for i_episode in count(1):
    state = env.reset()
    for t in range(10000): # Don't infinite loop while learning
        action = select_action(state)
        state, reward, done, _ = env.step(action[0,0])
        if render_opt:
            env.render()
        policy.rewards.append(reward)
        if done:
            break

    running_reward = running_reward * 0.99 + t * 0.01
    finish_episode()
    if i_episode % log_interval == 0:
        print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
            i_episode, t, running_reward))
    if running_reward > 200:
        print("Solved! Running reward is now {} and "
              "the last episode runs to {} time steps!".format(running_reward, t))
        break

[2017-07-14 06:35:27,021] Making new env: CartPole-v0


Episode 10	Last length:    13	Average length: 10.64
Episode 20	Last length:    24	Average length: 11.37
Episode 30	Last length:   115	Average length: 15.63
Episode 40	Last length:    17	Average length: 19.16
Episode 50	Last length:    77	Average length: 22.33
Episode 60	Last length:    52	Average length: 24.56
Episode 70	Last length:    67	Average length: 28.63
Episode 80	Last length:   199	Average length: 40.23
Episode 90	Last length:   167	Average length: 54.04
Episode 100	Last length:   184	Average length: 63.67
Episode 110	Last length:   199	Average length: 73.99
Episode 120	Last length:    11	Average length: 70.85
Episode 130	Last length:     8	Average length: 64.98
Episode 140	Last length:    10	Average length: 59.67
Episode 150	Last length:    12	Average length: 55.06
Episode 160	Last length:    13	Average length: 51.02
Episode 170	Last length:   199	Average length: 58.26
Episode 180	Last length:   199	Average length: 71.72
Episode 190	Last length:   199	Average length: 83.89
Ep

Episode 1530	Last length:   199	Average length: 171.08
Episode 1540	Last length:   199	Average length: 173.75
Episode 1550	Last length:   199	Average length: 176.16
Episode 1560	Last length:   199	Average length: 178.35
Episode 1570	Last length:   199	Average length: 180.32
Episode 1580	Last length:   199	Average length: 182.11
Episode 1590	Last length:   199	Average length: 183.72
Episode 1600	Last length:   199	Average length: 185.18
Episode 1610	Last length:   199	Average length: 186.50
Episode 1620	Last length:   199	Average length: 187.70
Episode 1630	Last length:   199	Average length: 185.27
Episode 1640	Last length:   199	Average length: 179.46
Episode 1650	Last length:   199	Average length: 177.78
Episode 1660	Last length:   199	Average length: 179.81
Episode 1670	Last length:   199	Average length: 181.64
Episode 1680	Last length:   199	Average length: 183.30
Episode 1690	Last length:   199	Average length: 184.80
Episode 1700	Last length:   199	Average length: 186.16
Episode 17

KeyboardInterrupt: 

[2017-07-14 06:39:11,630] Uncaught exception, closing connection.
Traceback (most recent call last):
  File "/users/aditya.a/Libraries/miniconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/users/aditya.a/Libraries/miniconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/users/aditya.a/Libraries/miniconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/users/aditya.a/Libraries/miniconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/users/aditya.a/Libraries/miniconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 421, in execute_request
    self._abort_queues()
  File "/users/aditya.a/Libraries/miniconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line