# SARSA
- Sampling
- On-Policy
- 정책 평가(학습) : TD-Learning(큐함수에 대한 벨만기대방정식)
    
$$q(s, a) \leftarrow q(s, a) + \alpha(r + \gamma q(s', a') - q(s, a))$$<br>

- 정책 발전(행동선택) : $\epsilon$ - greedy
$$\pi(s) = \begin{cases} 
    a^* = argmax_{a \in A} q(s, a), 1 - \epsilon \\
    \text{random action}, \ \epsilon 
\end{cases}$$

## 기억할 점
- TD(Temporal-difference) : online learning
- Bootstrap : 자신의 상태를 통해 자기 자신을 업데이트
- TD(0) 사용, TD(무한) -> MC
    - n-step에서 n이 증가하면 **Bias는 감소, Variance는 증가**

In [1]:
import numpy as np
import random
from collections import defaultdict
from environment import Env
import time

In [2]:
class Sarsa_Agent:
    def __init__(self, actions):
        self.actions = actions
        self.learning_rate = 0.01
        self.discount_factor = 0.9
        self.epsilon = 0.1
        
        # defaultdict는 초기화 되어있지 않은 key값이 들어오면 미리 선언해준 값 대로 초기화
        self.q_table = defaultdict(lambda : [0.0, 0.0, 0.0, 0.0])
        
    # 정책평가(evaluation)
    def learn(self, state, action, reward, next_state, next_action):
        current_q = self.q_table[state][action]
        next_state_q = self.q_table[next_state][next_action]
        new_q = (current_q + self.learning_rate * 
                 (reward + self.discount_factor * next_state_q - current_q))
        self.q_table[state][action] = new_q
        
    # 정책발전(improvement : 엡실론 그리디)
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            # 무작위 행동 변환
            action = np.random.choice(self.actions)
        else:
            # 큐함수에 따른 행동 변환
            state_action = self.q_table[state]
            action = self.arg_max(state_action)
        return action
    
    @staticmethod
    def arg_max(state_action):
        max_index_list = []
        max_value = state_action[0]
        for index, value in enumerate(state_action):
            if value > max_value:
                max_index_list.clear()
                max_value = value
                max_index_list.append(index)
            elif value == max_value:
                max_index_list.append(index)
        return random.choice(max_index_list)

In [3]:
EPISODES = 100

In [4]:
# SARSA : Online Learning

if __name__ == '__main__':
    env = Env()
    sarsa_agent = Sarsa_Agent(actions=list(range(env.n_actions)))
    
    for episode in range(EPISODES):
        # 게임 환경과 상태를 초기화
        state = env.reset()
        
        # 현재 상태에 대한 행동을 선택(에피소드 마다 exploration)
        action = sarsa_agent.get_action(str(state))
        while True:
            env.render()
            
            # 행동으로 reward, next_state, 에피소드의 종료 여부를 받음
            next_state, reward, done = env.step(action) # step설정
            
            # next_state에서 next_action을 뽑아옴
            next_action = sarsa_agent.get_action(str(next_state))
            
            # <s, a, r, s', a'>로 큐함수 업데이트
            sarsa_agent.learn(str(state), action, reward, str(next_state), next_action)
            state = next_state
            action = next_action
            
            # 모든 큐함수를 화면에 표시
            env.print_value_all(sarsa_agent.q_table)
            if done: # 종료가 되었다면~~~
                if episode == 30:
                    for y in range(5):
                        for x in range(5):
                            print("x : ", x, " y : ", y, sarsa_agent.q_table[str([x, y])])
                        print("")
                    time.sleep(100)
                break
    env.destroy()

x :  0  y :  0 [2.7912842919680606e-10, 0.0, 0.0, 3.0123560069871326e-06]
x :  1  y :  0 [0.0, 2.4862551949245376e-10, 8.17496419217527e-09, 0.00013283272885585064]
x :  2  y :  0 [-0.01791, -2.9701, 0.0, 0.004100423772283158]
x :  3  y :  0 [0.0, 0.08718789530875552, -0.009000000000000001, 0.0]
x :  4  y :  0 [0.0, 7.290000000000002e-07, 0.0, 0.0]

x :  0  y :  1 [0.0, 0.0, 0.0, 0.0]
x :  1  y :  1 [2.1286241743220398e-07, -1.0, 0.0, -1.0]
x :  2  y :  1 [0.0, 0.0, 0.0, 0.0]
x :  3  y :  1 [0.0, 1.7105314189877845, -1.0, 1.5528003518712602e-05]
x :  4  y :  1 [0.0, 0.0, 0.014353090320814582, 0.0]

x :  0  y :  2 [0.0, 0.0, 0.0, -1.0]
x :  1  y :  2 [0.0, 0.0, 0.0, 0.0]
x :  2  y :  2 [0.0, 0.0, 0.0, 0.0]
x :  3  y :  2 [0.0, 0.0, 19.8369410460954, 0.0]
x :  4  y :  2 [0.0, 0.0, 0.11812876850791953, 0.0]

x :  0  y :  3 [0.0, 0.0, 0.0, 0.0]
x :  1  y :  3 [-1.0, 0.0, 0.0, 0.0]
x :  2  y :  3 [1.0, 0.0, 0.0, 0.0]
x :  3  y :  3 [0.0, 0.0, 0.0, 0.0]
x :  4  y :  3 [0.0, 0.0, 0.0, 0.0]

x