本代码实现了如下功能：
- 环境为CartPole-v1，详见https://gymnasium.farama.org/environments/classic_control/cart_pole/
- 原始SARSA算法,原始SARSA算法+目标网络
- 加噪网络
- 改进的奖励计算

In [None]:
import torch

import gymnasium as gym
import torch.nn as nn
import random
import torch
from tqdm.notebook import tqdm
import numpy as np
import time
from torch.utils.tensorboard import SummaryWriter
#加噪网络
class NoiseLinear(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        self.p_noise_weight = nn.Parameter(torch.full(self.weight.size(), 0.))
        self.p_noise_bias = nn.Parameter(torch.full(self.bias.size(), 0.))
    def forward(self, x):
        weight_noise = torch.randn_like(self.p_noise_weight,device=x.device)
        bias_noise = torch.randn_like(self.p_noise_bias,device=x.device)
        out = super().forward(x)
        if self.training:
            out = out + torch.mm(x,(self.p_noise_weight * weight_noise).T) + self.p_noise_bias * bias_noise
        return out
        
class SARSANet(nn.Module):
    def __init__(self, input_dim=4, hidden_dim = 128, action_dim = 2):
        super().__init__()
        self.model = nn.Sequential(
            NoiseLinear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            NoiseLinear(hidden_dim, action_dim),
        )
        
    def forward(self, state: torch.Tensor | np.ndarray, device = 'cuda'):
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).unsqueeze(0)
        state = state.to(device)
        #输出Q值
        return self.model(state)

#定义Agent
class SARSAAgent:
    def __init__(self, device, hidden_dim = 128, gamma = 0.98,game_name='CartPole-v1'):
        self.model = SARSANet(hidden_dim=hidden_dim).to(device)
        self.target_model = SARSANet(hidden_dim=hidden_dim).to(device) # Q学习+目标网络，缓解自举问题
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.requires_grad_(False)
        
        self.env = gym.make(game_name)
        
        self.game_name = game_name
        self.device = device
        self.gamma = gamma # 回报率
        self.loss = torch.nn.MSELoss()
        
    def reset(self):
        observation, info = self.env.reset()
        observation = torch.from_numpy(observation).unsqueeze(0)
        return observation
    
    #epsilon-greedy策略函数
    @torch.no_grad()
    def take_action(self, state: torch.Tensor | np.ndarray, epsilon: float) -> torch.Tensor:
        bs = 1
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).unsqueeze(0)
        state = state.to(self.device)
        q_value = self.model(state)
        if random.random() < epsilon:
            # 随机动作
            action = torch.randint(0,2,(bs,))
        else:
            #最优动作
            action = torch.argmax(q_value,dim=1)
        return action

    def interact(self, action):
        next_state, reward, terminated,truncated, _ = self.env.step(action.item())
        done = terminated or truncated
        next_state = torch.from_numpy(next_state).unsqueeze(0)
        reward = self.compute_reward(next_state, done)
        return next_state, reward, done
        
    @torch.no_grad()
    def test(self):
        env_test = gym.make('CartPole-v1',render_mode = 'human')
        state, _ = env_test.reset()
        done = False
        step = 0 
        input('TESTING, press enter to continue')
        while not done:
            action = self.take_action(state, epsilon=0)
            next_state, reward, terminated,truncated, _ = env_test.step(action.item())
            done = terminated or truncated
            state = next_state
            step += 1
            time.sleep(0.01)
        print(f"Running {step} steps, last_state:{state}")
        
    @torch.no_grad()
    def valid(self, ):
        self.model.eval()
        env_test = gym.make(self.game_name)
        result = 0
        for i in range(5):
            state, _ = env_test.reset()
            done = False
            step = 0 
            total_reward = 0
            while not done:
                action = self.take_action(state, epsilon=0)
                next_state, reward, terminated,truncated, _ = env_test.step(action.item())
                done = terminated or truncated
                state = next_state
                total_reward += reward
                step += 1
            result += total_reward
        self.model.train()
        return result/5
    
    #EMA更新目标网络
    @torch.no_grad()
    def update_target_model(self, decay = 0.99):
        for param, target_param in zip(self.model.parameters(), self.target_model.parameters()):
            newparam = decay * target_param.data + (1 - decay) * param.data
            target_param.data.copy_(newparam)
    
    #改进奖励算法
    def compute_reward(self, state, done):
        x, x_dot, theta, theta_dot = state[0]
        angle_penalty = theta ** 2  # 杆子角度惩罚（直立时θ≈0）
        position_penalty = 0.1 * x ** 2  # 小车位置惩罚（中心时x=0）
        if done:
            return -10.0  # 倒下强惩罚
        return 1.0 - angle_penalty - position_penalty  # 奖励函数`


In [None]:
writer = SummaryWriter()    
lr = 1e-3
num_round = 10000
target_update_step = 1
epsilon = 1.0
hidden_dim=128

device = 'cuda'
agent = SARSAAgent(device,hidden_dim=hidden_dim)
optimizer = torch.optim.AdamW(agent.model.parameters(), lr=lr)
#epsilon衰减
def compute_epsilon(cur_value, decay = 0.999, min_value= 0.01):
    if cur_value < min_value:
        return min_value
    else:
        return decay * cur_value

In [None]:
cur_step = 0
tqdm_bar = tqdm(total=num_round)
train_step = 0
for cur_round in range(num_round):
    epsilon = compute_epsilon(epsilon)
    cur_time = 0
    cur_reward = 0
    state = agent.reset()
    done = False
    losses = []
    while not done:
        action = agent.take_action(state, epsilon=epsilon)
        #compute q_t
        q_t = agent.model(state)[:,action.item()]
        #take action
        next_state, reward, done = agent.interact(action)
        #sample action_t+1
        with torch.no_grad():
            action_next = agent.take_action(next_state, epsilon=epsilon)
            q_next = agent.target_model(next_state)[:,action_next.item()]
            td_target = reward + agent.gamma * q_next
        
        loss = agent.loss(q_t, td_target)
        
        loss.backward()
        
        #update model
        nn.utils.clip_grad_norm_(agent.model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()
        #update target network
        if cur_step % target_update_step == 0:
            agent.update_target_model()
        
        state = next_state
        #for log 
        losses.append(loss.item())
        cur_reward += reward
        train_step += 1
        cur_time += 1
    
        
    # valid
    valid_reward = agent.valid()
        
    loss = np.mean(losses) if len(losses) > 0 else 0
    writer.add_scalar('train/loss', loss, cur_round)
    writer.add_scalar('train/reward', cur_reward, cur_round)
    writer.add_scalar('train/epsilon', epsilon, cur_round)
    writer.add_scalar('valid/reward', valid_reward, cur_round)
    tqdm_bar.set_postfix_str(f'Round {cur_round}/{num_round}, Reward {cur_reward}, Time {cur_time}, loss {loss}, traing step {train_step}')
    tqdm_bar.update(1)
    if valid_reward == 500 and cur_round > 5000:
        break
writer.close()

In [1]:
agent.test()

NameError: name 'agent' is not defined