# 构建世界模型
1. state space: 状态空间，包括所有可能的状态
2. action space: 动作空间，包括所有可能的动作，默认包括`up`, `down`, `left`, `right`以及`still`五个动作
3. p(r | s, a): 状态-动作转移概率，即在状态s下执行动作a之后，环境给出的奖励r的概率，这里的奖励r是实际上来自于所有可能的s'的奖励的期望值
4. p(s' | s, a): 状态转移概率，即在状态s下执行动作a之后，环境转移到状态s'的概率。可以越界，越界的reward设置为-100




In [None]:
from utils.grids import build_models, actions, plot_values_and_policy, plot_values_and_policy_gif
# 定义网格参数
grid_size = 5

target_areas = [(3, 2)]  # 终止和危险状态
forbidden_areas = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)] # 禁止状态

states, p_r, p_s_prime = build_models(grid_size, target_areas, forbidden_areas, success_prob=1)
print(len(states), states[:2])
v_initial = {s: 0 for s in states}
p_initial = {s: {"still": 1.0} for s in states}
plot_values_and_policy(
    value_dict=v_initial, 
    policy_dict=p_initial, 
    target_cells=target_areas, 
    forbidden_cells=forbidden_areas, 
    title="Initial State Value Function")
print(p_r)
print(p_s_prime)
print(p_r[(0, 0)])
print(p_s_prime.keys())

In [None]:
import numpy as np
import random

def policy_evaluation(states, policy, p_r, p_s_prime, gamma, threshold=1e-5, max_iter=100):
    """
    策略评估：计算当前策略下的状态价值函数
    """
    v = {s: 0.0 for s in states}
    for _ in range(max_iter):
        delta = 0
        value_new = {}
        for s in states:
            total = 0
            # 遍历所有动作（根据策略的分布）
            for a in policy[s]:
                if policy[s][a] == 0:
                    continue  # 概率为0的动作可直接跳过
                # ------------------------------------------------------------------------------------
                # 通过计算每一个action的总和来计算v_new(s)
                # v_new(s) = ∑_a policy(s, a) * (∑_r p ( r | s, a) * r + gamma * ∑_{s'} p (s' | s, a) * v(s'))
                # 建议使用列表推导来实现下面的代码
                # 参考信息：https://docs.python.org/zh-cn/3.13/tutorial/datastructures.html#list-comprehensions
                # 先通过p_r计算expected_r，再通过p_s_prime计算expected_next_v，累积加权到total
                # 最后在循环外更新value_new[s]
                # Expected code: ~3 lines
                # ------------------------------------------------------------------------------------
                # 计算预期奖励
                expected_r = sum(prob * r for r, prob in p_r[s][a].items())
                # 计算预期下一状态价值
                expected_next_v = sum(prob * v[s_prime] for s_prime, prob in p_s_prime[s][a].items())
                # 累积加权值
                total += policy[s][a] * (expected_r + gamma * expected_next_v)
                # ------------------------------------------------------------------------------------
                # End of code snippet
                # ------------------------------------------------------------------------------------
            value_new[s] = total
        delta = (sum((v[s] - value_new[s]) ** 2 for s in states)) ** 0.5
        # 判断是否收敛
        if delta < threshold:
            break
        v = value_new.copy()
    return v

def policy_improvement(states, actions, v, p_r, p_s_prime, gamma):
    """
    策略改进：根据当前值函数生成新策略
    """
    new_policy = {}
    for s in states:
        new_policy[s] = {}
        max_q = -float('inf')
        argmax_q_a = None
        # 遍历所有可能的动作
        for a in actions:
            expected_r = sum(prob * r for r, prob in p_r[s][a].items())
            expected_next_v = sum(prob * v[s_prime] for s_prime, prob in p_s_prime[s][a].items())
            q = expected_r + gamma * expected_next_v
            max_q = max(max_q, q)
            argmax_q_a = a if q == max_q else argmax_q_a
        # 生成确定性策略（当前最优动作的概率为1）
        new_policy[s][argmax_q_a] = 1.0
    return new_policy

def policy_iteration(states, actions, p_r, p_s_prime, gamma, initial_policy=None, threshold=1e-5, max_iter=1000):
    # 初始化随机策略（均匀分布）
    if initial_policy is None:
        initial_policy = {s: {"still": 1.0} for s in states}
    
    policy_k = initial_policy.copy()
    v_policy_k_minus_1 = None
    for k in range(max_iter):
        # 1. 策略评估
        v_policy_k = policy_evaluation(states, policy_k, p_r, p_s_prime, gamma)
        # 2. 策略改进
        policy_k_plus_1 = policy_improvement(states, actions, v_policy_k, p_r, p_s_prime, gamma)
        # 检查策略是否稳定
        if v_policy_k_minus_1 is not None:
            delta_v = sum((v_policy_k[s] - v_policy_k_minus_1[s]) ** 2 for s in states) ** 0.5
        else:
            delta_v = float('inf')
        if delta_v < threshold:
            break
        v_policy_k_minus_1 = v_policy_k.copy()
        policy_k = policy_k_plus_1.copy()
    return policy_k, v_policy_k
gamma = 0.9
optimal_policy, optimal_v = policy_iteration(states, actions, p_r, p_s_prime, gamma)


In [None]:
plot_values_and_policy(
    value_dict=optimal_v,
    policy_dict=optimal_policy,
    forbidden_cells=forbidden_areas,
    target_cells=target_areas,
    title="State Value and Policy at Final Iteration",
)

In [None]:

plot_values_and_policy_gif(
    v_history,
    p_history,
    forbidden_areas,
    target_areas,
    gif_save_path='./value_iteration.gif'
)
plot_values_and_policy(
    value_dict=v_history[-1],
    policy_dict=p_history[-1],
    forbidden_cells=forbidden_areas,
    target_cells=target_areas,
    title="State Value and Policy at Final Iteration",
)

In [None]:

# Invoke the workflow again to ensure the probability model respects bounds and prevents KeyErrors
grid_size = 4
terminals = [(3, 3)]           # 终止状态（目标）
forbiddens = [(1, 1), (2, 2)]  # 危险区（惩罚 -10）
actions = ['up', 'down', 'left', 'right', 'still']
gamma = 0.9

states, p_r, p_s_prime = build_models(grid_size, terminals, forbiddens, success_prob=0.8)

# 改进的策略迭代调用
optimal_policy, optimal_v = policy_iteration(states, actions, p_r, p_s_prime, gamma)

# 打印最优策略和状态价值
print("\nOptimal Value Function:")
for i in range(grid_size):
    row = [optimal_v[(i, j)] for j in range(grid_size)]
    print("  ".join([f"{val:.1f}" for val in row]))

print("\nOptimal Policy:")
action_symbols = {'up': '↑', 'down': '↓', 'left': '←', 'right': '→', 'still': '⏺'}
for i in range(grid_size):
    row = []
    for j in range(grid_size):
        s = (i, j)
        if s in terminals:
            row.append('G')      # 目标
        elif s in forbiddens:
            row.append('D')      # 危险
        else:
            best_action = [a for a, prob in optimal_policy[s].items() if prob == 1.0][0]
            row.append(action_symbols[best_action])
    print("  ".join(row))