In [51]:
import gym
import numpy as np

In [52]:
def learn_terminal_states(env):
    
    terminal_states = []
    
    for s in range(env.nS):
        if s in terminal_states:
            continue
        for a in env.P[s]:
            for prob,s_prim,reward,done in env.P[s][a]:
                if done:
                    if not s_prim in terminal_states:
                        terminal_states.append(s_prim)
    return terminal_states
    

In [53]:

def policy_evaluation(env, terminal_states, V, policy, gamma):
    loop_counter = 0
    while True:
#    for count in range(10000):
        delta = 0
        for s in range(env.nS):
            if s in terminal_states:
                V[s] = 0.0
                continue
                
            new_v = 0.0
            a = policy[s]
            for prob,s_prim,reward,done in env.P[s][a]:
                new_v += prob * (reward + gamma * V[s_prim])
            v = V[s]
            V[s] = new_v
            delta = max(delta, abs(v - V[s]))
            loop_counter +=1
        if delta < 0.0001:
            break
    return V, loop_counter
            



In [54]:
def policy_iteration(env, gamma):

    policy_stable = False
    V = np.zeros(env.nS)
    policy = np.zeros(env.nS)
    terminal_states = learn_terminal_states(env)
    loop_counter = 0
    while not policy_stable:
        V, counter = policy_evaluation(env, terminal_states, V, policy, gamma)
        
        loop_counter += counter

        #Policy improvment
        policy_stable = True

        for s in range(env.nS):
            old_p = policy[s]

            actions = np.zeros(env.nA)
            for a in range(env.nA):
                for prob,s_prim,reward,done in env.P[s][a]:
                    if done:
                        actions[a] += prob * reward 
                    else:
                        actions[a] += prob * (reward + gamma * V[s_prim])

            policy[s] = np.argmax(actions)

            if policy[s] != old_p:
                policy_stable = False
    print('Number of loops:', loop_counter)
    return V, policy

In [55]:
env = gym.make('CliffWalking-v0')
env.reset()

V, policy = policy_iteration(env, 0.9)

steps = 0
total_reward = 0

for c in range(1000):
    done = False
    state = env.reset()
    while not done:
        state, reward, done, info = env.step(policy[state])
        total_reward += reward
        steps += 1
    
env.close()
print('Mean reward: ', total_reward/1000, 'Mean steps:', steps/1000)


Number of loops: 5499
Mean reward:  -13.0 Mean steps: 13.0


In [56]:
env = gym.make('FrozenLake-v0')
env.reset()

V, policy = policy_iteration(env, 0.9)

steps = 0
total_reward = 0

for c in range(1000):
    done = False
    state = env.reset()
    while not done:
        state, reward, done, info = env.step(policy[state])
        total_reward += reward
        steps += 1
    
env.close()
print('Mean reward: ', total_reward/1000, 'Mean steps:', steps/1000)


Number of loops: 968
Mean reward:  0.758 Mean steps: 39.516


In [57]:
env = gym.make('Taxi-v3')
env.reset()

V, policy = policy_iteration(env, 0.9)

steps = 0
total_reward = 0

for c in range(1000):
    done = False
    state = env.reset()
    while not done:
        state, reward, done, info = env.step(policy[state])
        total_reward += reward
        steps += 1
    
env.close()
print('Mean reward: ', total_reward/1000, 'Mean steps:', steps/1000)


Number of loops: 77872
Mean reward:  7.834 Mean steps: 13.166
