# 强化学习之Q-learning

## 1. Q-learning算法

> Initialize $Q(s,a)$ arbitrarily  
> Repeat (for each episode):  
> &ensp;&ensp; Initialize $s$  
> &ensp;&ensp; Repeat (for each step of episode):  
> &ensp;&ensp;&ensp;&ensp; Choose $a$ from $s$ using policy derived from $Q(e.g.,\epsilon-greedy)$  
> &ensp;&ensp;&ensp;&ensp; Take action $a$, observe $r$, $s'$  
> &ensp;&ensp;&ensp;&ensp; $Q(s,a) \leftarrow Q(s,a) + \alpha*[r + \gamma*max_{a'}Q(s',a') - Q(s,a)]$  
> &ensp;&ensp;&ensp;&ensp; $s \leftarrow s'$;  
> &ensp;&ensp; until $s$ is terminal


### 1.1 Q-learning算法说明

> 在Q-learning算法中，$s$表示智能体agent所处的环境状态，$a$为智能体所采取的动作, 智能体采取动作之后会获得环境的反馈$r$，以及观察到新的环境状态$s'$。$\alpha$为迭代步长。$\gamma$是折扣因子，取值范围为\[0, 1\)，表达了对长远未来回报的考虑程度，当$\gamma$=0时，只顾眼前的回报。$Q(e.g.,\epsilon-greedy)$ 为选择动作时的$\epsilon$贪婪策略,当$rand()<\epsilon$时，选择具有最大$Q(s,a)$值的动作$a$。    
> $Q(s,a)$为估计值，$\gamma*max_{a'}Q(s',a')$为的实际值，实际上，$\gamma*max_{a'}Q(s',a')$也为估计值,很奇妙吧！

## 2. Agent使用Q-learning来寻宝
> 智能体Agent在4x4的二维表格世界里寻找到宝藏, 寻宝地图如下图  

> ![宝藏地图](http://localhost:8888/tree/RL/img/寻宝图.png) 

> 左上角为起点，红色圆点表示智能体的位置，黄色方块表示宝藏位置，黑色方块表示陷阱位置   
> 在如黑色方块处的回报是$r=-1$， 黄色方块的回报是$r=+1$，其他方块处的回报为$r=0$.

### 2.1 导入相应的包

In [1]:
import time
import random
import os
import copy
import tkinter as tk

### 2.1 定义4x4的二维格子世界

In [2]:
class GridEenviroment(object):
    def __init__(self):
        self.n = 4
        self.action = ['East', 'South', 'West', 'North']
        self.terminal = {'death':[{'rol':2, 'col':1}, {'rol':2, 'col':2}],
                         'target':[{'rol':3, 'col':2}]}
        self.cav = None
        self.window = None
        
    def get_feed_back(self, state, last_state):
        rol = state['rol']
        col = state['col']
        r = 0
        if last_state['rol'] == rol and last_state['col'] == col:
            return -1
        for t in self.terminal['death']:
            if rol == t['rol'] and col == t['col']:
                r = -1
        for t in self.terminal['target']:
            if rol == t['rol'] and col == t['col']:
                r = 2
        return r
    
    def is_done(self, state):
        rol = state['rol']
        col = state['col']
        done = False
        for t in self.terminal['death']:
            if rol == t['rol'] and col == t['col']:
                done = True
                break
        for t in self.terminal['target']:
            if rol == t['rol'] and col == t['col']:
                done = True
                break
        return done
    
    def get_next_state(self, state, action):
        rol = state['rol']
        col = state['col']
        next_state = {'rol':rol, 'col':col}
        if action == 'East':
            col += 1
        elif action == 'South':
            rol += 1
        elif action == 'West':
            col -= 1
        elif action == 'North':
            rol -= 1
        else:
            pass
        if rol < 0 or rol >= self.n:
            return next_state
        if col < 0 or col >= self.n:
            return next_state
        next_state['rol'] = rol
        next_state['col'] = col
        return next_state
    
    def init_env(self):
        self.window=tk.Tk()
        self.window.title('Grid world')
        self.window.geometry('300x300') #长宽
        
        self.cav=tk.Canvas(self.window,bg='green',height=300,width=300)
            
    def reset_world(self):
        self.cav.delete()
        x0,y0 = 50,50
        step = 50
        
        for rol in range(self.n):
            for col in range(self.n):
                self.cav.create_rectangle(x0+col*step, y0+rol*step, x0+(col+1)*step, y0+(rol+1)*step,fill='gray')
        
        for t in self.terminal['death']:
            rol = t['rol'] 
            col = t['col']
            self.cav.create_rectangle(x0+col*step, y0+rol*step, x0+(col+1)*step, y0+(rol+1)*step,fill='black')
                
        for t in self.terminal['target']:
            rol = t['rol'] 
            col = t['col']  
            self.cav.create_rectangle(x0+col*step, y0+rol*step, x0+(col+1)*step, y0+(rol+1)*step,fill='yellow')
        
        rol = 0
        col = 0
        self.cav.create_oval(x0+col*step, y0+rol*step, x0+(col+1)*step, y0+(rol+1)*step,fill='red')
        
        self.cav.pack()
        
    def render_world(self, state):
        self.reset_world()
        x0,y0 = 50,50
        step = 50
        
        rol = 0
        col = 0
        self.cav.create_rectangle(x0+col*step, y0+rol*step, x0+(col+1)*step, y0+(rol+1)*step,fill='gray')
        
        rol = state['rol']
        col = state['col']
        self.cav.create_oval(x0+col*step, y0+rol*step, x0+(col+1)*step, y0+(rol+1)*step,fill='red')
        
        self.cav.pack()
        self.window.update_idletasks()
#         time.sleep(0.1)
        

### 2.3 定义智能体

In [3]:
class Agent(object):
    def __init__(self):
        self.init_state = {'rol':0, 'col':0}
        self.current_state = {'rol':0, 'col':0}
        self.current_action = 'East'
        self.action = ['East', 'South', 'West', 'North']
        self.feed_back = 0
        self.q_value = {}
        self.alpah = 0.1
        self.epsilon = 0.6
        self.gamma = 0
    
    def reset_state(self):
        self.current_state = {'rol':0, 'col':0}
        
    def init_q_value(self, env):
        for rol in range(env.n):
            for col in range(env.n):
                state = str(rol) + '-' + str(col)
                self.q_value.update({state:{}})
                for a in self.action:
                    self.q_value[state].update({a:0})
    
    def get_action_of_max_q_value(self, state):
        rol = state['rol']
        col = state['col']
        state = str(rol) + '-' + str(col)
        action_q_value = self.q_value[state]
        max_action = ''
        max_q_value = -10000
        for action, q_value in action_q_value.items():
            if q_value > max_q_value:
                max_action = action
                max_q_value = q_value
        similiar_max_action = []
        for action, q_value in action_q_value.items():
            if q_value == max_q_value:
                similiar_max_action.append(action)
        max_action = similiar_max_action[random.randint(0, len(similiar_max_action)-1)]
        return max_action, max_q_value
    
    def choose_action(self):
        max_action, _ = self.get_action_of_max_q_value(self.current_state)
        if random.random() < self.epsilon:
            self.current_action = max_action
        else:
            self.current_action = self.action[random.randint(0, len(self.action)-1)]
    
    def take_action(self, env):
        next_state = env.get_next_state(self.current_state, self.current_action)
        feed_back = env.get_feed_back(next_state, self.current_state)
        done = env.is_done(self.current_state)
        return next_state, feed_back, done
    
    def update_q_value(self, next_state,feed_back):
        rol = self.current_state['rol']
        col = self.current_state['col']
        current_state = str(rol) + '-' + str(col)
        q = self.q_value[current_state][self.current_action]
        _, max_q = self.get_action_of_max_q_value(next_state)
        q = q + self.alpah * (feed_back + self.gamma * max_q - q)
        self.q_value[current_state][self.current_action] = q
          

### 2.4 使用Q-learning寻宝

In [4]:

env = GridEenviroment()
agent = Agent()

agent.epsilon = 0.9
agent.gamma = 0.1

agent.init_q_value(env)
env.init_env()

max_episode = 20
for episode in range(max_episode):
    agent.reset_state()
    print('---------------  第 %d 回合 -----------' %  episode)
    while True:
        env.render_world(agent.current_state)
        agent.choose_action()
        next_state, feed_back, done = agent.take_action(env)
        agent.update_q_value(next_state, feed_back)
        agent.current_state = next_state
        if done:
            break

print(agent.q_value)

env.window.mainloop()
        
        

---------------  第 0 回合 -----------
---------------  第 1 回合 -----------
---------------  第 2 回合 -----------
---------------  第 3 回合 -----------
---------------  第 4 回合 -----------
---------------  第 5 回合 -----------
---------------  第 6 回合 -----------
---------------  第 7 回合 -----------
---------------  第 8 回合 -----------
---------------  第 9 回合 -----------
---------------  第 10 回合 -----------
---------------  第 11 回合 -----------
---------------  第 12 回合 -----------
---------------  第 13 回合 -----------
---------------  第 14 回合 -----------
---------------  第 15 回合 -----------
---------------  第 16 回合 -----------
---------------  第 17 回合 -----------
---------------  第 18 回合 -----------
---------------  第 19 回合 -----------
{'0-0': {'East': 0.0, 'South': 8.660665053070482e-07, 'West': -0.1, 'North': -0.19}, '0-1': {'East': 0.0, 'South': 0.0, 'West': 0.0, 'North': -0.271}, '0-2': {'East': 0.0, 'South': 0.0, 'West': 0.0, 'North': -0.1}, '0-3': {'East': -0.1, 'South': 0.0, 'West': 0.0, 'North

### 2.5 训练结束 与 最佳策略
> 最佳策略：   
> ![最优寻宝路线](http://localhost:8888/tree/RL/img/最优寻宝路线.jpg) 


> 智能体的学习过程：
> <video id="video" controls="" preload="none" poster="http://localhost:8888/tree/RL/img/寻宝图.png">
    <source id="mp4" src="http://localhost:8888/tree/RL/img/Gridworld.mp4" type="video/mp4">
    <p>智能体的学习过程</p>
  </video>
> 最终经过 10 回合后，智能体在这个环境学到了最佳行动策略，如上图，使用该策略可以很快就找到了宝藏   
> 最佳策略由$Q(s,a)$取最最大值可得到，最终学到的$Q(s,a)$值如下：
>      
{   
'0-0': {'East': 0.0, 'South': 1.7820526329856689e-07, 'West': -0.271, 'North': -0.271},    
'0-1': {'East': 0.0, 'South': 0.0, 'West': 0.0, 'North': -0.19},    
'0-2': {'East': 0.0, 'South': 0.0, 'West': 0.0, 'North': -0.1},    
'0-3': {'East': -0.1, 'South': 0.0, 'West': 0.0, 'North': -0.1},    
'1-0': {'East': 0.0, 'South': 1.7304008947845952e-05, 'West': -0.19, 'North': 0.0},    
'1-1': {'East': 0.0, 'South': -0.1, 'West': 0.0, 'North': 0.0},     
'1-2': {'East': 0.0, 'South': -0.34390000000000004, 'West': 0.0, 'North': 0.0},    
'1-3': {'East': -0.19, 'South': 2.000000000000001e-05, 'West': 0.0, 'North': 0.0},    
'2-0': {'East': -0.1, 'South': 0.0011747059721527013, 'West': -0.1, 'North': 0.0},    
'2-1': {'East': -0.1, 'South': 0.0, 'West': 0, 'North': 0.0},    
'2-2': {'East': 0.0, 'South': 0.2, 'West': -0.19, 'North': 0},    
'2-3': {'East': -0.19, 'South': 0.005600200000000001, 'West': -0.1, 'North': 0.0},     
'3-0': {'East': 0.050228083114677295, 'South': -0.19, 'West': -0.189944, 'North': 0.0},     
'3-1': {'East': 1.2263596317505943, 'South': -0.1, 'West': 0.00029951146164151703, 'North': -0.1},     
'3-2': {'East': 0.03328359434399456, 'South': -0.1, 'West': 0, 'North': -0.098},    
'3-3': {'East': -0.1, 'South': -0.18619980000000003, 'West': 0.542074002, 'North': 0.0}   
}   