# **一个Prioritized Replay DQN算法的简单实现**
## **算法概述**
- 普通DQN算法的改进版
- 使用了两个网络，目标网络和策略网络
    - 目标网络用于提供策略网络需要学习的Q值目标(Qmax(S',a))，策略网络则用于学习更新
    - 两个网络的结构完全相同，目标网络参数固定，每隔X步将策略网络的参数更新到目标网络
- 传统DQN使用单个网络进行价值评估和网络参数更新，这会两者之间的相关性过强，不利于算法收敛；使用两个网络来解耦这种关联性
- off-policy算法，value-based算法

论文链接：*https://docs.popo.netease.com/docs/cbb9344e20ca49578e9d107a26c2ba17*

In [8]:
import random
import numpy as np

In [9]:
class SumTree:
    writer = 0
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(capacity * 2 - 1) # 树中总的节点数
        self.data = np.zeros(capacity, dtype=object) # 存储权重数据的节点
    
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        
        if parent != 0:
            self._propagate(parent, change)
    
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        
        if left >= len(self.tree):
            return idx
        
        if self.tree[left] >= s:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])
        
    def total(self):
        return self.tree[0]
    
    def add(self, p, data):
        idx = self.capacity - 1 + self.writer
        
        self.data[self.writer] = data
        self.update(idx, p)
        
        self.writer += 1
        if self.writer >= self.capacity:
            self.writer = 0
    
    def update(self, idx, p):
        change = p - self.tree[idx]
        
        self.tree[idx] = p
        self._propagate(idx, change)
    
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1 # idx - number_of_none_leafnode
        
        return (idx, self.tree[idx], self.data[dataIdx])

In [10]:
class Memery:
    e = 0.01
    a = 0.6
    
    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        
    def _getPriority(self, error):
        return (error + self.e) ** self.a
    
    def add(self, error, sample):
        p = self._getPriority(error)
        self.tree.add(p, sample)
        
    
    def sample(self, n):
        batch = []
        segment = self.tree.total() / n
        
        for i in range(n):
            a = segment * i
            b = segment * (i + 1)
            
            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append((p, data))
        
        return batch
        
    
    def update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)

In [11]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import copy

In [12]:
GAMMA = 0.9
epsilon = 0.5
start_epsilon = 0.5
end_epsilon = 0.01
replay_size = 10000
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
class PRDQN(nn.Module):
    def __init__(self, env):
        super(PRDQN, self).__init__()
        self.replay_buffer = deque()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_training_network()
        self.create_training_method()
        self.to(device)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        self.state_action_value = self.fc2(x)
        return self.state_action_value
    
    def create_training_network(self):
        self.fc1 = nn.Linear(self.state_dim, 20)
        self.fc2 = nn.Linear(20, self.action_dim)
        
    def create_training_method(self):
        self.optimizer = optim.Adam(self.parameters(), lr=0.0001)
        self.loss_cal = F.mse_loss
        
    def get_target_network(self, target_network):
        self.target_network = target_network
        
    def train_loop(self):
        # self.eval()
        minibatch = random.sample(self.replay_buffer, batch_size)
        state_batch = [data[0] for data in minibatch]
        action_batch = [data[1] for data in minibatch]
        reward_batch = [data[2] for data in minibatch]
        next_state_batch = [data[3] for data in minibatch]
        done = [data[4] for data in minibatch]
        
        # Q(S,A) = Q(S,A) + alpha*(R+gamma*Qmax(S',a) - Q(S,A)) 
        with torch.no_grad():
            Q_max_value_batch = torch.max(self.target_network(torch.tensor(np.stack(next_state_batch), device=device)), dim=1)[0] # 从目标网络中得到最大的Q值
        action = torch.tensor(action_batch, device=device).unsqueeze(-1) 
        reward = torch.tensor(reward_batch, device=device) # 
        done = torch.tensor(done, device=device)
        y_batch = torch.where(done, reward, reward + GAMMA * Q_max_value_batch) # 根据目标网络输出的最大Q值计算目标Q值
        Q_batch = torch.gather(self(torch.tensor(np.stack(state_batch), device=device)), 1, action).squeeze(-1)
        loss = self.loss_cal(Q_batch, y_batch)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def epsilon_greedy(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).to(device)
            if random.random() > epsilon:
                state_action_value = self(state)
                action = torch.argmax(state_action_value).item()
            else:
                action = np.random.randint(0, 2)
            # epsilon = epsilon - (start_epsilon - end_epsilon) / 10000
            return action
    
    def action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).to(device)
            state_action_value = self(state)
            return torch.argmax(state_action_value).item()
    
    def perceive(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))
        if len(self.replay_buffer) > replay_size:
            self.replay_buffer.popleft()
            
        if len(self.replay_buffer) > batch_size:
            self.train_loop()
    
    def perceive(self, state, action, reward, next_state, done):
        pass
    
    def create_sumtree_memery:
        self.memery = Memery(replay_size)
            
    def update_target_network(self):
        self.target_network.load_state_dict(self.state_dict(),strict=False)

In [14]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)
agent = PRDQN(env)
agent.train()

In [16]:
def main():
    for episode in range(3000):
        state, _ = env.reset()
        for step in range(300):
            action = agent.epsilon_greedy(state)
            next_state, reward, done, _, _ = env.step(action)
            
            
            state = next_state
            
            
        

In [None]:
if __name__ == "__main__":
    main()