In [1]:
import gym
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from itertools import count

In [2]:
env = gym.make('CartPole-v0')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
#循环写入的数组
class ReplayMemory(object):
    def __init__(self, capacity):
        # cap : 10000
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, trans):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = trans
        #loop [0 : cap)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        #从A中随机抽取B个. B > A 时抛异常
        return random.sample(self.memory, batch_size) 
    
    def __len__(self):
        return len(self.memory)


memory = ReplayMemory(10000)

In [4]:
HIDDEN_LAYER = 256  # NN hidden layer size
LR = 0.001


class Network(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.l1 = nn.Linear(4, HIDDEN_LAYER)
        self.l2 = nn.Linear(HIDDEN_LAYER, 2)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.l2(x)
        return x
    

model = Network().to(device)
optimizer = optim.Adam(model.parameters(), LR)

In [5]:
EPS_START = 0.9  # e-greedy threshold start value
EPS_END = 0.05  # e-greedy threshold end value
EPS_DECAY = 200  # e-greedy threshold decay
steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return model(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)

In [6]:
episode_durations = []
PLOT_MEAN = 10

def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.FloatTensor(episode_durations)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # take 100 episode averages and plot them too
    
    if len(durations_t) >= PLOT_MEAN:
        means = durations_t.unfold(0, PLOT_MEAN, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(PLOT_MEAN-1), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated

In [7]:
def learn():
    BATCH_SIZE = 128  # Q-learning batch size
    GAMMA = 0.7
    ALPHA = 0.7
    
    if len(memory) < BATCH_SIZE:
        return

    trans = memory.sample(BATCH_SIZE)
    batch = list(zip(*trans))
    
    batch_state = torch.cat(batch[0])
    batch_action = torch.cat(batch[1])
    batch_nstate = torch.cat(batch[2])
    batch_reward = torch.cat(batch[3])
  
    current_q_values = model(batch_state).gather(1, batch_action)
    max_next_q_values = model(batch_nstate).detach().max(1)[0]
    
    expected_q_values = batch_reward + (max_next_q_values * GAMMA)
    expected_q_values = expected_q_values.unsqueeze(1)

    # loss is measured from error between current and newly expected Q values
    loss = F.smooth_l1_loss(current_q_values, expected_q_values)

    # backpropagation of loss to NN
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [8]:
EPISODES = 50  # number of episodes

def run_episode():
    plt.ion()
    
    for e in range(EPISODES):
        state = env.reset()
        state = torch.tensor([state], device=device, dtype=torch.float)

        for t in count(): 
            #env.render()
            action = select_action(state)
            next_state, reward, done, _ = env.step(action.item())
            reward = -200 if done else reward

            next_state = torch.tensor([next_state], device=device, dtype=torch.float)
            reward = torch.tensor([reward], device=device, dtype=torch.float)
            
            memory.push((state, action, next_state, reward))
            state = next_state

            learn()

            if done:
                print("{2} Episode {0} finished after {1} steps"
                      .format(e, t, '\033[92m' if t >= 195 else '\033[0m'))
                episode_durations.append(t)
                plot_durations()
                break
                
    print('Complete')
    #env.render()
    env.close()
    plt.ioff()
    plt.show()

In [9]:
%matplotlib
run_episode()
%matplotlib inline

Using matplotlib backend: TkAgg
[0m Episode 0 finished after 13 steps
[0m Episode 1 finished after 11 steps
[0m Episode 2 finished after 9 steps
[0m Episode 3 finished after 11 steps
[0m Episode 4 finished after 11 steps
[0m Episode 5 finished after 13 steps
[0m Episode 6 finished after 13 steps
[0m Episode 7 finished after 9 steps
[0m Episode 8 finished after 11 steps
[0m Episode 9 finished after 10 steps
[0m Episode 10 finished after 13 steps
[0m Episode 11 finished after 11 steps
[0m Episode 12 finished after 8 steps
[0m Episode 13 finished after 9 steps
[0m Episode 14 finished after 10 steps
[0m Episode 15 finished after 10 steps
[0m Episode 16 finished after 9 steps
[0m Episode 17 finished after 9 steps
[0m Episode 18 finished after 10 steps
[0m Episode 19 finished after 9 steps
[0m Episode 20 finished after 8 steps
[0m Episode 21 finished after 9 steps
[0m Episode 22 finished after 9 steps
[0m Episode 23 finished after 13 steps
[0m Episode 24 finished aft

In [10]:
from itertools import count
rewards = []
for trails in range(100):
    state = env.reset()
    reward = 0
    for t in count():
        env.render()
        with torch.no_grad():
              action = model(torch.tensor([state], device=device, dtype=torch.float)).max(1)[1].view(1, 1)
        state, r, done, _ = env.step(action.item())
        reward += r
        if done:
            rewards.append(reward)
            print(reward)
            break

env.close()
print("rewards:", np.average(rewards))


NameError: name 'FloatTensor' is not defined