本代码实现了如下功能：
- 环境为CartPole-v1，详见https://gymnasium.farama.org/environments/classic_control/cart_pole/
- 基线REINFORCE算法+目标网络
- 熵正则
- 改进的奖励计算

In [1]:
import torch

import gymnasium as gym
import torch.nn as nn
import random
import torch
from tqdm.notebook import tqdm
import numpy as np
import time
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

class Policy_ValueModel(nn.Module):
    #输入状态，计算状态价值和动作价值
    def __init__(self, input_dim=4, hidden_dim = 128, action_dim = 2):
        super().__init__()
        self.hiddennet = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU()
        )
        self.valuehead = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.actionhead = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
    def forward(self, state):
        x = self.hiddennet(state)
        value = self.valuehead(x)
        action_prob = self.actionhead(x).softmax(-1)
        return value, action_prob


#定义Agent
class REINFORCEAgent:
    def __init__(self, device='cuda', hidden_dim = 128, gamma = 0.98,game_name='CartPole-v1'):
        self.model = Policy_ValueModel(hidden_dim=hidden_dim).to(device)
        # self.target_model = Policy_ValueModel(hidden_dim=hidden_dim).to(device) #目标网络，缓解自举问题
        # self.target_model.load_state_dict(self.model.state_dict())
        # self.target_model.requires_grad_(False)
        
        self.env = gym.make(game_name)
        
        self.game_name = game_name
        self.device = device
        self.gamma = gamma # 回报率
        self.loss = torch.nn.MSELoss()
        
    def reset(self):
        observation, info = self.env.reset()
        observation = torch.from_numpy(observation).unsqueeze(0)
        return observation
    
    #策略网络
    @torch.no_grad()
    def take_action(self, state: torch.Tensor | np.ndarray, epsilon = 0) -> torch.Tensor:
        bs = 1
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).unsqueeze(0)
        state = state.to(self.device)
        _, action_prob = self.model(state)
        if random.random() < epsilon:
            # 随机动作
            action = torch.randint(0,2,(bs,))
        else:
            #最优动作
            action = torch.argmax(action_prob,dim=1)
            action = action.cpu()
        return action

    def interact(self, action):
        next_state, reward, terminated,truncated, _ = self.env.step(action.item())
        done = terminated or truncated
        next_state = torch.from_numpy(next_state).unsqueeze(0)
        # if done:
        #     reward = -1
        reward = self.compute_reward(next_state, done)
        return next_state, reward, done
        
    @torch.no_grad()
    def test(self):
        env_test = gym.make('CartPole-v1',render_mode = 'human')
        state, _ = env_test.reset()
        done = False
        step = 0 
        input('TESTING, press enter to continue')
        while not done:
            action = self.take_action(state)
            next_state, reward, terminated,truncated, _ = env_test.step(action.item())
            done = terminated or truncated
            state = next_state
            step += 1
            time.sleep(0.01)
        print(f"Running {step} steps, last_state:{state}")
        
    @torch.no_grad()
    def valid(self, ):
        self.model.eval()
        env_test = gym.make(self.game_name)
        result = 0
        for i in range(5):
            state, _ = env_test.reset()
            done = False
            step = 0 
            total_reward = 0
            while not done:
                action = self.take_action(state)
                next_state, reward, terminated,truncated, _ = env_test.step(action.item())
                done = terminated or truncated
                state = next_state
                total_reward += reward
                step += 1
            result += total_reward
        self.model.train()
        return result/5
    
    #EMA更新目标网络
    @torch.no_grad()
    def update_target_model(self, decay = 0.99):
        for param, target_param in zip(self.model.parameters(), self.target_model.parameters()):
            newparam = decay * target_param.data + (1 - decay) * param.data
            target_param.data.copy_(newparam)
    
    #改进奖励算法
    def compute_reward(self, state, done):
        x, x_dot, theta, theta_dot = state[0]
        angle_penalty = 20 * theta ** 2  # 杆子角度惩罚（直立时θ≈0）
        position_penalty = 0.2 * x ** 2  # 小车位置惩罚（中心时x=0）
        # v_penality = abs(x_dot) + abs(theta_dot)
        if done:
            return -10  # 倒下强惩罚
        return 1.0 - angle_penalty - position_penalty #- v_penality  # 奖励函数`

    #计算汇报
    def compute_u(self, rewards):
        us = []
        reversed_r = reversed(rewards)
        for idx, r in enumerate(reversed_r):
            if idx == 0:
                us.append(r)#u_n = r_n
            else:
                us.append(r + self.gamma * us[-1])
        return list(reversed(us))
    

