 Compute the value of each state.

 Derive the best (greedy) policy.

 Execute that policy in the environment to collect rewards.

In [2]:
import gym
import numpy as np

<img alt="Alt Text" src="img.png"/>

In [9]:
# env.observation_space.n=500
# env.action_space.n =6

# performs bellman update to improve value function
def iterate_value_function(v_inp, gamma, env):
    ret = np.zeros(env.observation_space.n)
    for statee in range(env.observation_space.n):
        temp_v = np.zeros(env.action_space.n)
        for action in range(env.action_space.n):
            for (prob, dst_state, reward, is_final) in env.P[statee][action]:
                temp_v[action] += prob * (reward + gamma * v_inp[dst_state] * (not is_final))
        ret[statee] = max(temp_v)
    return ret


def build_greedy_policy(v_inp, gamma, env):
    new_policy = np.zeros(env.observation_space.n)
    for state_id in range(env.observation_space.n):
        profits = np.zeros(env.action_space.n)
        for action in range(env.action_space.n):
            for (prob, dst_state, reward, is_final) in env.P[state_id][action]:
                profits[action] += prob * (reward + gamma * v_inp[dst_state])
        new_policy[state_id] = np.argmax(profits)
    return new_policy

env = gym.make('Taxi-v3')
# env = gym.make("CartPole-v1")
gamma = 0.9
cum_reward = 0
n_rounds = 500

In [10]:
for t_rounds in range(n_rounds):
    observation = env.reset()
    if isinstance(observation, tuple):  # For gym>=0.26
        observation = observation[0]
    v = np.zeros(env.observation_space.n)

    for _ in range(100):
        v_old = v.copy()
        v = iterate_value_function(v, gamma, env)
        if np.allclose(v, v_old):
            break

    policy = build_greedy_policy(v, gamma, env).astype(np.int32)


    for _ in range(1000):
        action = policy[observation]
        step_result = env.step(action)
        if len(step_result) == 5:
            observation, reward, done, truncated, info = step_result
            done = done or truncated
        else:
            observation, reward, done, info = step_result
        cum_reward += reward
        if done:
            break

    if t_rounds % 50 == 0 and t_rounds > 0:
        print(f"Average reward after {t_rounds + 1} episodes: {cum_reward / (t_rounds + 1):.2f}")

env.close()


Average reward after 51 episodes: 8.27
Average reward after 101 episodes: 7.76
Average reward after 151 episodes: 7.72
Average reward after 201 episodes: 7.92
Average reward after 251 episodes: 7.96
Average reward after 301 episodes: 8.08
Average reward after 351 episodes: 8.04
Average reward after 401 episodes: 8.00
Average reward after 451 episodes: 8.06


# value of v after 100 iterations

v
Out[12]:


