# 构建世界模型
1. state space: 状态空间，包括所有可能的状态
2. action space: 动作空间，包括所有可能的动作，默认包括`up`, `down`, `left`, `right`以及`still`五个动作
3. p(r | s, a): 状态-动作转移概率，即在状态s下执行动作a之后，环境给出的奖励r的概率，这里的奖励r是实际上来自于所有可能的s'的奖励的期望值
4. p(s' | s, a): 状态转移概率，即在状态s下执行动作a之后，环境转移到状态s'的概率。可以越界，越界的reward设置为-100




In [None]:
from utils.grids import build_models, actions, plot_values_and_policy
# 定义网格参数
grid_size = 5

target_areas = [(3, 2)]  # 终止和危险状态
forbidden_areas = [(1, 1), (1, 2), (2, 2), (3, 1), (3, 3), (4, 1)] # 禁止状态

states, p_r, p_s_prime = build_models(grid_size, target_areas, forbidden_areas, success_prob=1)
print(len(states), states[:2])
v_initial = {s: 0 for s in states}
p_initial = {s: {"still": 1.0} for s in states}
plot_values_and_policy(
    value_dict=v_initial, 
    policy_dict=p_initial, 
    target_cells=target_areas, 
    forbidden_cells=forbidden_areas, 
    title="Initial State Value Function")
print(p_r)
print(p_s_prime)
print(p_r[(0, 0)])
print(p_s_prime.keys())

In [None]:


def value_iteration(p_r, p_s_prime, v_init, states, target_areas, actions, gamma, threshold=1e-5, max_iter=100, save_history=False):
    v = v_init.copy()
    p = None
    if save_history:
        v_history = []
        p_history = []
    for k in range(max_iter):
        delta = 0
        value_new = {}
        policy_new = {}
        for s in states:
            policy_new[s] = {}
            if s in target_areas:
                value_new[s] = 0  # 终止状态价值保持0（已终止）
                policy_new[s] = {"still": 1.0}
                continue
            max_q = -float('inf')
            argmax_q_a = None
            for a in actions:
                # ------------------------------------------------------------------------------------
                # 计算q_k(s, a)的值。
                # q_k(s, a) = expected reward + gamma * expected value of next state
                # 可以通过调取p_r的值来计算expected_r, 然后通过列表推导计算expected_v，最后计算q_k(s, a)
                # 最后获取最大的q_k(s, a)作为max_q，获取最大的a作为max_q_a
                # Expected code: ~5 lines
                # ------------------------------------------------------------------------------------
                expected_r = sum(prob * r for r, prob in p_r[s][a].items())
                expected_v = sum(prob * v[s_prime] for s_prime, prob in p_s_prime[s][a].items())
                q = expected_r + gamma * expected_v
                max_q = max(max_q, q)
                argmax_q_a = a if q == max_q else argmax_q_a
            # policy update
            policy_new[s][argmax_q_a] = 1.0
            # value update
            value_new[s] = max_q
        # norm between old and new value function
        delta = (sum((v[s] - value_new[s]) ** 2 for s in states)) ** 0.5
        if save_history:
            v_history.append(value_new.copy())
            p_history.append(policy_new.copy())
        v = value_new.copy()
        p = policy_new.copy()
        if delta < threshold:
            break
    return_dict = {
        'v': v,
        'p': p,
    }
    if save_history:
        return_dict['v_history'] = v_history
        return_dict['p_history'] = p_history
    return return_dict

# 运行算法
gamma = 0.9
return_dict = value_iteration(p_r, p_s_prime, v_initial, states, target_areas, actions, gamma, save_history=True)
v_optimal = return_dict['v']
policy_optimal = return_dict['p']
v_history = return_dict['v_history']
p_history = return_dict['p_history']

In [None]:
import os
import tempfile
from PIL import Image


file_paths = []
with tempfile.TemporaryDirectory() as tmpdir:
    # 画图
    for idx, (v_h, p_h) in enumerate(zip(v_history, p_history)):
        temp_file = os.path.join(tmpdir, "plot_" + str(idx) + ".png")
        plot_values_and_policy(
            value_dict=v_h,
            policy_dict=p_h,
            forbidden_cells=forbidden_areas,
            target_cells=target_areas,
            forbidden_color='coral',       # 使用珊瑚色表示禁止区
            target_color='lightgreen',      # 使用浅绿色表示目标区
            bg_color='whitesmoke',          # 常规背景设为烟白色
            title="State Value Function at Iteration " + str(idx+1),
            save_path=temp_file
        )
        file_paths.append(temp_file)
    images = [Image.open(f) for f in file_paths]

# 保存为GIF（设置帧间隔和循环次数）
gif_save_path = "value_iteration.gif"
images[0].save(
    gif_save_path,
    save_all=True,
    append_images=images[1:],
    duration=500,  # 每帧持续时间（单位：毫秒）
    loop=0,        # 0表示无限循环
    optimize=True
)
plot_values_and_policy(
    value_dict=v_history[-1],
    policy_dict=p_history[-1],
    forbidden_cells=forbidden_areas,
    target_cells=target_areas,
    title="State Value and Policy at Final Iteration",
)

In [None]:
import numpy as np
import random

def policy_evaluation(states, policy, p_r, p_s_prime, gamma, threshold=1e-5, max_iter=100):
    """
    策略评估：计算当前策略下的状态价值函数
    """
    v = {s: 0.0 for s in states}
    for _ in range(max_iter):
        delta = 0
        value_new = {}
        for s in states:
            total = 0
            # 遍历所有动作（根据策略的分布）
            for a in policy[s]:
                if policy[s][a] == 0:
                    continue  # 概率为0的动作可直接跳过
                # 计算预期奖励
                expected_r = sum(prob * r for r, prob in p_r[s][a].items())
                # 计算预期下一状态价值
                expected_next_v = sum(prob * v[bounce_back(s_prime, grid_size)] for s_prime, prob in p_s_prime[s][a].items())
                # 累积加权值
                total += policy[s][a] * (expected_r + gamma * expected_next_v)
            value_new[s] = total
            delta = max(delta, abs(value_new[s] - v[s]))
        # 判断是否收敛
        if delta < threshold:
            break
        v = value_new.copy()
    return v

def policy_improvement(states, actions, v, p_r, p_s_prime, gamma):
    """
    策略改进：根据当前值函数生成新策略
    """
    new_policy = {}
    for s in states:
        q_values = {}
        # 遍历所有可能的动作
        for a in actions:
            expected_r = sum(prob * r for r, prob in p_r[s][a].items())
            expected_next_v = sum(prob * v[bounce_back(s_prime, grid_size)] for s_prime, prob in p_s_prime[s][a].items())
            q = expected_r + gamma * expected_next_v
            q_values[a] = q
        # 找到最大Q值的动作（随机选择平局）
        max_q = max(q_values.values())
        best_actions = [a for a in q_values if q_values[a] == max_q]
        best_action = random.choice(best_actions)
        # 生成确定性策略（当前最优动作的概率为1）
        new_policy[s] = {a: 0.0 for a in actions}
        new_policy[s][best_action] = 1.0
    return new_policy

def policy_iteration(states, actions, p_r, p_s_prime, gamma, initial_policy=None, threshold=1e-5, max_iter=1000):
    # 初始化随机策略（均匀分布）
    if initial_policy is None:
        initial_policy = {s: {a: 1/len(actions) for a in actions} for s in states}
    
    policy = initial_policy.copy()
    for _ in range(max_iter):
        # 1. 策略评估
        v = policy_evaluation(states, policy, p_r, p_s_prime, gamma)
        # 2. 策略改进
        new_policy = policy_improvement(states, actions, v, p_r, p_s_prime, gamma)
        # 检查策略是否稳定
        if all(np.allclose(list(policy[s].values()), list(new_policy[s].values())) for s in states):
            break
        policy = new_policy
    return policy, v

In [None]:

# Invoke the workflow again to ensure the probability model respects bounds and prevents KeyErrors
grid_size = 4
terminals = [(3, 3)]           # 终止状态（目标）
forbiddens = [(1, 1), (2, 2)]  # 危险区（惩罚 -10）
actions = ['up', 'down', 'left', 'right', 'still']
gamma = 0.9

states, p_r, p_s_prime = build_models(grid_size, terminals, forbiddens, success_prob=0.8)

# 改进的策略迭代调用
optimal_policy, optimal_v = policy_iteration(states, actions, p_r, p_s_prime, gamma)

# 打印最优策略和状态价值
print("\nOptimal Value Function:")
for i in range(grid_size):
    row = [optimal_v[(i, j)] for j in range(grid_size)]
    print("  ".join([f"{val:.1f}" for val in row]))

print("\nOptimal Policy:")
action_symbols = {'up': '↑', 'down': '↓', 'left': '←', 'right': '→', 'still': '⏺'}
for i in range(grid_size):
    row = []
    for j in range(grid_size):
        s = (i, j)
        if s in terminals:
            row.append('G')      # 目标
        elif s in forbiddens:
            row.append('D')      # 危险
        else:
            best_action = [a for a, prob in optimal_policy[s].items() if prob == 1.0][0]
            row.append(action_symbols[best_action])
    print("  ".join(row))