In [2]:
writer = SummaryWriter()    
lr = 1e-3
num_round = 10000
epsilon = 1.0
hidden_dim=64
lambda_reg = 1e-2 #熵正则项
device = 'cuda'
agent = REINFORCEAgent(device,hidden_dim=hidden_dim)
optimizer = torch.optim.AdamW(agent.model.parameters(), lr=lr)
#epsilon衰减
def compute_epsilon(cur_value, decay = 0.999, min_value= 0.01):
    if cur_value < min_value:
        return min_value
    else:
        return decay * cur_value

In [3]:
cur_step = 0
tqdm_bar = tqdm(total=num_round)
train_step = 0
for cur_round in range(num_round):
    epsilon = compute_epsilon(epsilon)
    cur_time = 0
    cur_reward = 0
    state = agent.reset()
    done = False
    losses = []
    trajectories = []
    
    #玩一局游戏
    while not done:
        action = agent.take_action(state, epsilon)
        #take action
        next_state, reward, done = agent.interact(action)
        trajectories.append((state, action, reward))
        state = next_state
    #计算回报
    rewards = [t[2] for t in trajectories]
    us = agent.compute_u(rewards)
    us = torch.tensor(us, dtype=torch.float32,device=device)
    #predict value
    states = torch.cat([t[0] for t in trajectories])
    states = states.to(device)
    value_predict, action_logits = agent.model(states)
    value_predict = value_predict.squeeze(-1)
    diff = value_predict - us
    loss_value = F.mse_loss(value_predict, us)
    # loss_value = diff.detach() * value_predict
    # loss_value = loss_value.mean()
    
    #计算策略梯度
    gamma_loss = torch.pow(agent.gamma, torch.arange(0, states.shape[0], device=device))
    action_probs = action_logits.softmax(dim=1)
    actions = torch.cat([t[1] for t in trajectories])
    action_probs = action_probs.gather(1, actions[:,None].to(device)).squeeze()
    log_action_probs = action_probs.log()
    loss_policy = gamma_loss * (diff+lambda_reg*(1+log_action_probs)).detach() * log_action_probs #带熵正则的策略梯度
    loss_policy = loss_policy.mean()
    
    #更新网络
    loss = loss_value + loss_policy
    loss.backward()
    nn.utils.clip_grad_norm_(agent.model.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()

    #for log 
    cur_reward = sum(rewards)
    train_step += 1
    cur_time += 1
    
        
    # valid
    valid_reward = agent.valid()

    writer.add_scalar('train/value_loss', loss_value.item(), cur_round)
    writer.add_scalar('train/policy_loss', loss_policy.item(), cur_round)
    writer.add_scalar('train/epsilon', epsilon, cur_round)
    writer.add_scalar('valid/reward', valid_reward, cur_round)
    tqdm_bar.set_postfix_str(f'Round {cur_round}/{num_round}, Reward {cur_reward}, Time {len(rewards)}, loss {loss}, traing step {train_step}')
    tqdm_bar.update(1)
    if valid_reward == 500 and epsilon < 0.5:
        break
writer.close()

  0%|          | 0/10000 [00:00<?, ?it/s]

In [4]:
agent.test()

Running 500 steps, last_state:[-0.01002345 -0.01475086 -0.00035427  0.01736287]
