In [1]:
import os
import torch as th
from stable_baselines3.common.callbacks import BaseCallback


class CustomCallBack(BaseCallback):
    def __init__(self, task_name, verbose=0):
        super(CustomCallBack, self).__init__(verbose)
        self.log_dir = f"logs/{task_name}"
        os.makedirs(self.log_dir, exist_ok=True)
        self.before_update_param = None
        self.update_counter = 0

    def _on_rollout_start(self) -> None:
        if self.before_update_param is not None:
            self.compare_weights()

        self.before_update_param = {
            name: param.clone() for name, param in self.model.policy.named_parameters()
        }

    def _on_step(self) -> bool:
        self._write_log("on_step", f"{self.n_calls=}, {self.locals=}")
        return True

    # _on_rollout_endの終了後にポリシーが更新されるため、このタイミングで利用可能になる新しい情報はない...はず...
    def _on_rollout_end(self) -> None:
        self.update_counter += 1

    def compare_weights(self):
        for name, param in self.model.policy.named_parameters():
            before = self.before_update_param[name]
            after = param.clone()
            mean_diff = th.mean(after - before).item()
            max_diff = th.max(after - before).item()
            self._write_log(
                "compare_weights",
                f"Update #{self.update_counter} Layer {name} | {mean_diff=} | {max_diff=}",
            )

    def _write_log(self, file_name, message):
        with open(f"{self.log_dir}/{file_name}.log", "a") as f:
            f.write(message + "\n")

In [2]:
import gymnasium as gym
from stable_baselines3 import PPO

In [3]:
env = gym.make("LunarLander-v2", render_mode="human")
model = PPO("MlpPolicy", env, n_steps=128, n_epochs=5, batch_size=16, verbose=1)
model.learn(total_timesteps=1024, callback=CustomCallBack("luna_lander", verbose=1))
env.close()

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 77       |
|    ep_rew_mean     | -70      |
| time/              |          |
|    fps             | 37       |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 128      |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 92.5         |
|    ep_rew_mean          | -167         |
| time/                   |              |
|    fps                  | 41           |
|    iterations           | 2            |
|    time_elapsed         | 6            |
|    total_timesteps      | 256          |
| train/                  |              |
|    approx_kl            | 0.0004810877 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    en

: 