In [101]:
import argparse
import gym
import numpy as np
from itertools import count
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable


class args:
    def __init__(self):
        self.seed = 1
        self.gamma = .99
        self.render = False
        self.log_interval = 50

args = args()
        
env = gym.make('CartPole-v0')
env.seed(args.seed)
torch.manual_seed(args.seed)

<torch._C.Generator at 0x7f8df00adf60>

In [102]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.action_head = nn.Linear(128, 2)
        self.value_head = nn.Linear(128, 1)

        self.actions = []
        self.action_log_probs = []
        self.entropies = []
        self.state_values = []
        self.rewards = []
        

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.action_head(x)
        state_values = self.value_head(x)
        return F.softmax(action_scores), state_values


model = Policy()
optimizer = optim.Adam(model.parameters(), lr=3e-2)

def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    
    probs, state_value = model(Variable(state))
    action = probs.multinomial().data[0][0]
    
    prob = probs[:, action].view(1, -1)
    action_log_prob = prob.log()
    entropy = - (probs*probs.log()).sum()
    
    return action, action_log_prob, entropy, state_value


In [103]:

running_reward = 10
for i_episode in count(1):
    state = env.reset()
    for t in range(1000): # Don't infinite loop while learning
        action, action_log_prob, entropy, state_value = select_action(state)
        #print(action)
        state, reward, done, _ = env.step(action)
        #print(state, reward)

        model.rewards.append(reward)
        model.actions.append(action)
        model.action_log_probs.append(action_log_prob)
        model.entropies.append(entropy)
        model.state_values.append(state_value)
        
        if done:
            break

    running_reward = running_reward * 0.99 + t * 0.01
    
    #########################
    # Finish episode
    R = 0
    value_loss = 0
    rewards = []
    for r in model.rewards[::-1]:
        R = r + args.gamma * R
        rewards.insert(0, R)
        
        loss = loss - (model.action_log_probs[i]\
                       *(Variable(R).expand_as(log_probs[i])).cuda()).sum()\
            - (0.0001*entropies[i].cuda()).sum()
    
    global advantages, value_loss, action_gain
    
    advantages = Variable(torch.FloatTensor(rewards), requires_grad=True)
    - Variable(torch.FloatTensor([i.data[0][0] for i in model.state_values]), requires_grad=True)
    
    value_loss = advantages.pow(2).mean()
    action_gain = (advantages *\
                   Variable(torch.FloatTensor([i.data[0][0] for i in model.action_log_probs]), requires_grad=True)).mean()
    
    total_loss = value_loss\
    - action_gain - Variable(torch.FloatTensor([i.data[0] for i in model.entropies]), requires_grad=True).mean()
    optimizer.zero_grad()
    
    total_loss.backward()
    
    optimizer.step()
    del model.rewards[:]
    del model.actions[:]
    del model.action_log_probs[:]
    del model.entropies[:]
    del model.state_values[:]
    ########################################
    
    if i_episode % args.log_interval == 0:
        print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
            i_episode, t, running_reward))
    if running_reward > env.spec.reward_threshold:
        print("Solved! Running reward is now {} and "
              "the last episode runs to {} time steps!".format(running_reward, t))
        break

  global advantages, value_loss, action_gain


Episode 50	Last length:    17	Average length: 15.17
Episode 100	Last length:    19	Average length: 17.16
Episode 150	Last length:    17	Average length: 19.06
Episode 200	Last length:    13	Average length: 19.94
Episode 250	Last length:    16	Average length: 20.62
Episode 300	Last length:    36	Average length: 20.64
Episode 350	Last length:    15	Average length: 19.83
Episode 400	Last length:    28	Average length: 20.71
Episode 450	Last length:    26	Average length: 20.71
Episode 500	Last length:    12	Average length: 20.69
Episode 550	Last length:    39	Average length: 20.65
Episode 600	Last length:    19	Average length: 20.95
Episode 650	Last length:    16	Average length: 21.16
Episode 700	Last length:    46	Average length: 21.23
Episode 750	Last length:    27	Average length: 21.91
Episode 800	Last length:    28	Average length: 21.77
Episode 850	Last length:    12	Average length: 21.56
Episode 900	Last length:    19	Average length: 21.87
Episode 950	Last length:    21	Average length: 

Episode 7700	Last length:    11	Average length: 21.31
Episode 7750	Last length:    15	Average length: 22.07
Episode 7800	Last length:    42	Average length: 22.32
Episode 7850	Last length:    31	Average length: 24.42
Episode 7900	Last length:    22	Average length: 24.67
Episode 7950	Last length:    15	Average length: 22.80
Episode 8000	Last length:    13	Average length: 22.21
Episode 8050	Last length:    13	Average length: 21.58
Episode 8100	Last length:    36	Average length: 21.33
Episode 8150	Last length:    25	Average length: 20.28
Episode 8200	Last length:    12	Average length: 21.76
Episode 8250	Last length:    16	Average length: 22.70
Episode 8300	Last length:    25	Average length: 21.95
Episode 8350	Last length:    12	Average length: 21.28
Episode 8400	Last length:    14	Average length: 22.12
Episode 8450	Last length:    23	Average length: 22.78
Episode 8500	Last length:    35	Average length: 22.67
Episode 8550	Last length:    11	Average length: 22.86
Episode 8600	Last length:   

Episode 15200	Last length:    19	Average length: 21.77
Episode 15250	Last length:    13	Average length: 21.55
Episode 15300	Last length:    18	Average length: 22.72
Episode 15350	Last length:    20	Average length: 23.29
Episode 15400	Last length:    20	Average length: 22.26
Episode 15450	Last length:    78	Average length: 21.28
Episode 15500	Last length:    16	Average length: 21.79
Episode 15550	Last length:    37	Average length: 22.56
Episode 15600	Last length:    52	Average length: 21.34
Episode 15650	Last length:    24	Average length: 20.88
Episode 15700	Last length:    10	Average length: 21.57
Episode 15750	Last length:    29	Average length: 20.84
Episode 15800	Last length:    12	Average length: 20.79
Episode 15850	Last length:    30	Average length: 21.54
Episode 15900	Last length:    47	Average length: 22.26


KeyboardInterrupt: 

In [93]:
advantages

Variable containing:
 22.1621
 21.3444
 20.5303
 19.7040
 18.8690
 18.1379
 17.3436
 16.5237
 15.6649
 14.7808
 13.9262
 13.0271
 12.1357
 11.2294
 10.3136
  9.4931
  8.5530
  7.6574
  6.6987
  5.7394
  4.8496
  3.8581
  2.8703
  1.8782
  0.8645
[torch.FloatTensor of size 25]