# Expected Sarsa

Similar to Q_learning, but instead of using the maximum over the next state-action pairs, it takes into account the probability of each action 

In [1]:
import numpy as np
import gym

In [2]:
env = gym.make('FrozenLake8x8-v0')

In [3]:
def epsilon_greedy_probs(Q,state,epsilon):
    #returns probability of each action for a given state
    probs = np.zeros((Q.shape[1]))+epsilon/Q.shape[1]
    best_action = Q[state,:].argmax()
    probs[best_action] += 1 - epsilon
    return probs

def choose_action(pi_probs,n_actions):
    #returns action based on policy
    return np.random.choice(range(n_actions),p=pi_probs)

In [4]:
def expected_sarsa(env,n_episodes,alpha,gamma,epsilon):
    
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    #initialize action-value function
    Q = np.zeros((n_states,n_actions)) + 0.5
    Q[n_states-1,:] = np.zeros(n_actions)
    
    for n in range(n_episodes):
        
        next_state = env.reset()
        done = False
        
        #generate episode
        while not done:
            
            state = next_state
            
            #choose action based on policy
            action = choose_action(epsilon_greedy_probs(Q,state,epsilon),n_actions)
            
            next_state, reward, done, info = env.step(action)
            
            #action probabilities for the update
            probs = epsilon_greedy_probs(Q,next_state,epsilon)
            
            #TD update of action-value function
            Q[state,action] += alpha * (reward + gamma * np.sum(probs[a]*Q[next_state,a] for a in range(n_actions)) - Q[state,action])

    return Q

In [5]:
Q = expected_sarsa(env,2000,0.8,0.9,0.1)

In [6]:
Q

array([[0.08066885, 0.14535178, 0.11825139, 0.0802569 ],
       [0.09663394, 0.08613346, 0.1434457 , 0.08920343],
       [0.09472482, 0.10026158, 0.0959848 , 0.09822764],
       [0.09892794, 0.09373709, 0.11474356, 0.09608107],
       [0.10215154, 0.10320402, 0.09618632, 0.1003031 ],
       [0.09917179, 0.10511685, 0.08573673, 0.106244  ],
       [0.13455549, 0.08882224, 0.08750541, 0.08906736],
       [0.08980127, 0.0891297 , 0.08651857, 0.10817246],
       [0.09296721, 0.14445293, 0.09020543, 0.09364567],
       [0.10404896, 0.15139764, 0.10323301, 0.13322339],
       [0.10603294, 0.10989229, 0.11980438, 0.10274475],
       [0.1111543 , 0.11915072, 0.11932586, 0.13216059],
       [0.1771494 , 0.11768624, 0.11348556, 0.10374275],
       [0.11473333, 0.11435696, 0.11142432, 0.11399737],
       [0.12103909, 0.28247331, 0.20988722, 0.12122316],
       [0.09484069, 0.21142175, 0.0941869 , 0.0984563 ],
       [0.12529295, 0.10827381, 0.21397794, 0.12414673],
       [0.15959288, 0.27412251,