# Value Iteration

We can merge a truncated policy-evaluation step and a policy improvement into the same equation. Instead of improving the policy (by taking the argmax to get a better policy and then evaluating this improved policy to obtain a value function again), we directly calculate the maximum (max, instead of argmax) value across the actions to be used for the next sweep over the states

In [1]:
import numpy as np

def value_iteration(P, gamma=1.0, theta=1e-10):
    
    # first thing is to initialize a state-value function.
    V = np.zeros(len(P))
    
    # We get in this loop and initialize a Q-function to zero.
    while True:
        Q = np.zeros((len(P), len(P[0])))
        
        # then, for every transition of every action in every 
        # state, we...
        for s in range(len(P)):
            for a in range(len(P[s])):
                for prob, next_state, reward, done in P[s][a]:
                    
                    # ...calculate the action-value function
                    # notice, using V, which is the old truncated estimate
                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
        
        # After each sweep over the state space, make sure 
        # the state-value function keeps changing. 
        # Otherwise, we found the optimal V-function and should break out
        if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
            break
            
        # we don’t need a separate policy-improvement phase. 
        # It’s not a direct replacement, but instead a combination of 
        # improvement and evaluation    
        V = np.max(Q, axis=1)
    
    # only at the end do we extract the optimal policy and return it 
    # along with the optimal state-value function.
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return V, pi

And again, we can solve the Frozen Lake problem:

In [2]:
import gym

env = gym.make('FrozenLake-v0')
P = env.env.P
init_state = env.reset()
goal_state = 15

In [3]:
V_best, pi_best = value_iteration(env.env.P, gamma=0.99)

And print policy, state value function and other stuff:

In [4]:
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^'), n_cols=4, title='Policy:'):
    print(title)
    arrs = {k:v for k,v in enumerate(action_symbols)}
    for s in range(len(P)):
        a = pi(s)
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), arrs[a].rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [5]:
import random

def probability_success(env, pi, goal_state, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, done, steps = env.reset(), False, 0
        while not done and steps < max_steps:
            state, _, done, h = env.step(pi(state))
            steps += 1
        results.append(state == goal_state)
    return np.sum(results)/len(results)

In [6]:
def mean_return(env, pi, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, done, steps = env.reset(), False, 0
        results.append(0.0)
        while not done and steps < max_steps:
            state, reward, done, _ = env.step(pi(state))
            results[-1] += reward
            steps += 1
    return np.mean(results)

In [7]:
print_policy(pi_best, P)

ps = probability_success(env, pi_best, goal_state=goal_state)*100
mr = mean_return(env, pi_best)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      < | 01      ^ | 02      ^ | 03      ^ |
| 04      < |           | 06      < |           |
| 08      ^ | 09      v | 10      < |           |
|           | 13      > | 14      v |           |
Reaches goal 74.00%. Obtains an average undiscounted return of 0.7400.


In [8]:
def print_state_value_function(V, P, n_cols=4, prec=3, title='State-value function:'):
    print(title)
    for s in range(len(P)):
        v = V[s]
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), '{}'.format(np.round(v, prec)).rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [9]:
print_state_value_function(V_best, P, prec=4)

State-value function:
| 00  0.542 | 01 0.4988 | 02 0.4707 | 03 0.4569 |
| 04 0.5585 |           | 06 0.3583 |           |
| 08 0.5918 | 09 0.6431 | 10 0.6152 |           |
|           | 13 0.7417 | 14 0.8628 |           |
