In [15]:
import gym
import numpy as np

In [16]:
def learn_terminal_states(env):
    
    terminal_states = []
    
    for s in range(env.nS):
        if s in terminal_states:
            continue
        for a in env.P[s]:
            for prob,s_prim,reward,done in env.P[s][a]:
                if done:
                    if not s_prim in terminal_states:
                        terminal_states.append(s_prim)
    return terminal_states
    

In [34]:

def policy_evaluation(env, terminal_states, V, policy, gamma):
    while True:
#    for count in range(10000):
        delta = 0
        for s in range(env.nS):
            if s in terminal_states:
                V[s] = 0.0
                continue
                
            new_v = 0.0
            a = policy[s]
            for prob,s_prim,reward,done in env.P[s][a]:
                new_v += prob * (reward + gamma * V[s_prim])
            v = V[s]
            V[s] = new_v
            delta = max(delta, abs(v - V[s]))
        
        if delta < 0.0001:
            break
    return V
            



In [35]:
def policy_iteration(env, gamma):

    policy_stable = False
    V = np.zeros(env.nS)
    policy = np.zeros(env.nS)
    terminal_states = learn_terminal_states(env)
    
    while not policy_stable:
        V = policy_evaluation(env, terminal_states, V, policy, gamma)

        #Policy improvment
        policy_stable = True

        for s in range(env.nS):
            old_p = policy[s]

            actions = np.zeros(env.nA)
            for a in range(env.nA):
                for prob,s_prim,reward,done in env.P[s][a]:
                    if done:
                        actions[a] += prob * reward 
                    else:
                        actions[a] += prob * (reward + gamma * V[s_prim])

            policy[s] = np.argmax(actions)

            if policy[s] != old_p:
                policy_stable = False
    return V, policy

In [36]:
env = gym.make('CliffWalking-v0')
env.reset()

V, policy = policy_iteration(env, 0.9)

done = False
steps = 0
total_reward = 0
state = env.reset()
while not done:
    state, reward, done, info = env.step(policy[state])
    total_reward += reward
    steps += 1

    
env.close()
print('Total reward: ', total_reward, 'in steps:', steps)


Total reward:  -13 in steps: 13


In [50]:
env = gym.make('FrozenLake-v0')
env.reset()

V, policy = policy_iteration(env, 0.9)

done = False
steps = 0
total_reward = 0
state = env.reset()
while not done:
    state, reward, done, info = env.step(policy[state])
    total_reward += reward
    steps += 1

    
env.close()
print('Total reward: ', total_reward, 'in steps:', steps)

Total reward:  1.0 in steps: 41


In [40]:
env = gym.make('Taxi-v3')
env.reset()

V, policy = policy_iteration(env, 0.9)

done = False
steps = 0
total_reward = 0
state = env.reset()
while not done:
    state, reward, done, info = env.step(policy[state])
    total_reward += reward
    steps += 1

    
env.close()
print('Total reward: ', total_reward, 'in steps:', steps)

Total reward:  9 in steps: 12
