# Algorithm Implement


**目录：**
1. 算法分析

2. agent结构

3. loss计算

4. 训练过程

---

## 1) 算法分析

首先，我们对比2013年和2015年两篇paper的算法的不同之处

<img src="./imgs/Old DQN.png"  width="550" height="500" align="left" />
<img src="./imgs/Better DQN.png"  width="400" height="500" align="right" />

一个很重要的区别是2015年的版本中引入了target action_value function $\hat{Q}$ 以及其独立的参数$\theta^-$，而且$\theta^-$的更新频率比$\theta$低很多。这个方法能帮助训练中的target更加稳定，从而提升训练的效果，我们也将直接实现2015年的算法

## 2) agent结构

### 1) 功能分析

首先我们应该考虑我们需要一个怎样的agent。

在算法中，可以看到agent需要做的事情有：
  * 选择action，随机或非随机
  * 更新$Q$的参数$\theta$
  * 更新$\hat{Q}$的参数$\theta^-$
  
而其他的一些部分与agent相关，但并不一定要融入agent内部：
  * 数据预处理
  * Replay Buffer
  
在这次的代码中，我们选择将**Replay Buffer**部分放入agent内部，这并不是必须的，但是可以让代码稍微简单一些。做出这样的选择后，我们可以大致确定agent类的主要部分：

  * 初始化，`agent.__init__(self, *args, **kwargs)`
    * 设置各种参数，如学习率，$\gamma$
    * 初始化Q-Network $Q, \hat{Q}$
    * 初始化Replay Buffer
  * 选择行动，`agent.act(self, state, eps)`
    * 随机选择
    * 非随机选择
  * 更新网络参数，`agent.learn(self)`
    * 从Replay Buffer中抽取mini-batch
    * 计算loss
    * 更新参数$\theta$
    * 如果需要，更新参数$\theta^-$
    
下面我们一个一个部分分析

### 2) 初始化

这个部分比较简单，只需要合理的传参就可以完成参数的初始化和网络的初始化，需要考虑的部分有：
  * 如何让两个网络权重一致
    * 方法1：使用`state_dict`方法，由$Q$进行`save`，然后`load`到$\hat{Q}$
    * 方法2：对每个layer的`data`进行提取，然后进行计算，最后使用`copy_`方法导入$\hat{Q}$
  * 如何设置Replay Buffer
    * 方法1：直接使用`collections`的`deque`类，只做储存功能
    * 方法2：在方法1的基础上，拓展成一个类，支持写入和抽样方法
    
**调整网络权重**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Q_Network(nn.Module):

    def __init__(self, state_size, action_size, hidden=[64, 64]):
        super(Q_Network, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden[0])
        self.fc2 = nn.Linear(hidden[0], hidden[1])
        self.fc3 = nn.Linear(hidden[1], action_size)

    def forward(self, state):
        x = state
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [2]:
#state_dict方法
net1 = Q_Network(5, 5, [3, 3])
net2 = Q_Network(5, 5, [3, 3])
net1.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.0925,  0.2964,  0.1767, -0.3285, -0.1470],
                      [ 0.3120, -0.4361,  0.1911,  0.3987,  0.4299],
                      [ 0.3245,  0.3840, -0.2999,  0.3200,  0.0381]])),
             ('fc1.bias', tensor([ 0.2561, -0.3150,  0.3420])),
             ('fc2.weight', tensor([[-0.5426, -0.5191,  0.4505],
                      [-0.2780, -0.4256,  0.3189],
                      [ 0.3753,  0.4871,  0.4037]])),
             ('fc2.bias', tensor([ 0.2724, -0.1105, -0.3366])),
             ('fc3.weight', tensor([[ 0.3055,  0.0140, -0.2403],
                      [ 0.3427,  0.4102, -0.0539],
                      [-0.0157, -0.4691, -0.4021],
                      [-0.0481, -0.5255,  0.1780],
                      [-0.1348,  0.2751,  0.1359]])),
             ('fc3.bias',
              tensor([ 0.1009, -0.4235,  0.1802, -0.4351,  0.4830]))])

In [3]:
net2.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.0743, -0.0388,  0.2691,  0.2066,  0.3201],
                      [ 0.0948, -0.3126,  0.0798, -0.0250, -0.1170],
                      [-0.1719,  0.2674, -0.0075, -0.1921, -0.4443]])),
             ('fc1.bias', tensor([ 0.0571,  0.2370, -0.2493])),
             ('fc2.weight', tensor([[-0.4953,  0.2223, -0.0297],
                      [-0.5063,  0.0753,  0.3710],
                      [ 0.1346,  0.2840,  0.4755]])),
             ('fc2.bias', tensor([ 0.1267,  0.3081, -0.1986])),
             ('fc3.weight', tensor([[-0.0182, -0.4756,  0.0271],
                      [-0.3871,  0.0410,  0.4056],
                      [ 0.1736, -0.4287,  0.0648],
                      [ 0.4501,  0.1102,  0.4227],
                      [-0.0091,  0.2039,  0.5011]])),
             ('fc3.bias',
              tensor([ 0.2565, -0.2944, -0.3965,  0.3402,  0.2100]))])

