In [4]:
import torch
from torch import nn
from deeplotx import MultiHeadFeedForward
from stable_baselines3.common.policies import ActorCriticPolicy


class MyActorCritic(nn.Module):
    def __init__(self, feature_dim: int, policy_output_dim: int, value_output_dim: int, device: str = 'cpu', dtype: torch.dtype = torch.float32):
        super().__init__()  
        self.latent_dim_pi = policy_output_dim  
        self.latent_dim_vf = value_output_dim  
        self.policy_net = nn.Sequential(  
            MultiHeadFeedForward(feature_dim=feature_dim, num_heads=50, device=device, dtype=dtype), nn.Linear(in_features=feature_dim, out_features=policy_output_dim, device=torch.device(device), dtype=dtype)
        )  
        self.value_net = nn.Sequential(  
            MultiHeadFeedForward(feature_dim=feature_dim, num_heads=50, device=device, dtype=dtype), nn.Linear(in_features=feature_dim, out_features=value_output_dim, device=torch.device(device), dtype=dtype)
        )
    
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return self.policy_net.forward(x), self.value_net.forward(x)
    
    def forward_actor(self, x: torch.Tensor):  
        return self.policy_net.forward(x)  
  
    def forward_critic(self, x: torch.Tensor):  
        return self.value_net.forward(x)


class MyPolicy(ActorCriticPolicy):
    def _build_mlp_extractor(self) -> None:  
        self.mlp_extractor = MyActorCritic(self.features_dim, 64, 64)

In [5]:
import numpy as np
from gymnasium import spaces, Env


class VectorClassificationEnv(Env):  
    def __init__(self, features: np.ndarray, labels: np.ndarray):  
        super().__init__()  
        self.observation_space = spaces.Box(  
            low=-np.inf, high=np.inf,   
            shape=(features.shape[1],),   
            dtype=np.float32  
        )  
        self.action_space = spaces.Discrete(2)  
        self.features = features
        self.labels = labels  
        self.ptr = 0  
      
    def step(self, action: int):
        true_label = self.labels[self.ptr]  
        reward = 1.0 if action == true_label else -1.0  
        self.ptr += 1  
        terminated = self.ptr >= len(self.features)  
        if not terminated:  
            observation = self.features[self.ptr]  
        else:  
            observation = np.zeros(self.observation_space.shape)  
        return observation, reward, terminated, False, {}  
    
    def reset(self, seed=None, options=None):  
        super().reset(seed=seed)  
        self.ptr = 0  
        observation = self.features[self.ptr]
        return observation, {}

In [3]:
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback  

# 创建 RL 环境
env = VectorClassificationEnv(np.random.randn(1000, 128), np.random.randint(0, 2, 1000))

# 定义 PPO 算法实现
ppo = PPO(  
    policy=MyPolicy,                            # 策略网络类型
    env=env,                                    # 环境实例
    learning_rate=2e-6,                         # 学习率  
    n_steps=2048,                               # 单个 rollout 的采样时间步
    batch_size=64,                              # 批次大小  
    n_epochs=10,                                # 在单个 rollout buffer 上的训练轮数
    gamma=0.99,                                 # 折扣因子
    gae_lambda=0.95,                            # GAE lambda 参数  
    clip_range=0.2,                             # PPO 裁剪范围  
    clip_range_vf=None,                         # 价值函数裁剪范围  
    normalize_advantage=True,                   # 是否标准化优势  
    ent_coef=0.0,                               # 熵系数
    vf_coef=0.5,                                # 价值函数系数
    max_grad_norm=0.5,                          # 梯度裁剪最大范数
    use_sde=False,                              # 是否使用状态依赖探索 (SDE)
    sde_sample_freq=-1,                         # SDE采样频率
    rollout_buffer_class=None,                  # rollout 缓冲区类
    rollout_buffer_kwargs=None,                 # rollout 缓冲区参数  
    target_kl=None,                             # 目标 KL 散度  
    stats_window_size=100,                      # 统计窗口大小  
    tensorboard_log=None,                       # TensorBoard 日志路径, None 表示不记录日志  
    policy_kwargs=None,                         # 策略额外参数  
    verbose=2,                                  # 日志详细程度  
    seed=None,                                  # 随机种子  
    device="auto",                              # 计算设备  
    _init_setup_model=True                      # 是否初始化模型  
)

# 创建训练过程回调函数
eval_callback = EvalCallback(env, best_model_save_path='./logs/', log_path='./logs/', eval_freq=500)  
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')

# 开始训练
ppo.learn(  
    total_timesteps=50000,  
    callback=[eval_callback, checkpoint_callback],  
    log_interval=10,  
    tb_log_name="ppo_run",  
    progress_bar=True  
)

NameError: name 'VectorClassificationEnv' is not defined

In [1]:
vec_env = ppo.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = ppo.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    print('Action:', action)
    print('State:', _states)
    print('Observation:', obs)
    print('Reward:', reward)
    if done:
      obs = vec_env.reset() # env resets automatically

NameError: name 'ppo' is not defined