In [1]:
import random
import gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils

ModuleNotFoundError: No module named 'gym'

In [2]:
pip install gym

Collecting gym
  Using cached gym-0.26.2.tar.gz (721 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting gym-notices>=0.0.4
  Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)
Building wheels for collected packages: gym
  Building wheel for gym (pyproject.toml) ... [?25ldone
[?25h  Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827632 sha256=c4af52fc60531303037bcc00c0732e51be432efd987623878c74e82ced9d2c35
  Stored in directory: /Users/jezeehu/Library/Caches/pip/wheels/af/2b/30/5e78b8b9599f2a2286a582b8da80594f654bf0e18d825a4405
Successfully built gym
Installing collected packages: gym-notices, gym
Successfully installed gym-0.26.2 gym-notices-0.0.8
Note: you may need to restart the kernel to use updated packages.


In [3]:
import random
import gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils

In [4]:
class ReplayBuffer:
    '''经验回放池'''
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) #队列，先进先出
        
    def add(self, state, action, reward, next_state, done): #将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size): #从buffer中采样数据，数量大小未batch_size
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(transitions)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self):
        return len(self.buffer)

In [5]:
class Qnet(torch.nn.Module):
    '''只有一层隐藏层的Q网络'''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Qnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x)) #隐层使用ReLU激活函数
        return self.fc2(x)

In [8]:
class DQN:
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,
                epsilon, target_update, device):
        self.action_dim = action_dim
        self.q_net = Qnet(state_dim, hidden_dim, self.action_dim).to(device) #Q网络
        self.target_q_net = Qnet(state_dim, hidden_dim, self.action_dim).to(device) #目标网络
        
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr = learning_rate)
        
        self.gamma = gamma #折扣因子
        self.epsilon = epsilon #epsilon-贪婪策略
        self.target_update = target_update  #目标网络更新频率
        self.count = 0 #计数器，记录更新次数
        self.device = device
    
    def take_action(self, state): #epsilon-贪婪策略采取动作
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)
        
        q_values = self.q_net(states).gather(1, actions) #Q值
        #下个状态的最大Q值
        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones) #TD误差目标
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets)) #均方误差损失函数
        self.optimizer.zero_grad() #pytorch中默认梯度会累计，需要显式将梯度置为0
        dqn_loss.backward() #反向传播更新参数
        self.optimizer.step()
        
        if self.count % self.target_update == 0:
            self.target_q_net.load_stete_dict(self.q_net.state_dict()) #更新目标网络
        self.count += 1