In [4]:
torch.save(net1.state_dict(), 'sample_weight.pth')
net2.load_state_dict(torch.load('sample_weight.pth'))
net2.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.0925,  0.2964,  0.1767, -0.3285, -0.1470],
                      [ 0.3120, -0.4361,  0.1911,  0.3987,  0.4299],
                      [ 0.3245,  0.3840, -0.2999,  0.3200,  0.0381]])),
             ('fc1.bias', tensor([ 0.2561, -0.3150,  0.3420])),
             ('fc2.weight', tensor([[-0.5426, -0.5191,  0.4505],
                      [-0.2780, -0.4256,  0.3189],
                      [ 0.3753,  0.4871,  0.4037]])),
             ('fc2.bias', tensor([ 0.2724, -0.1105, -0.3366])),
             ('fc3.weight', tensor([[ 0.3055,  0.0140, -0.2403],
                      [ 0.3427,  0.4102, -0.0539],
                      [-0.0157, -0.4691, -0.4021],
                      [-0.0481, -0.5255,  0.1780],
                      [-0.1348,  0.2751,  0.1359]])),
             ('fc3.bias',
              tensor([ 0.1009, -0.4235,  0.1802, -0.4351,  0.4830]))])

In [5]:
# use .data and .data.copy_
def soft_update(net1, net2, tau):
    '''
    This function update net2's weights as weighted average of net1 and net2's weights
    net2 new weight = tau * net1 weight + (1 - tau) * net2 weight
    '''
    for target_param, local_param in zip(net2.parameters(), net1.parameters()):
        target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)

In [6]:
net1 = Q_Network(5, 5, [3, 3])
net2 = Q_Network(5, 5, [3, 3])
net1.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.2299,  0.1428, -0.1564,  0.1906, -0.3330],
                      [ 0.3900,  0.1510,  0.0532,  0.4311,  0.3918],
                      [ 0.2452,  0.3879, -0.4217, -0.3185,  0.1902]])),
             ('fc1.bias', tensor([0.3736, 0.1483, 0.3985])),
             ('fc2.weight', tensor([[ 0.5117,  0.0028,  0.0388],
                      [ 0.5616,  0.5186, -0.0247],
                      [-0.3029,  0.2964,  0.4281]])),
             ('fc2.bias', tensor([ 0.2233, -0.0437,  0.0111])),
             ('fc3.weight', tensor([[ 0.4599, -0.3969,  0.5516],
                      [-0.2753,  0.3211,  0.1092],
                      [ 0.5005, -0.1821, -0.4430],
                      [ 0.1723,  0.2741, -0.5264],
                      [ 0.2439, -0.5012, -0.5745]])),
             ('fc3.bias',
              tensor([-0.1350,  0.2478,  0.3126,  0.2716, -0.3751]))])

In [7]:
net2.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.0014, -0.1165,  0.4389,  0.2044, -0.0413],
                      [ 0.0307, -0.0345, -0.0496,  0.4184,  0.0350],
                      [ 0.3984, -0.2312, -0.3261,  0.4412,  0.1084]])),
             ('fc1.bias', tensor([ 0.1550, -0.0480,  0.3098])),
             ('fc2.weight', tensor([[ 0.2695, -0.5170, -0.5436],
                      [-0.1767, -0.1369, -0.0465],
                      [ 0.4651, -0.3268,  0.5606]])),
             ('fc2.bias', tensor([ 0.2329, -0.1796, -0.1200])),
             ('fc3.weight', tensor([[-0.0162, -0.1983, -0.2579],
                      [-0.1972,  0.2593, -0.3214],
                      [ 0.0624, -0.2514,  0.0549],
                      [-0.4351,  0.4384, -0.0198],
                      [-0.3582, -0.0810, -0.1467]])),
             ('fc3.bias',
              tensor([-0.4678,  0.0820,  0.5584,  0.1007,  0.0040]))])

