# Temporal Difference

## SARSA

$q_t(s, a) = q_{t-1}(s, a) + \frac{1}{N}[r(s^{\prime}) + \gamma \times q(s^{\prime}, a^{\prime}) - q_{t-1}(s, a)]$

In [1]:
import numpy as np
from snake import *

In [13]:
class SARSA(object):
    @staticmethod
    def _sarsa_eval(agent, env, epsilon):
        '''update action-value function'''
        # move on
        state = env.reset()
        prev_act = -1
        prev_state = -1
        while True:
            act = agent.play(state, epsilon)
            state, reward, done, _ = env.step(act) 
            # reward: future reward regarding with action-value function
            # agent.value_sa[state][act]: future action value
            # agent.value_sa[prev_state][prev_act]: previous action value
            
            # update action-value function
            if prev_act != -1:
                future_ret = reward + agent.gamma * (0 if done else agent.value_sa[state][act])
                agent.value_n[prev_state, prev_act] += 1
                N = agent.value_n[prev_state, prev_act]
                agent.value_sa[prev_state, prev_act] += (future_ret - agent.value_sa[prev_state, prev_act]) / N
        
            prev_act = act
            prev_state = state
        
            if done:
                break
        
        
    @staticmethod
    def _policy_improvement(agent):
        '''update policy'''
        new_policy = np.zeros_like(agent.pi) # (action_size, state_size, state_size)
        for s in range(1, agent.state_size):
            new_policy[s] = np.argmax(agent.value_sa[s, :]) # select the max action !!! not [s, a], but [s, :]
        if np.all(np.equal(new_policy, agent.pi)):
            return True # converge
        else:
            agent.pi = new_policy
            return False # not converge
    
    @staticmethod
    def sarsa_opt(agent, env, epsilon=0.0):
        for i in range(10):
            for j in range(2000): # more evaluation iterations than Monte Carlo !!!
                SARSA._sarsa_eval(agent, env, epsilon=epsilon)
            SARSA._policy_improvement(agent)

In [29]:
np.random.seed(3)
env = SnakeEnv(10, [3, 6])
agent = ModelFreeAgent(env)
SARSA.sarsa_opt(agent, env)
print('return:', eval_game(env, agent))
print(agent.pi)

return: 89
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0]


## Q Learning

$q_t(s, a) = q_{t-1}(s, a) + \frac{1}{N}[r(s^{\prime}) + \gamma \times \max_{a^{\prime}}q(s^{\prime}, a^{\prime}) - q_{t-1}(s, a)]$

In [30]:
class QLearning(object):
    @staticmethod
    def _q_learning_eval(agent, env, epsilon):
        '''update action-value function'''
        # move on
        state = env.reset()
        prev_act = -1
        prev_state = -1
        while True:
            act = agent.play(state, epsilon)
            state, reward, done, _ = env.step(act) 
            # reward: future reward regarding with action-value function
            # agent.value_sa[state][act]: future action value
            # agent.value_sa[prev_state][prev_act]: previous action value
            
            # update action-value function
            if prev_act != -1:
                future_ret = reward + agent.gamma * (0 if done else np.max(agent.value_sa[state, :]))

                agent.value_n[prev_state, prev_act] += 1
                N = agent.value_n[prev_state, prev_act]
                agent.value_sa[prev_state, prev_act] += (future_ret - agent.value_sa[prev_state, prev_act]) / N
        
            prev_act = act
            prev_state = state
        
            if done:
                break
        
        
    @staticmethod
    def _policy_improvement(agent):
        '''update policy'''
        new_policy = np.zeros_like(agent.pi) # (action_size, state_size, state_size)
        for s in range(1, agent.state_size):
            new_policy[s] = np.argmax(agent.value_sa[s, :]) # select the max action !!! not [s, a], but [s, :]
        if np.all(np.equal(new_policy, agent.pi)):
            return True # converge
        else:
            agent.pi = new_policy
            return False # not converge
    
    @staticmethod
    def q_learning_opt(agent, env, epsilon=0.0):
        for i in range(10):
            for j in range(4000): # more evaluation iterations than Monte Carlo !!!
                QLearning._q_learning_eval(agent, env, epsilon=epsilon)
            QLearning._policy_improvement(agent)

In [31]:
np.random.seed(3)
env = SnakeEnv(10, [3, 6])
agent = ModelFreeAgent(env)
QLearning.q_learning_opt(agent, env)
print('return:', eval_game(env, agent))
print(agent.pi)

return: 89
[0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
 1 1 0 1 1 1 1 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 1 1 0 0 0 1
 0 0 0 0 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0]
