In [25]:
import numpy as np
from rl_envs.grid_world_env import GridWorldEnv
from agents.value_iteration_agent import ValueIterationAgent
from agents.policy_iteration_agent import TruncatedPolicyIterationAgent
# rl_envs.grid_world_env import GridWorldEnv

%load_ext autoreload 
# %aimport rl_envs.grid_world_env

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
def value_iteration(env, theta=0.0001, discount_factor=1.0):
    """
    Value Iteration Algorithm.
    
    Args:
        env: Opepossible_actionsI env. env.P represents the transition probabilities of the environment.
            env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).
            env.nS is a number of states in the environment. 
            env.possible_actions is a number of actions in the environment.
        theta: We stop evaluation once our value function change is less than theta for all states.
        discount_factor: Gamma discount factor.
        
    Returns:
        A tuple (policy, V) of the optimal policy and the optimal value function.
    """
    
    def one_step_lookahead(state, V):
        """
        Helper function to calculate the value for all action in a given state.
        
        Args:
            state: The state to consider (int)
            V: The value to use as an estimator, Vector of length env.nS
        
        Returns:
            A vector of length env.possible_actions containing the expected value of each action.
        """
        A = np.zeros(env.possible_actions)
        for a in range(env.possible_actions):
            for prob, next_state, reward, done in env.P[state][a]:
                A[a] += prob * (reward + discount_factor * V[next_state])
        return A
    
    V = np.zeros(env.width * env.height)
    while True:
        # Stopping condition
        delta = 0
        # Update each state...
        for s in range(env.width * env.height):
            # Do a one-step lookahead to find the best action
            A = one_step_lookahead(s, V)
            best_action_value = np.max(A)
            # Calculate delta across all states seen so far
            delta = max(delta, np.abs(best_action_value - V[s]))
            # Update the value function. Ref: Sutton book eq. 4.10. 
            V[s] = best_action_value        
        # Check if we can stop 
        if delta < theta:
            break
    
    # Create a deterministic policy using the optimal value function
    policy = np.zeros([env.width * env.height, env.possible_actions])
    for s in range(env.width * env.height):
        # One step lookahead to find the best action for this state
        A = one_step_lookahead(s, V)
        best_action = np.argmax(A)
        # Always take the best action
        policy[s, best_action] = 1.0
    
    return policy, V

In [27]:
# env = GridWorldEnv(5, 5, forbidden_grids=[(1,1),(1,2), (2,2),(3,1),(3,3),(4,1)], target_grids=[(3,2)], forbidden_reward=-1, hit_wall_reward=-1)
env = GridWorldEnv(2, 2, forbidden_grids=[(0,1)], target_grids=[(1,1)])
env.init_model_based_transitions()


In [28]:
policy, V = value_iteration(env, theta=0.0001, discount_factor=0.9)
print(V)

[8.99915359 9.99915359 9.99915359 9.99915359]


Value Iteration 部分完全一致

我的结果为 {(0, 0): 8.999153585021714, (0, 1): 9.999153585021714, (1, 0): 9.999153585021714, (1, 1): 9.999153585021714})

In [29]:
def print_actions(policy, env):
    index = 0
    for i in range(env.height):
        print("[", end=" ")
        for j in range(env.width):
            action = np.argmax(policy[index])
            index+=1
            print(env.action_mappings[action], end=" ")
        print("]")

print_actions(policy, env)

[  ↓   ↓  ]
[  →   ↺  ]


测试更复杂的环境下的情况

In [30]:
env = GridWorldEnv(5, 5, forbidden_grids=[(1,1),(1,2), (2,2),(3,1),(3,3),(4,1)], target_grids=[(3,2)], forbidden_reward=-10, hit_wall_reward=-1)
# env = GridWorldEnv(2, 2, forbidden_grids=[(0,1)], target_grids=[(1,1)])
env.init_model_based_transitions()

policy, V = value_iteration(env, theta=0.0001, discount_factor=0.9)
print(env)
print_actions(policy, env)
print(V)

