# Expected Sarsa

In [1]:
import numpy as np
import gym

In [3]:
env = gym.make('FrozenLake8x8-v0')

In [82]:
def epsilon_greedy_probs(Q,state,epsilon):
    #returns probability of each action for a given state
    probs = np.zeros((Q.shape[1]))+epsilon/Q.shape[1]
    best_action = Q[state,:].argmax()
    probs[best_action] += 1 - epsilon
    return probs

def choose_action(pi_probs,n_actions):
    #returns action based on policy
    return np.random.choice(range(n_actions),p=pi_probs)

In [105]:
def expected_sarsa(env,n_episodes,alpha,gamma,epsilon):
    
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    #initialize action-value function
    Q = np.zeros((n_states,n_actions)) + 0.5
    Q[n_states-1,:] = np.zeros(n_actions)
    
    for n in range(n_episodes):
        
        next_state = env.reset()
        done = False
        
        #generate episode
        while not done:
            
            state = next_state
            action = choose_action(epsilon_greedy_probs(Q,state,epsilon),n_actions)
            next_state, reward, done, info = env.step(action)
            
            probs = epsilon_greedy_probs(Q,next_state,epsilon)
            
            #TD update of action-value function
            Q[state,action] += alpha * (reward + gamma * np.sum(probs[a]*Q[next_state,a] for a in range(n_actions)) - Q[state,action])

    return Q

In [106]:
Q = expected_sarsa(env,2000,0.8,0.9,0.1)

In [107]:
Q

array([[0.06512623, 0.11657942, 0.09439198, 0.06617008],
       [0.09859404, 0.1008878 , 0.10385653, 0.08419835],
       [0.10894218, 0.1221738 , 0.13388931, 0.12379043],
       [0.10459233, 0.12730272, 0.15955584, 0.10415731],
       [0.09353305, 0.26161564, 0.09720503, 0.09848714],
       [0.10516636, 0.11171306, 0.23037542, 0.1005726 ],
       [0.10700139, 0.10690943, 0.12162299, 0.11121538],
       [0.11814847, 0.13298698, 0.11114836, 0.09825464],
       [0.09947446, 0.10434635, 0.11841665, 0.08040223],
       [0.09006615, 0.091728  , 0.26154873, 0.09508759],
       [0.14688228, 0.11301591, 0.17827709, 0.12378641],
       [0.1982952 , 0.22846081, 0.18777419, 0.13457142],
       [0.36612532, 0.12398119, 0.14138722, 0.10475301],
       [0.11458951, 0.22587756, 0.10757546, 0.12583371],
       [0.13517924, 0.19125275, 0.11952932, 0.11573972],
       [0.1465823 , 0.12926727, 0.13232908, 0.11861467],
       [0.11227528, 0.09786017, 0.13970567, 0.09984749],
       [0.12623921, 0.23383441,