In [3]:
import torch                                   
import torch.nn as nn                          
import torch.nn.functional as F               
import numpy as np                             
import gym                                     
from collections import deque

# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# 超参数
BATCH_SIZE = 32                                
LR = 0.01                                      
EPSILON = 0.9                              
GAMMA = 0.9                                
TARGET_REPLACE_ITER = 100                    
MEMORY_CAPACITY = 2000                       
env = gym.make('CartPole-v0').unwrapped        
N_ACTIONS = env.action_space.n                  
N_STATES = env.observation_space.shape[0]       

# 定义Net类 (定义网络)
class Net(nn.Module):
    def __init__(self):                                                         
        # nn.Module的子类函数必须在构造函数中执行父类的构造函数
        super(Net, self).__init__()                                            

        self.fc1 = nn.Linear(N_STATES, 50)                                      
        self.fc1.weight.data.normal_(0, 0.1)
        self.fc2 = nn.Linear(50, 50)                                      
        self.fc2.weight.data.normal_(0, 0.1)
        self.out = nn.Linear(50, N_ACTIONS)                                     
        self.out.weight.data.normal_(0, 0.1)                                   
    def forward(self, x):                                                       
        x = F.relu(self.fc1(x))                                                 
        x = F.relu(self.fc2(x))
        return self.out(x)                                                    

# 定义DQN类 (定义两个网络)
class DQN(object):
    def __init__(self):                                                         
        self.eval_net, self.target_net = Net(), Net()                          
        self.learn_step_counter = 0                                             
        self.memory_counter = 0                                                 
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))             
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)    
        self.loss_func = nn.MSELoss()                                           

    def choose_action(self, x):                                                 
        x = torch.unsqueeze(torch.FloatTensor(x), 0)                            
        if np.random.uniform() < EPSILON:                                       
            actions_value = self.eval_net.forward(x)                            
            action = torch.max(actions_value, 1)[1].data.numpy()                
            action = action[0]                                                  
        else:                                                                   
            action = np.random.randint(0, N_ACTIONS)                            
        return action                                                           

    def store_transition(self, s, a, r, s_):                                    
        transition = np.hstack((s, [a, r], s_))                                 
        # 如果记忆库满了，便覆盖旧的数据
        index = self.memory_counter % MEMORY_CAPACITY                           
        self.memory[index, :] = transition                                      
        self.memory_counter += 1                                                

    def learn(self):                                                            
        # 目标网络参数更新
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:                 
            self.target_net.load_state_dict(self.eval_net.state_dict())         
        self.learn_step_counter += 1                                           

        # 抽取记忆库中的批数据
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)           
        b_memory = self.memory[sample_index, :]                                 
        b_s = torch.FloatTensor(b_memory[:, :N_STATES])                         
        b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))  
        b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])            
        b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])                      

        # 获取32个transition的评估值和目标值，并利用损失函数和优化器进行评估网络参数更新
        q_eval = self.eval_net(b_s).gather(1, b_a)                              
        q_next = self.target_net(b_s_).detach()                                 
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)           
        loss = self.loss_func(q_eval, q_target)                                 
        self.optimizer.zero_grad()                                             
        loss.backward()                                             
        self.optimizer.step()                                  

dqn = DQN()                                                     
score = deque(maxlen=50)
for i in range(40000):                                           
    # print('<<<<<<<<<Episode: %s' % i)
    s = env.reset()                                                  
    episode_reward_sum = 0                                    
    while True:                                                     
#         env.render()                                               
        a = dqn.choose_action(s)                                    
        s_, r, done, info = env.step(a)                             
        # 修改奖励 (不修改也可以，修改奖励只是为了更快地得到训练好的摆杆)
        x, x_dot, theta, theta_dot = s_
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
        new_r = r1 + r2
        dqn.store_transition(s, a, new_r, s_)                
        episode_reward_sum += new_r                        
        s = s_                                              
        if dqn.memory_counter > MEMORY_CAPACITY:           
            dqn.learn()
        if done:       
            score.append(episode_reward_sum)
            print(f'\r==> episode: {i}---reward_mean100: {round(np.mean(score), 2)}',end="")
            if i % 100 == 0:
                print(f'\r==> episode: {i}---reward_mean100: {round(np.mean(score), 2)}')
            break                                             

    if np.mean(score) > 550 and i > 500:
      torch.save(dqn.eval_net,'dqn_eval_v3.pth')
      break
print('------FINISHED------')

==> episode: 0---reward_mean100: 2.87
==> episode: 100---reward_mean100: 2.47
==> episode: 200---reward_mean100: 2.58
==> episode: 300---reward_mean100: 350.95
==> episode: 400---reward_mean100: 403.62
==> episode: 500---reward_mean100: 524.56
==> episode: 518---reward_mean100: 555.81
------FINISHED------