In [8]:
soft_update(net1, net2, 0.5)
net2.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.1157,  0.0132,  0.1412,  0.1975, -0.1871],
                      [ 0.2103,  0.0582,  0.0018,  0.4248,  0.2134],
                      [ 0.3218,  0.0784, -0.3739,  0.0614,  0.1493]])),
             ('fc1.bias', tensor([0.2643, 0.0501, 0.3542])),
             ('fc2.weight', tensor([[ 0.3906, -0.2571, -0.2524],
                      [ 0.1924,  0.1909, -0.0356],
                      [ 0.0811, -0.0152,  0.4943]])),
             ('fc2.bias', tensor([ 0.2281, -0.1116, -0.0544])),
             ('fc3.weight', tensor([[ 0.2218, -0.2976,  0.1468],
                      [-0.2362,  0.2902, -0.1061],
                      [ 0.2814, -0.2167, -0.1941],
                      [-0.1314,  0.3562, -0.2731],
                      [-0.0572, -0.2911, -0.3606]])),
             ('fc3.bias',
              tensor([-0.3014,  0.1649,  0.4355,  0.1861, -0.1856]))])

In [9]:
soft_update(net1, net2, 1)
net2.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-0.2299,  0.1428, -0.1564,  0.1906, -0.3330],
                      [ 0.3900,  0.1510,  0.0532,  0.4311,  0.3918],
                      [ 0.2452,  0.3879, -0.4217, -0.3185,  0.1902]])),
             ('fc1.bias', tensor([0.3736, 0.1483, 0.3985])),
             ('fc2.weight', tensor([[ 0.5117,  0.0028,  0.0388],
                      [ 0.5616,  0.5186, -0.0247],
                      [-0.3029,  0.2964,  0.4281]])),
             ('fc2.bias', tensor([ 0.2233, -0.0437,  0.0111])),
             ('fc3.weight', tensor([[ 0.4599, -0.3969,  0.5516],
                      [-0.2753,  0.3211,  0.1092],
                      [ 0.5005, -0.1821, -0.4430],
                      [ 0.1723,  0.2741, -0.5264],
                      [ 0.2439, -0.5012, -0.5745]])),
             ('fc3.bias',
              tensor([-0.1350,  0.2478,  0.3126,  0.2716, -0.3751]))])

**Replay Buffer**

`collections`的`deque`只有一个主要参数，就是`maxlen`，也就是最大储存数量

`deque`和`list`几乎完全一致，只是会自动控制最大元素个数，所以很好使用，也很适合作为Replay Buffer使用

In [10]:
from collections import deque

test = deque(maxlen=3)
for i in range(10):
    test.append(i)
    print(test)

deque([0], maxlen=3)
deque([0, 1], maxlen=3)
deque([0, 1, 2], maxlen=3)
deque([1, 2, 3], maxlen=3)
deque([2, 3, 4], maxlen=3)
deque([3, 4, 5], maxlen=3)
deque([4, 5, 6], maxlen=3)
deque([5, 6, 7], maxlen=3)
deque([6, 7, 8], maxlen=3)
deque([7, 8, 9], maxlen=3)


## 3) loss计算

### 1) `torch.gather`

此处要强调loss计算是因为，我们的网络对于输入$s_t$，会直接输出所有的$$Q(s_t, a_1), Q(s_t, a_2), ..., Q(s_t, a_n)$$
为了将一个batch中需要的Q-values便捷的提取出来，需要介绍一个一个函数`torch.gather`

`torch.gather`的一般用法为`torch.gather(input, dim, index)`，这个函数主要的功能是根据`dim`和`index`来提取`input`中的值，具体的来说：
  * `input`和`index`必须拥有同样的size，$(d_0, d_1, ..., d_k)$，除了`dim`维度
  * `dim`的范围不能超过size的长度，即上一行的$k$
  * `input`:一个`torch.tensor`对象，取值的来源
  * `dim`:一个整数，被取代的维度，之后补充具体意义
  * `index`:一个`torch.tensor`对象，必须是`torch.long`数据类型，用来指定提取的坐标
  * 数学表达式来看，函数返回值仍是一个$(d_0, d_1, ..., d_k)$维度，和`input`有相同数据类型的`torch.tensor`，不过在$[i_0, i_1,...,i_k]$坐标上的值是`input`对应坐标$[i_0, i_1, ...,i_{dim-1}, index[i_0, i_1,...,i_k], i_{dim+1},...,i_k]$
  * 简单来说，就是用`index`的值去改变原坐标中`dim`维度的值，这也是`index`需要是`torch.long`类型的原因

In [11]:
# example
test = torch.rand((5, 4))
test

tensor([[0.7587, 0.2253, 0.1851, 0.1033],
        [0.5584, 0.7686, 0.8005, 0.7403],
        [0.9221, 0.7041, 0.4943, 0.2212],
        [0.7942, 0.6475, 0.0248, 0.5439],
        [0.9903, 0.1416, 0.8074, 0.6079]])

In [12]:
index = torch.tensor([[1,2,2,0],
                      [3,1,1,2],
                      [0,0,0,0],
                      [1,1,1,1],
                      [2,2,2,2]], dtype=torch.long)

In [13]:
torch.gather(test, 0, index)

