# **一个Prioritized Replay DQN算法的简单实现**
## **算法概述**
- Double DQN算法的改进版。
- 除了使用目标网络和策略网络两个网络外，还使用了Prioritized Replay Buffer。
- Prioritized Replay Buffer的核心功能由SumTree提供，SumTree可以存储带有权重数据，并且可以依照权重的大小重新采样这些数据。在实际存储时，会先以相同的权重将一部分样本存储在SumTree种，然后采样一组样本进行训练，计算这组样本中各项的TD error然后以TD error为新权重更新SumTree。
- off-policy算法，value-based算法。

论文链接：*https://docs.popo.netease.com/ofedit/0000cf9bc83b4ce087d1ecc32b8c7090*

In [22]:
import random
import numpy as np

In [23]:
class SumTree:
    writer = 0
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(capacity * 2 - 1) # 树中总的节点数
        self.data = np.zeros(capacity, dtype=object) # 存储权重数据的节点
    
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        
        if parent != 0:
            self._propagate(parent, change)
    
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        
        if left >= len(self.tree):
            return idx
        
        if self.tree[left] >= s:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])
    @property
    def total(self):
        return self.tree[0]
    
    def add(self, p, data):
        idx = self.capacity - 1 + self.writer
        
        self.data[self.writer] = data
        self.update(idx, p)
        
        self.writer += 1
        if self.writer >= self.capacity:
            self.writer = 0
    
    def update(self, idx, p):
        change = p - self.tree[idx]
        
        self.tree[idx] = p
        self._propagate(idx, change)
    
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1 # idx - number_of_none_leafnode
        
        return (idx, self.tree[idx], self.data[dataIdx])

In [24]:
class Memory:
    e = 0.01
    a = 0.6
    beta = 0.4
    beta_increment_per_sampling = 0.001
    abs_err_upper = 1.
    
    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        
    def _getPriority(self, error):
        return (error + self.e) ** self.a
    
    def add(self, sample):
        max_p = np.max(self.tree.tree[-self.tree.capacity:])
        if max_p == 0:
            max_p = self.abs_err_upper
        # p = self._getPriority(error)
        self.tree.add(max_p, sample)
        
    
    def sample(self, n):
        b_idx, batch, ISWeights = np.empty((n,), dtype=np.int32), [], np.empty((n, 1))
        segment = self.tree.total / n
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
        
        min_prob = np.max([np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total, 0.00001])
        for i in range(n):
            a = segment * i
            b = segment * (i + 1)
            
            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            prob = p / self.tree.total
            ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
            b_idx[i] = idx
            batch.append(data)
    
        return  batch, b_idx, ISWeights
        
    
    def update(self, idx, abs_err):
        abs_err += 0.1
        clipped_err = np.minimum(abs_err,self.abs_err_upper)
        ps = np.power(clipped_err,self.a)
        for ti, p in zip(idx, ps):
            self.tree.update(ti, p)

In [25]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import copy

