In [1]:
import torch
from torch import nn
from deeplotx import MultiHeadFeedForward
from stable_baselines3.common.policies import ActorCriticPolicy


class MyActorCriticPolicy(nn.Module):
    def __init__(self, feature_dim: int, policy_output_dim: int, value_output_dim: int, device: str = 'cpu', dtype: torch.dtype = torch.float32):
        super().__init__()  
        self.latent_dim_pi = policy_output_dim  
        self.latent_dim_vf = value_output_dim  
        self.policy_net = nn.Sequential(  
            MultiHeadFeedForward(feature_dim=feature_dim, num_heads=3, device=device, dtype=dtype), nn.Linear(in_features=feature_dim, out_features=policy_output_dim, device=torch.device(device), dtype=dtype)
        )  
        self.value_net = nn.Sequential(  
            MultiHeadFeedForward(feature_dim=feature_dim, num_heads=3, device=device, dtype=dtype), nn.Linear(in_features=feature_dim, out_features=value_output_dim, device=torch.device(device), dtype=dtype)
        )
    
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return self.policy_net.forward(x), self.value_net.forward(x)
    
    def forward_actor(self, x: torch.Tensor):  
        return self.policy_net.forward(x)  
  
    def forward_critic(self, x: torch.Tensor):  
        return self.value_net.forward(x)


class MyRLModel(ActorCriticPolicy):
    def _build_mlp_extractor(self) -> None:  
        self.mlp_extractor = MyActorCriticPolicy(self.features_dim, 64, 64)

In [2]:
import gymnasium
from stable_baselines3 import PPO

env = gymnasium.make("CartPole-v1", render_mode="human")
ppo = PPO(MyRLModel, env, verbose=1)
ppo.learn(total_timesteps=3000, progress_bar=True)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


Output()

  from pkg_resources import resource_stream, resource_exists


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21       |
|    ep_rew_mean     | 21       |
| time/              |          |
|    fps             | 46       |
|    iterations      | 1        |
|    time_elapsed    | 44       |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 27.8        |
|    ep_rew_mean          | 27.8        |
| time/                   |             |
|    fps                  | 45          |
|    iterations           | 2           |
|    time_elapsed         | 89          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.016178887 |
|    clip_fraction        | 0.17        |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.681      |
|    explained_variance   | 0.00107     |
|    learning_rate        | 0.

<stable_baselines3.ppo.ppo.PPO at 0x2bf9589d3a0>

In [3]:
vec_env = ppo.get_env()
obs = vec_env.reset()
for i in range(1000):
    action, _states = ppo.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)
    vec_env.render()
    # VecEnv resets automatically
    if done:
      obs = vec_env.reset()