In [4]:
import torch
import gym
import torch.nn as nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import math
import random
import os

In [22]:
STATE_DIM = 4
ACTION_DIM = 2
STEP = 2000
SAMPLE_NUMS = 30

In [19]:
class ActorNetwork(nn.Module):
    
    def __init__(self, input_size, hidden_size, action_size):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)
        
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = F.log_softmax(self.fc3(out))
        return out
    
class ValueNetwork(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [20]:
def roll_out(actor_network, task, sample_nums, value_network, init_state):
    states = []
    actions = []
    rewards = []
    is_done = False
    final_r = 0
    state = init_state
    
    for j in range(sample_nums):
        states.append(state)
        log_softmax_action = actor_network(Variable(torch.Tensor([state])))
        softmax_action = torch.exp(log_softmax_action)
        action = np.random.choice(ACTION_DIM, p=softmax_action.cpu().data.numpy()[0])
        one_hot_action = [int(k == action) for k in range(ACTION_DIM)]
        next_state, reward, done, _ = task.step(action)
        actions.append(one_hot_action)
        rewards.append(reward)
        final_state = next_state
        state = next_state
        if done:
            is_done = True
            state = task.reset()
            break
    if not is_done:
        final_r = value_network(Variable(torch.Tensor([final_state]))).cpu().data.numpy()
    
    return states, actions, rewards, final_r, state

def discount_reward(r, gamma, final_r):
    discounted_r = np.zeros_like(r)
    running_add = final_r
    for t in reversed(range(0, len(r))):
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r

In [23]:

# init a task generator for data fetching
task = gym.make("CartPole-v0")
init_state = task.reset()

# init network
value_network = ValueNetwork(input_size=STATE_DIM, hidden_size=50, output_size=1)
value_network_optim = torch.optim.Adam(value_network.parameters(), lr=0.01)

actor_network = ActorNetwork(STATE_DIM, 50, ACTION_DIM)
actor_network_optim = torch.optim.Adam(actor_network.parameters(), lr=0.01)

steps = []
task_episodes = []
test_results = []

for step in range(STEP):
    states, actions, rewards, final_r, current_state = roll_out(actor_network, task, SAMPLE_NUMS, value_network, init_state)
    init_state = current_state
    actions_var = Variable(torch.Tensor(actions).view(-1, ACTION_DIM))
    states_var = Variable(torch.Tensor(states).view(-1, STATE_DIM))

    actor_network_optim.zero_grad()
    log_softmax_actions = actor_network(states_var)

    vs = value_network(states_var).detach()
    qs = Variable(torch.Tensor(discount_reward(rewards, 0.99, final_r)))

    advantages = qs - vs
    actor_network_loss = torch.mean(torch.sum(log_softmax_actions * actions_var, 1) * advantages)
    actor_network_loss.backward()
    torch.nn.utils.clip_grad_norm(actor_network.parameters(), 0.5)
    actor_network_optim.step()

    value_network_optim.zero_grad()
    target_values = qs
    values = value_network(states_var)
    criterion = nn.MSELoss()
    value_network_loss = criterion(values, target_values)
    value_network_loss.backward()
    torch.nn.utils.clip_grad_norm(value_network.parameters(), 0.5)
    value_network_optim.step()

    # Testing
    if (step + 1) % 50 == 0:
        result = 0
        test_task = gym.make("CartPole-v0")
        for test_epi in range(10):
            state = test_task.reset()
            for test_step in range(200):
                softmax_action = torch.exp(actor_network(Variable(torch.Tensor([state]))))
                action = np.argmax(softmax_action.data.numpy([0]))
                next_state, reward, done, _ = test_task.step(action)
                result += reward
                state = next_state
                if done:
                    break
        print("step: ", step + 1, "test result: ", result / 10.0)
        steps.append(step + 1)
        test_results.append(result / 10)
            
    


[2017-07-13 21:28:26,924] Making new env: CartPole-v0
[2017-07-13 21:28:27,342] Making new env: CartPole-v0


('step: ', 50, 'test result: ', 9.7)


[2017-07-13 21:28:27,670] Making new env: CartPole-v0


('step: ', 100, 'test result: ', 9.3)


[2017-07-13 21:28:27,977] Making new env: CartPole-v0


('step: ', 150, 'test result: ', 9.6)


[2017-07-13 21:28:28,258] Making new env: CartPole-v0


('step: ', 200, 'test result: ', 9.2)


[2017-07-13 21:28:28,543] Making new env: CartPole-v0


('step: ', 250, 'test result: ', 9.4)


[2017-07-13 21:28:28,816] Making new env: CartPole-v0


('step: ', 300, 'test result: ', 9.4)


[2017-07-13 21:28:29,096] Making new env: CartPole-v0


('step: ', 350, 'test result: ', 9.6)


[2017-07-13 21:28:29,373] Making new env: CartPole-v0


('step: ', 400, 'test result: ', 9.0)


[2017-07-13 21:28:29,652] Making new env: CartPole-v0


('step: ', 450, 'test result: ', 9.6)


[2017-07-13 21:28:29,939] Making new env: CartPole-v0


('step: ', 500, 'test result: ', 9.4)


[2017-07-13 21:28:30,227] Making new env: CartPole-v0


('step: ', 550, 'test result: ', 9.1)


[2017-07-13 21:28:30,513] Making new env: CartPole-v0


('step: ', 600, 'test result: ', 9.9)


[2017-07-13 21:28:30,797] Making new env: CartPole-v0


('step: ', 650, 'test result: ', 8.6)


[2017-07-13 21:28:31,090] Making new env: CartPole-v0


('step: ', 700, 'test result: ', 9.1)


[2017-07-13 21:28:31,389] Making new env: CartPole-v0


('step: ', 750, 'test result: ', 9.1)


[2017-07-13 21:28:31,705] Making new env: CartPole-v0


('step: ', 800, 'test result: ', 9.2)


[2017-07-13 21:28:32,004] Making new env: CartPole-v0


('step: ', 850, 'test result: ', 9.3)


[2017-07-13 21:28:32,299] Making new env: CartPole-v0


('step: ', 900, 'test result: ', 9.1)


[2017-07-13 21:28:32,588] Making new env: CartPole-v0


('step: ', 950, 'test result: ', 9.3)


[2017-07-13 21:28:32,870] Making new env: CartPole-v0


('step: ', 1000, 'test result: ', 9.6)


[2017-07-13 21:28:33,152] Making new env: CartPole-v0


('step: ', 1050, 'test result: ', 9.5)


[2017-07-13 21:28:33,449] Making new env: CartPole-v0


('step: ', 1100, 'test result: ', 9.0)


[2017-07-13 21:28:33,737] Making new env: CartPole-v0


('step: ', 1150, 'test result: ', 9.2)


[2017-07-13 21:28:34,030] Making new env: CartPole-v0


('step: ', 1200, 'test result: ', 9.6)


[2017-07-13 21:28:34,326] Making new env: CartPole-v0


('step: ', 1250, 'test result: ', 9.4)


[2017-07-13 21:28:34,651] Making new env: CartPole-v0


('step: ', 1300, 'test result: ', 9.3)


[2017-07-13 21:28:34,973] Making new env: CartPole-v0


('step: ', 1350, 'test result: ', 9.3)


[2017-07-13 21:28:35,297] Making new env: CartPole-v0


('step: ', 1400, 'test result: ', 9.3)


[2017-07-13 21:28:35,653] Making new env: CartPole-v0


('step: ', 1450, 'test result: ', 9.7)


[2017-07-13 21:28:35,983] Making new env: CartPole-v0


('step: ', 1500, 'test result: ', 9.1)


[2017-07-13 21:28:36,307] Making new env: CartPole-v0


('step: ', 1550, 'test result: ', 9.5)


[2017-07-13 21:28:36,624] Making new env: CartPole-v0


('step: ', 1600, 'test result: ', 9.0)


[2017-07-13 21:28:36,950] Making new env: CartPole-v0


('step: ', 1650, 'test result: ', 9.4)


[2017-07-13 21:28:37,258] Making new env: CartPole-v0


('step: ', 1700, 'test result: ', 9.5)


[2017-07-13 21:28:37,640] Making new env: CartPole-v0


('step: ', 1750, 'test result: ', 9.3)


[2017-07-13 21:28:38,005] Making new env: CartPole-v0


('step: ', 1800, 'test result: ', 9.5)


[2017-07-13 21:28:38,353] Making new env: CartPole-v0


('step: ', 1850, 'test result: ', 9.3)


[2017-07-13 21:28:38,746] Making new env: CartPole-v0


('step: ', 1900, 'test result: ', 9.2)


[2017-07-13 21:28:39,126] Making new env: CartPole-v0


('step: ', 1950, 'test result: ', 9.6)


[2017-07-13 21:28:39,443] Making new env: CartPole-v0


('step: ', 2000, 'test result: ', 9.2)