tensor([[0.5584, 0.7041, 0.4943, 0.1033],
        [0.7942, 0.7686, 0.8005, 0.2212],
        [0.7587, 0.2253, 0.1851, 0.1033],
        [0.5584, 0.7686, 0.8005, 0.7403],
        [0.9221, 0.7041, 0.4943, 0.2212]])

In [14]:
torch.gather(test, 1, index)

tensor([[0.2253, 0.1851, 0.1851, 0.7587],
        [0.7403, 0.7686, 0.7686, 0.8005],
        [0.9221, 0.9221, 0.9221, 0.9221],
        [0.6475, 0.6475, 0.6475, 0.6475],
        [0.8074, 0.8074, 0.8074, 0.8074]])

### 2) `torch.max`

`torch.max`比较简单，主要的使用方法是`torch.max(input, dim)`，作用就是沿着`dim`维度提取每个slice的最大值即对应下标，由于函数本身比较简单，就不单独给出样例了

下面我们进行两个具体的操作：
  * 提取`test`每一行的最大值
  * 提取`test`对应行的第(2,2,1,3,0)元素

In [15]:
#max
values, indices = torch.max(test, dim=1)
values, indices

(tensor([0.7587, 0.8005, 0.9221, 0.7942, 0.9903]), tensor([0, 2, 0, 0, 0]))

In [16]:
index = torch.tensor([[2], [2], [1], [3], [0]], dtype=torch.long)
torch.gather(test, 1, index)

tensor([[0.1851],
        [0.8005],
        [0.7041],
        [0.5439],
        [0.9903]])

In [17]:
index = torch.tensor([2, 2, 1, 3, 0], dtype=torch.long)
torch.gather(test, 1, index)

RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:638

由这两个函数，我们就可以计算loss了，具体的过程为：
  * 已知mini-batch中的state, action, next_state, reward, done
  * $Q$网络对应`net_local`，$\hat{Q}$网络对应`net_target`
  * 计算Q-values `Q_values = net_local(state)`
  * 根据action，提取对应的Q-values `Q_values = torch.gather(Q_values, 1, action)`
  * 计算next_state对应的最大的target values `Q_targets, _ = torch.max(net_target(next_state), 1, keepdim=True)`
  * 在target values中加入reward和done的信息 `Q_targets = reward + (1 - done) * Q_targets`
  * 注意`Q_targets`不应该有导数
  * 计算loss `loss = (Q_values - Q_targets).pow(2).mean()`
  * 使用`.backward()`方法更新参数
  
## 4) 训练过程

### 1) 结构分析
之前的notebook中我们已经示范了如何获得随机策略的reward平均值，我们将在这个的基础上扩充代码，将其变为可以用做训练的代码

In [None]:
#env = gym.make()
num_episode = 5
max_t = 1000
reward_log = []

for _ in range(num_episode):
    
    # initialize
    env.reset()
    t = 0
    episodic_reward = 0
    
    for t in range(max_t):
        
        #env.render()
        action = env.action_space.sample() # random action
        _, reward, done, _ = env.step(action)
        episodic_reward += reward
        if done:
            break
    
    reward_log.append(episodic_reward)

在之前的代码中，我们已经完成的部分有：
  * 初始化参数
  * 初始化environment
  * 选择action，记录reward
 
现在我们需要增加的部分有：
  * 初始化和调整$\epsilon$的值以实现$\epsilon$-Greedy
  * agent选择action部分
  * agent的训练部分，即更新$Q$和$\hat{Q}$的对应参数
  * 监督训练进度的部分

之前我们已经分析了`agent`的结构，如果我们认为我们已经完成了`agent`类的设计，那么具体的任务就是：
  * 初始化和调整$\epsilon$的值以实现$\epsilon$-Greedy
  * 在合适的位置插入`agent.act()`
  * 在合适的位置插入`agent.learn()`
  * 监督训练进度的部分
  
## 2) 具体实现

In [None]:
#env = gym.make()
num_episode = 5
max_t = 1000
reward_log = []
average_log = [] # monitor training process
eps = 1
eps_decay = 0.995
eps_min = 0.01
C = 4 # update weights every C steps

for _ in range(num_episode):
    
    # initialize 
    state = env.reset()
    # preprocessing of state if necessary
    t = 0
    done = False
    episodic_reward = 0
    
    for t in range(max_t):
        
        action = agent.act(state, eps)
        next_state, reward, done, _ = env.step(action)
        # preprocessing of next_state if necessary
        episodic_reward += reward
        
        if t % C == 0:
            agent.learn()
            agent.soft_update()
        
        state = next_state.copy()
        if done:
            break
            
    reward_log.append(episodic_reward)
    
    # monitor
    average_log.append(np.mean(rewards_log[-100:])) 
    print('\rEpisode {}, Reward {:.3f}, Average Reward {:.3f}'.format(i, episodic_reward, average_log[-1]), end='')
    if i % 50 == 0:
        print()