In [26]:
GAMMA = 0.9
epsilon = 0.5
start_epsilon = 0.5
end_epsilon = 0.01
replay_size = 10000
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
class PRDQN(nn.Module):
    def __init__(self, env):
        super(PRDQN, self).__init__()
        self.prioritized_replay_buffer = Memory(replay_size)
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_training_network()
        self.create_training_method()
        self.replay_total = 0
        self.to(device)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        self.state_action_value = self.fc2(x)
        return self.state_action_value
    
    def create_training_network(self):
        self.fc1 = nn.Linear(self.state_dim, 20)
        self.fc2 = nn.Linear(20, self.action_dim)
        
    def create_training_method(self):
        self.optimizer = optim.Adam(self.parameters(), lr=0.0001)
        self.loss_cal = F.mse_loss
        
    def get_target_network(self, target_network):
        self.target_network = target_network
        
    def train_loop(self):
        # self.eval()
        minibatch, tree_idx, ISWeights = self.prioritized_replay_buffer.sample(batch_size)
        state = [data[0] for data in minibatch] 
        action = [data[1] for data in minibatch]
        reward = [data[2] for data in minibatch]
        next_state = [data[3] for data in minibatch]
        done = [data[4] for data in minibatch]
        
        # Q(S,A) = Q(S,A) + alpha*(R+gamma*Qmax(S',a) - Q(S,A)) 
        with torch.no_grad():
            next_state = torch.tensor(np.stack(next_state), device=device)
            action_max_Q = torch.argmax(self(next_state), dim=1).unsqueeze(-1) # 得到策略网络输出Q值最大的动作
            Q_max_value_batch = self.target_network(next_state).gather(1, action_max_Q).squeeze(-1) # 从目标网络中获取该动作对应的Q值        action = torch.tensor(action_batch, device=device).unsqueeze(-1) 
        action = torch.tensor(action, device=device).unsqueeze(-1)
        reward = torch.tensor(reward, device=device) # 
        done = torch.tensor(done, device=device)
        
        state = torch.tensor(np.stack(state), device=device)
        y_batch = torch.where(done,reward,reward + GAMMA * Q_max_value_batch) # 计算目标Q值
        Q_batch = self(state).gather(1, action).squeeze(-1)
        y_batch_with_W = y_batch * torch.tensor(ISWeights, device=device).squeeze(-1)
        Q_batch_with_W = Q_batch * torch.tensor(ISWeights, device=device).squeeze(-1)
        
        loss = self.loss_cal(Q_batch_with_W, y_batch_with_W)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        abs_err = torch.abs(y_batch-Q_batch).detach().cpu().numpy()
        self.prioritized_replay_buffer.update(tree_idx,abs_err)
        
    def epsilon_greedy(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).to(device)
            if random.random() > epsilon:
                state_action_value = self(state)
                action = torch.argmax(state_action_value).item()
            else:
                action = np.random.randint(0, 2)
            # epsilon = epsilon - (start_epsilon - end_epsilon) / 10000
            return action
    
    def action(self, state):
        with torch.no_grad():
            state = torch.from_numpy(state).to(device)
            state_action_value = self(state)
            return torch.argmax(state_action_value).item()
    
    def perceive(self, state, action, reward, next_state, done):
        self.prioritized_replay_buffer.add((state, action, reward, next_state, done))
        self.replay_total += 1
        if self.replay_total > batch_size:
            self.train_loop()
            
    def update_target_network(self):
        self.target_network.load_state_dict(self.state_dict(),strict=False)

In [28]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)
agent = PRDQN(env)
target_network = PRDQN(env)
agent.get_target_network(target_network)
agent.train()

PRDQN(
  (fc1): Linear(in_features=4, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=2, bias=True)
  (target_network): PRDQN(
    (fc1): Linear(in_features=4, out_features=20, bias=True)
    (fc2): Linear(in_features=20, out_features=2, bias=True)
  )
)

In [29]:
def main():
    for episode in range(3000):
        state, _ = env.reset()
        for step in range(300):
            action = agent.epsilon_greedy(state)
            next_state, reward, done, _, _ = env.step(action)
            reward = -1 if done else 0.01
            agent.perceive(state,action,reward,next_state,done)
            if done:
                break
            state = next_state
            
        if episode % 100 == 0:
            total_reward = 0
            for i in range(10):
                state, _ = env.reset()
                for step in range(300):
                    action = agent.action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    total_reward += reward
                    if done:
                        break
                    state = next_state
            total_reward /= 10
            print(f"average reward is {total_reward}")
            agent.update_target_network()

In [30]:
if __name__ == "__main__":
    main()

average reward is 12.4
average reward is 9.3
average reward is 13.1
average reward is 14.4
average reward is 147.7
average reward is 209.0
average reward is 210.9
average reward is 187.1
average reward is 186.3
average reward is 203.1
average reward is 215.7
average reward is 196.5
average reward is 203.3
average reward is 189.8
average reward is 209.7
average reward is 204.4
average reward is 232.4
average reward is 213.5


KeyboardInterrupt: 