[ 0.000000 0.000000 0.000000 0.000000 0.000000 ]
[ 0.000000 -10.000000 -10.000000 0.000000 0.000000 ]
[ 0.000000 0.000000 -10.000000 0.000000 0.000000 ]
[ 0.000000 -10.000000 1.000000 -10.000000 0.000000 ]
[ 0.000000 -10.000000 0.000000 0.000000 0.000000 ]

[  →   →   →   →   ↓  ]
[  ↑   ↑   →   →   ↓  ]
[  ↑   ←   ↓   →   ↓  ]
[  ↑   →   ↺   ←   ↓  ]
[  ↑   →   ↑   ←   ←  ]
[3.48616736 3.87358785 4.30405506 4.78235196 5.31379296 3.13755063
 3.48622907 4.78235196 5.31379296 5.90428296 2.82379557 2.54141601
 9.99915359 5.90428296 6.56038296 2.54141601 9.99915359 9.99915359
 9.99923823 7.28938296 2.28727441 8.99923823 9.99923823 8.9993144
 8.09938296]


In [31]:
def print_actions(agent, env):
    index = 0
    for i in range(env.height):
        print("[", end=" ")
        for j in range(env.width):
            action = agent.get_action(index)
            print(env.action_mappings[action], end=" ")
            index += 1
        print("]")

In [39]:
agent = ValueIterationAgent(action_space_n=env.possible_actions, discounted_factor=0.9, threshold=0.0001)
agent.run(env)
print_actions(agent, env)
print(agent.v.values())

[  →   →   →   →   ↓  ]
[  ↑   ↑   →   →   ↓  ]
[  ↑   ←   ↓   →   ↓  ]
[  ↑   →   ↺   ←   ↓  ]
[  ↑   →   ↑   ←   ←  ]
dict_values([3.485937986021714, 3.8733584750217145, 4.303825685021715, 4.782122585021715, 5.313563585021715, 3.1372595459217134, 3.485937986021714, 4.782122585021715, 5.313563585021715, 5.904053585021714, 2.8234489498317146, 2.5410194133507136, 9.999153585021714, 5.904053585021714, 6.560153585021716, 2.5410194133507136, 9.999153585021714, 9.999153585021714, 9.999153585021714, 7.289153585021714, 2.286832830517813, 8.999153585021714, 9.999153585021714, 8.999153585021714, 8.099153585021716])


In [40]:
print("\nUnmatch state value below:\n")
myV = list(agent.v.values())
for i in range(len(V)):
    if V[i] != myV[i]:
        print(myV[i], V[i], V[i]-myV[i])


Unmatch state value below:

3.485937986021714 3.48616736448083 0.0002293784591160808
3.8733584750217145 3.8735878534808297 0.0002293784591151926
4.303825685021715 4.3040550634808294 0.00022937845911474852
4.782122585021715 4.782351963480831 0.00022937845911652488
5.313563585021715 5.31379296348083 0.00022937845911474852
3.1372595459217134 3.137550628032747 0.00029108211103379134
3.485937986021714 3.486229068132747 0.00029108211103290316
4.782122585021715 4.782351963480831 0.00022937845911652488
5.313563585021715 5.31379296348083 0.00022937845911474852
5.904053585021714 5.904282963480831 0.00022937845911741306
2.8234489498317146 2.8237955652294726 0.0003466153977580433
2.5410194133507136 2.5414160087065256 0.0003965953558120461
5.904053585021714 5.904282963480831 0.00022937845911741306
6.560153585021716 6.56038296348083 0.00022937845911386034
2.5410194133507136 2.5414160087065256 0.0003965953558120461
9.999153585021714 9.999238226519543 8.464149782838604e-05
7.289153585021714 7.2893829

微小差异, 主要原因在于在更新 q 值时 使用的 V 值的差异, 按照 Shiyu Zhao 书中 V 值应该使用上一个 iteration 的 V 值来计算第 当前 iteration 的 Q 值的 future reward, 但是这份代码直接使用本次更新的 V 值 (更接近圣经的伪码)

两种方法应该都能收敛到同一个地方, 区别不大