array([17.        ,  1.62261467,  7.7147    ,  2.9140163 , -4.99684549,
        1.62261467, -4.99684549, -3.13696226,  1.62261467, -2.37440252,
        7.7147    , -1.52711391, -3.82326604, -2.37440252, -3.82326604,
        2.9140163 , 20.        ,  2.9140163 ,  9.683     ,  4.348907  ,
       14.3       ,  0.4603532 ,  5.94323   ,  1.62261467, -4.44093943,
        2.9140163 , -4.44093943, -2.37440252,  0.4603532 , -3.13696226,
        5.94323   , -2.37440252, -3.13696226, -1.52711391, -3.13696226,
        4.348907  , 17.        ,  4.348907  ,  7.7147    ,  5.94323   ,
        4.348907  , -3.82326604, -0.58568212, -3.13696226, -0.58568212,
       11.87      , -0.58568212,  2.9140163 , -0.58568212, -3.82326604,
        4.348907  , -3.13696226, -2.37440252, -0.58568212, -2.37440252,
        5.94323   ,  5.94323   , 14.3       ,  5.94323   ,  7.7147    ,
        2.9140163 , -4.44093943, -1.52711391, -3.82326604,  0.4603532 ,
       14.3       ,  0.4603532 ,  4.348907  , -1.52711391, -4.44093943,
        2.9140163 , -3.82326604, -1.52711391,  0.4603532 , -1.52711391,
        7.7147    ,  4.348907  , 17.        ,  4.348907  ,  9.683     ,
        1.62261467, -4.99684549, -2.37440252, -4.44093943,  1.62261467,
       17.        ,  1.62261467,  5.94323   , -2.37440252, -4.99684549,
        1.62261467, -4.44093943, -2.37440252, -0.58568212, -2.37440252,
        5.94323   ,  2.9140163 , 20.        ,  2.9140163 ,  7.7147    ,
       14.3       ,  0.4603532 ,  5.94323   ,  1.62261467, -4.44093943,
        2.9140163 , -4.44093943, -2.37440252,  2.9140163 , -1.52711391,
        9.683     , -0.58568212, -3.13696226, -1.52711391, -3.13696226,
        4.348907  , 17.        ,  4.348907  , 11.87      ,  5.94323   ,
       11.87      , -0.58568212,  4.348907  ,  0.4603532 , -3.82326604,
        4.348907  , -3.82326604, -1.52711391,  1.62261467, -2.37440252,
        7.7147    , -1.52711391, -2.37440252, -0.58568212, -2.37440252,
        5.94323   , 14.3       ,  5.94323   ,  9.683     ,  7.7147    ,
        5.94323   , -3.13696226,  0.4603532 , -2.37440252, -1.52711391,
        9.683     , -1.52711391,  1.62261467,  0.4603532 , -3.13696226,
        5.94323   , -2.37440252, -1.52711391,  0.4603532 , -1.52711391,
        7.7147    ,  7.7147    , 11.87      ,  7.7147    ,  9.683     ,
        4.348907  , -3.82326604, -0.58568212, -3.13696226, -0.58568212,
       11.87      , -0.58568212,  2.9140163 , -0.58568212, -3.82326604,
        4.348907  , -3.13696226, -0.58568212,  1.62261467, -0.58568212,
        9.683     ,  5.94323   , 14.3       ,  5.94323   , 11.87      ,
        2.9140163 , -4.44093943, -1.52711391, -3.82326604,  0.4603532 ,
       14.3       ,  0.4603532 ,  4.348907  , -1.52711391, -4.44093943,
        2.9140163 , -3.82326604, -1.52711391,  0.4603532 , -1.52711391,
        7.7147    ,  4.348907  , 17.        ,  4.348907  ,  9.683     ,
       11.87      , -0.58568212,  4.348907  ,  0.4603532 , -3.82326604,
        4.348907  , -3.82326604, -1.52711391,  4.348907  , -0.58568212,
       11.87      ,  0.4603532 , -2.37440252, -0.58568212, -2.37440252,
        5.94323   , 14.3       ,  5.94323   , 14.3       ,  7.7147    ,
        9.683     , -1.52711391,  2.9140163 , -0.58568212, -3.13696226,
        5.94323   , -3.13696226, -0.58568212,  2.9140163 , -1.52711391,
        9.683     , -0.58568212, -1.52711391,  0.4603532 , -1.52711391,
        7.7147    , 11.87      ,  7.7147    , 11.87      ,  9.683     ,
        7.7147    , -2.37440252,  1.62261467, -1.52711391, -2.37440252,
        7.7147    , -2.37440252,  0.4603532 ,  1.62261467, -2.37440252,
        7.7147    , -1.52711391, -0.58568212,  1.62261467, -0.58568212,
        9.683     ,  9.683     ,  9.683     ,  9.683     , 11.87      ,
        5.94323   , -3.13696226,  0.4603532 , -2.37440252, -1.52711391,
        9.683     , -1.52711391,  1.62261467,  0.4603532 , -3.13696226,
        5.94323   , -2.37440252,  0.4603532 ,  2.9140163 ,  0.4603532 ,
       11.87      ,  7.7147    , 11.87      ,  7.7147    , 14.3       ,
        4.348907  , -3.82326604, -0.58568212, -3.13696226, -0.58568212,
       11.87      , -0.58568212,  2.9140163 , -0.58568212, -3.82326604,
        4.348907  , -3.13696226, -0.58568212,  1.62261467, -0.58568212,
        9.683     ,  5.94323   , 14.3       ,  5.94323   , 11.87      ,
        9.683     , -1.52711391,  2.9140163 , -0.58568212, -4.44093943,
        2.9140163 , -4.44093943, -2.37440252,  5.94323   ,  0.4603532 ,
       14.3       ,  1.62261467, -3.13696226, -1.52711391, -3.13696226,
        4.348907  , 11.87      ,  4.348907  , 17.        ,  5.94323   ,
        7.7147    , -2.37440252,  1.62261467, -1.52711391, -3.82326604,
        4.348907  , -3.82326604, -1.52711391,  1.62261467, -2.37440252,
        7.7147    , -1.52711391, -2.37440252, -0.58568212, -2.37440252,
        5.94323   ,  9.683     ,  5.94323   ,  9.683     ,  7.7147    ,
        5.94323   , -3.13696226,  0.4603532 , -2.37440252, -3.13696226,
        5.94323   , -3.13696226, -0.58568212,  0.4603532 , -3.13696226,
        5.94323   , -2.37440252, -1.52711391,  0.4603532 , -1.52711391,
        7.7147    ,  7.7147    ,  7.7147    ,  7.7147    ,  9.683     ,
        4.348907  , -3.82326604, -0.58568212, -3.13696226, -2.37440252,
        7.7147    , -2.37440252,  0.4603532 , -0.58568212, -3.82326604,
        4.348907  , -3.13696226,  1.62261467,  4.348907  ,  1.62261467,
       14.3       ,  5.94323   ,  9.683     ,  5.94323   , 17.        ,
        2.9140163 , -4.44093943, -1.52711391, -3.82326604, -1.52711391,
        9.683     , -1.52711391,  1.62261467, -1.52711391, -4.44093943,
        2.9140163 , -3.82326604,  0.4603532 ,  2.9140163 ,  0.4603532 ,
       11.87      ,  4.348907  , 11.87      ,  4.348907  , 14.3       ,
        7.7147    , -2.37440252,  1.62261467, -1.52711391, -4.99684549,
        1.62261467, -4.99684549, -3.13696226,  7.7147    ,  1.62261467,
       17.        ,  2.9140163 , -3.82326604, -2.37440252, -3.82326604,
        2.9140163 ,  9.683     ,  2.9140163 , 20.        ,  4.348907  ,
        5.94323   , -3.13696226,  0.4603532 , -2.37440252, -4.44093943,
        2.9140163 , -4.44093943, -2.37440252,  0.4603532 , -3.13696226,
        5.94323   , -2.37440252, -3.13696226, -1.52711391, -3.13696226,
        4.348907  ,  7.7147    ,  4.348907  ,  7.7147    ,  5.94323   ,
        4.348907  , -3.82326604, -0.58568212, -3.13696226, -3.82326604,
        4.348907  , -3.82326604, -1.52711391, -0.58568212, -3.82326604,
        4.348907  , -3.13696226, -2.37440252, -0.58568212, -2.37440252,
        5.94323   ,  5.94323   ,  5.94323   ,  5.94323   ,  7.7147    ,
        2.9140163 , -4.44093943, -1.52711391, -3.82326604, -3.13696226,
        5.94323   , -3.13696226, -0.58568212, -1.52711391, -4.44093943,
        2.9140163 , -3.82326604,  2.9140163 ,  5.94323   ,  2.9140163 ,
       17.        ,  4.348907  ,  7.7147    ,  4.348907  , 20.        ,
        1.62261467, -4.99684549, -2.37440252, -4.44093943, -2.37440252,
        7.7147    , -2.37440252,  0.4603532 , -2.37440252, -4.99684549,
        1.62261467, -4.44093943,  1.62261467,  4.348907  ,  1.62261467,
       14.3       ,  2.9140163 ,  9.683     ,  2.9140163 , 17.        ])
