In [None]:
from __future__ import annotations

import abc
import os
import copy
import math
import time
import itertools
import functools
import random
import warnings
import pathlib
from typing import Iterable, Callable, Any, NamedTuple

import numpy as np
import pandas as pd
import torch
import torch.utils.data
import matplotlib.pyplot as plt
import gymnasium as gym
import lightning
import lightning.pytorch
import lightning.pytorch.loggers
import lightning.pytorch.callbacks
import gymnasium as gym
import lovely_tensors
import lovely_numpy
import wandb
import wandb.wandb_run

from tqdm.auto import tqdm


def lovely(x):
    "summarizes important tensor properties"
    if isinstance(x, np.ndarray):
        return lovely_numpy.lovely(x)
    elif isinstance(x, torch.Tensor):
        return lovely_tensors.lovely(x)
    else:
        warnings.warn(f"lovely: unknown type {type(x)}")
        return str(x)

# Replay buffer

In [None]:
# Copied and adapted from previous assignment
class Transition(NamedTuple):
    state: torch.Tensor | np.ndarray
    action: torch.Tensor | np.ndarray
    reward: torch.Tensor | np.ndarray
    next_state: torch.Tensor | np.ndarray
    action_log_prob: torch.Tensor | np.ndarray
    done: torch.Tensor | np.ndarray | bool

In [None]:
# Copied and adapted from previous assignment
class ReplayBuffer(torch.utils.data.IterableDataset):
    def __init__(self, capacity: int, seed = None):
        self.capacity = capacity
        self.buffer: list[Transition] = []
        self.seed = seed
        self.rng = random.Random(seed)
    
    def __next__(self):
        transition = self.rng.choice(self.buffer)
        return Transition(
            state = torch.tensor(transition.state, dtype=torch.float32),
            action = torch.tensor(transition.action, dtype=torch.int64),
            reward = torch.tensor(transition.reward, dtype=torch.float32),
            next_state = torch.tensor(transition.next_state, dtype=torch.float32),
            action_log_prob = torch.tensor(transition.action_log_prob, dtype=torch.float32),
            done = torch.tensor(transition.done.item(), dtype=torch.bool),
        )
    
    def __iter__(self):
        return self
    
    def add(self, t: Transition):
        self.buffer.append(t)
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)

# Environment prep

In [None]:
class ResetArgsWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, reset_kwargs: dict[str, Any]):
        super().__init__(env)
        self.reset_kwargs = reset_kwargs

    def reset(self, *args, **kwargs):
        kwargs = {**self.reset_kwargs, **kwargs}
        return self.env.reset(*args, **kwargs)
    
    def __repr__(self):
        return f"ResetArgsWrapper({self.env}, {self.reset_kwargs})"


def make_bipedal_walker_env(hardcore=False, seed: int | None = None, render_mode: str = "rgb_array", **kwargs):
    env: gym.Environment = gym.make(
        'BipedalWalker-v3',
        render_mode=render_mode,
        hardcore=hardcore,
        **kwargs
    )

    if seed is not None:
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        env = ResetArgsWrapper(env, {"seed": seed})

    return env

In [None]:
seeds = [[0, 0, 0], [1, 2, 3]]

for hardcore in [False, True]:
    fig = plt.figure()
    for row, row_seeds in enumerate(seeds):
        for col, seed in enumerate(row_seeds):
            env = make_bipedal_walker_env(seed=seed, hardcore=hardcore)
            env.reset()
            img = env.render()
            env.close()
            fig.add_subplot(len(seeds), len(row_seeds), row * len(row_seeds) + col + 1)
            plt.imshow(img)
            # set img title
            plt.title(f"seed = {seed}")
            plt.axis('off')
    fig.tight_layout()
    fig.suptitle(f"hardcore = {hardcore}")
    fig.show()

# Architecture

In [None]:
class Residual(torch.nn.Module):
    def __init__(
        self,
        around: torch.nn.Module
    ) -> None:
        super().__init__()
        self.around = around

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y: torch.Tensor = self.around(x)
        return x + y


class Block(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        norm_layer: torch.nn.Module,
        hidden_activation: torch.nn.Module | None,
        output_activation: torch.nn.Module | None,
        num_layers: int = 2,
        dropout: float = 0.1,
        is_residual: bool = True,
    ) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.is_residual = is_residual

        dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
        layers = []

        for i in range(num_layers):
            layers.append(torch.nn.Linear(dims[i], dims[i + 1]))

            if i < num_layers - 1:
                if hidden_activation is not None:
                    layers.append(copy.deepcopy(hidden_activation))

            else:
                if norm_layer is not None:
                    layers.append(copy.deepcopy(norm_layer))
                if output_activation is not None:
                    layers.append(copy.deepcopy(output_activation))
                if dropout > 0:
                    layers.append(torch.nn.Dropout(dropout))

        self.nn = torch.nn.Sequential(*layers)

        if self.is_residual:
            if self.input_dim != self.output_dim:
                raise ValueError("Residual block input and output dimensions must match")
            self.nn = Residual(self.nn)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.nn(x)


class ResFeedForward(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_blocks: int,
        activation: torch.nn.Module,
        output_dim: int | None = None,
        squeeze_dim: int | None = None,
        use_skip: bool = True,
    ) -> None:
        if squeeze_dim is None:
            squeeze_dim = hidden_dim
        if output_dim is None:
            output_dim = hidden_dim

        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.squeeze_dim = squeeze_dim
        self.output_dim = output_dim
        self.num_blocks = num_blocks

        dims = [input_dim] + [hidden_dim] * num_blocks + [output_dim]
        blocks = []

        for i, (block_in_dim, block_out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            block_activation = activation if i < num_blocks - 1 else None
            block = Block(
                input_dim=block_in_dim,
                hidden_dim=squeeze_dim,
                output_dim=block_out_dim,
                norm_layer=torch.nn.BatchNorm1d(block_out_dim),
                hidden_activation=activation,
                output_activation=block_activation,
                is_residual=use_skip and block_in_dim == block_out_dim,
            )
            blocks.append(block)
            
        self.nn = torch.nn.Sequential(*blocks)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.nn(x)
    

class ContinuousActorNet(torch.nn.Module):
    def __init__(self, backbone: torch.nn.Module, output_dim: int, epsilon: float = 0.001) -> None:
        super().__init__()
        self.backbone = backbone
        self.head_mu = torch.nn.LazyLinear(output_dim)
        self.head_sigma = torch.nn.LazyLinear(output_dim)
        
    def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        state = self.backbone(state)
        mu_logits = self.head_mu(state)
        sigma_logits = self.head_sigma(state)
        return mu_logits, sigma_logits
    
    def distribution(self, mu_logits: torch.Tensor, sigma_logits: torch.Tensor) -> torch.distributions.Distribution:
        mu_logits = torch.nn.functional.tanh(mu_logits)
        sigma_logits = torch.nn.functional.softplus(sigma_logits) + 1e-5
        distribution = torch.distributions.MultivariateNormal(mu_logits, torch.diag_embed(sigma_logits))
        return distribution

    def act(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        mu_logits, sigma_logits = self.forward(state)
        return self.sample(mu_logits, sigma_logits)

    def sample(self, mu_logits: torch.Tensor, sigma_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        distribution = self.distribution(mu_logits, sigma_logits)
        action = distribution.rsample()
        log_prob = distribution.log_prob(action)
        return action, log_prob


class CriticNet(torch.nn.Module):
    def __init__(self, backbone: torch.nn.Module) -> None:
        super().__init__()
        self.backbone = backbone
        self.head_state_value = torch.nn.LazyLinear(1)
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        state = self.backbone(state)
        state_value = self.head_state_value(state)
        return state_value

    
class Agent(abc.ABC):
    def select_action(self, observations: torch.Tensor, collecting_data: bool) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Returns actions and log_probs for a batch of observations
        """


# Measuring Score

In [None]:
# Copied and adapted from previous assignment
def measure_scores_parallel(
    agent: Agent,
    make_env: Callable,
    make_env_kwargs: list[dict[str, Any]],
    num_episodes: int | None = None,
    limit_steps: int = None,
    return_unfinished: bool = False,
    video_folder: pathlib.Path | str | None = None,
    select_action_kwargs: dict[str, Any] = None,
) -> Iterable[float]:
    if video_folder is not None:
        video_folder = pathlib.Path(video_folder)
        video_folder.mkdir(parents=True, exist_ok=True)
    if select_action_kwargs is None:
        select_action_kwargs = {}
    if isinstance(make_env_kwargs, dict):
        make_env_kwargs = [make_env_kwargs] * num_episodes
    elif isinstance(make_env_kwargs, list):
        if num_episodes is None:
            num_episodes = len(make_env_kwargs)
    if len(make_env_kwargs) != num_episodes:
        raise ValueError("num_episodes does not match length of make_env_kwargs")

    def create_env(i, kwargs):
        env = make_env(**kwargs)
        if video_folder is not None:
            env = gym.wrappers.RecordVideo(env, video_folder/str(i), disable_logger=True)
        return env

    env = gym.vector.AsyncVectorEnv([
        # lambda: create_env(i, kwargs) would not work because of late binding
        functools.partial(create_env, i, kwargs) 
        for i, kwargs in enumerate(make_env_kwargs)
    ])
    
    if limit_steps is None:
        limit_steps = env.get_attr("spec")[0].max_episode_steps
    if limit_steps is None:
        warnings.warn("No limit_steps provided and env does not have max_episode_steps set. This might lead to infinite loops.")
    
    state, _ = env.reset()
    score = np.zeros(num_episodes)
    done_in_past = np.zeros(num_episodes, dtype=bool)
    done_at_the_moment = np.zeros(num_episodes, dtype=bool)

    if limit_steps is None:
        steps = itertools.count()
    else:
        steps = range(limit_steps)
    for step in tqdm(steps, "Measuring score - step"):
        score[done_at_the_moment] = 0
        action, log_prob = agent.select_action(state, **select_action_kwargs)
        state, reward, terminated, truncated, _ = env.step(action)
        done_at_the_moment = terminated | truncated
        # vector env resets done envs automatically and starts over
        # we only want to update score of episodes that were not done before
        # and output the score for that env only once
        score += reward
        done = done_in_past | done_at_the_moment
        done_for_first_time = done_in_past != done # this is not the same as done_at_the_moment, because of autoreset
        done_in_past = done
        
        for idx in done_for_first_time.nonzero()[0]:
            yield score[idx].item()

        if done_in_past.all():
            break
      
    env.close()

    if return_unfinished:
        yield from score[~done_in_past].tolist()

In [None]:
# Copied from previous assignment
def cvar(scores: np.ndarray | pd.Series, *, percent: float) -> float:
    """
    Computes conditional value at risk (CVaR) of a given set of scores.
    CVaR is the expected value of the worst percents of the scores.
    """
    assert 0 < percent < 100
    if not isinstance(scores, pd.Series):
        scores = pd.Series(scores)
    quantile = scores.quantile(percent / 100)
    return scores[scores <= quantile].mean()
    
cvar(np.arange(100), percent=25)

# PPO

In [None]:
from lightning.pytorch.utilities.types import STEP_OUTPUT

class PPOCriticLossInfo(NamedTuple):
    td_loss: torch.Tensor
    td_advantage: torch.Tensor

class PPOActorLossInfo(NamedTuple):
    ppo_loss: torch.Tensor
    entropy: torch.Tensor
    was_clipped: torch.Tensor
    importance_ratio: torch.Tensor
    

class PPO(lightning.LightningModule, Agent):
    def __init__(
        self,
        backbone_config: dict[str, Any],
        output_dim: int,
        optimizer_class: str,
        optimizer_config: dict[str, Any],
        make_env: Callable | None = None,
        video_root_folder: pathlib.Path | str | None = None,
        importance_sampling_clip_range: tuple[float, float] = (0.8, 1.2),
        discount: float = 0.99,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore=["make_env", "video_root_folder"])
        self.actor = ContinuousActorNet(ResFeedForward(**backbone_config), output_dim)
        self.critic = CriticNet(ResFeedForward(**backbone_config))
        self.critic_loss = torch.nn.MSELoss()
        self.make_env = make_env
        self.validation_step_scores = []
        self.validation_start_time = None
        if video_root_folder is None:
            self.video_root_folder = None
        else:
            self.video_root_folder = pathlib.Path(video_root_folder)
            self.video_root_folder.mkdir(parents=True, exist_ok=True)

    def configure_optimizers(self):
        optimizer_class = getattr(torch.optim, self.hparams.optimizer_class)
        optimizer = optimizer_class(self.parameters(), **self.hparams.optimizer_config)
        return optimizer
    
    def forward(self, observations: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return self.actor(observations)
    
    def select_action(self, observation: torch.Tensor | np.ndarray, collecting_data: bool = False) -> tuple[np.ndarray, np.ndarray]:
        if isinstance(observation, np.ndarray):
            param = next(self.actor.parameters())
            observation = torch.from_numpy(observation).to(param)
        action, log_prob = self.actor.act(observation)
        action: torch.Tensor = action.detach().cpu().float().numpy()
        log_prob: torch.Tensor = log_prob.detach().cpu().float().numpy()
        return action, log_prob

    def _critic_loss(self, batch: Transition) -> PPOCriticLossInfo:
        next_state_value_estimate = self.critic(batch.next_state).detach().flatten()
        curr_state_value_estimate = self.critic(batch.state).flatten()
        td_target = batch.reward + self.hparams.discount * next_state_value_estimate * (~batch.done)
        td_loss = self.critic_loss(curr_state_value_estimate, td_target)
        td_advantage = (td_target - curr_state_value_estimate).detach()
        return PPOCriticLossInfo(td_loss, td_advantage)

    def _actor_loss(self, batch: Transition, td_advantage: torch.Tensor) -> PPOActorLossInfo:
        mu_logits, sigma_logits = self.actor(batch.state)
        distribution = self.actor.distribution(mu_logits, sigma_logits)
        action_log_prob_new = distribution.log_prob(batch.action)
        action_log_prob_old = batch.action_log_prob
        clip_min, clip_max = self.hparams.importance_sampling_clip_range
        importance_ratio = torch.exp(action_log_prob_new - action_log_prob_old)
        importance_ratio_clipped = torch.clip(importance_ratio, clip_min, clip_max)
        weighed_td_advantage = td_advantage * importance_ratio
        weighed_td_advantage_clipped = td_advantage * importance_ratio_clipped
        objective = torch.minimum(weighed_td_advantage, weighed_td_advantage_clipped)
        was_clipped = weighed_td_advantage_clipped < weighed_td_advantage 
        ppo_loss = -objective.mean()
        entropy = distribution.entropy().mean()
        return PPOActorLossInfo(ppo_loss, entropy, was_clipped, importance_ratio)

    def training_step(self, batch: Transition) -> torch.Tensor:        
        critic_info = self._critic_loss(batch)
        actor_info = self._actor_loss(batch, critic_info.td_advantage)

        self.log_dict({
            "critic_td_loss": critic_info.td_loss,
            "actor_ppo_loss": actor_info.ppo_loss,
            "td_advantage": critic_info.td_advantage.mean(),
            "was_clipped": actor_info.was_clipped.float().mean(),
            "importance_ratio_min": actor_info.importance_ratio.min(),
            "importance_ratio_max": actor_info.importance_ratio.max(),
        })
        
        # TODO add entropy regularization to encourage early exploration
        return critic_info.td_loss + actor_info.ppo_loss

    def validation_step(self, make_env_kwargs: list[dict[str, Any]], batch_idx):
        """
        Run a batch of validation episodes and log the scores.
        make_env_kwargs is a list of dictionaries of keyword arguments, each of which is passed to make_env.
        """
        if self.video_root_folder is None:
            video_folder = None
        else:
            video_folder = self.video_root_folder / f"step-{self.trainer.global_step}-batch-{batch_idx}"
        
        scores = measure_scores_parallel(
            agent=self,
            make_env=self.make_env,
            make_env_kwargs=make_env_kwargs,
            video_folder=video_folder,
        )
        scores = list(scores)
        self.validation_step_scores.extend(scores)

        if video_folder is not None:
            for logger in self.loggers:
                if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                    for worker_dir in video_folder.glob("*"):
                        logger.log_metrics({
                            f"video/batch_{batch_idx}_worker_{worker_dir.name}": wandb.Video(str(video))
                            for video in worker_dir.glob("*0.mp4")
                        }, step=self.trainer.global_step)
                    break

        return scores
    
    def on_validation_epoch_start(self):
        self.validation_step_scores.clear()
        self.validation_start_time = time.time()

    def on_validation_epoch_end(self):
        scores = torch.tensor(self.validation_step_scores)
        validation_num_seconds = time.time() - self.validation_start_time
        self.log_simulation_metrics(scores, validation_num_seconds)

    def simulation_metrics(self, scores: torch.Tensor | np.ndarray) -> dict[str, float]:
        metrics = {
            "score/avg": float(scores.mean()),
            "score/std": float(scores.std()),
            "score/min": float(scores.min()),
            "score/max": float(scores.max()),
            "score/num_episodes": float(len(scores)),
        }
        for p in [25, 50, 75]:
            metrics[f"score/percentile_{p}"] = float(np.percentile(scores, p))
            metrics[f"score/cvar_{p}"] = float(cvar(scores, percent=p))
        return metrics

    def log_simulation_metrics(self, scores: Iterable[float], validation_num_seconds: float, metrics: dict[str, float] | None = None):
        if metrics is None:
            metrics = self.simulation_metrics(scores)
        metrics["score/validation_total_num_seconds"] = validation_num_seconds

        for logger in self.loggers:
            if isinstance(logger, lightning.pytorch.loggers.WandbLogger):
                logger.log_metrics({f"score/score": wandb.Histogram(scores)}, step=self.trainer.global_step)

        self.log_dict(metrics)

# Experience collection

In [None]:
# Copied and adjusted from previous assignment
class DataGeneratorParallel:
    def __init__(
        self,
        agent: Agent,
        make_env: Callable[[int | None, bool], gym.Env],
        make_env_kwargs: dict[str, Any] | list[dict[str, Any]] | None = None,
        num_parallel_envs: int | None = None,
    ) -> None:
        if num_parallel_envs is None:
            if make_env_kwargs is None:
                raise ValueError("Either make_env_kwargs or num_parallel_envs must be set")
            if isinstance(make_env_kwargs, dict):
                raise ValueError("num_parallel_envs must be set if make_env_kwargs is a dict")
        
        if make_env_kwargs is None:
            make_env_kwargs = {}
        if isinstance(make_env_kwargs, dict):
            make_env_kwargs = [make_env_kwargs for _ in range(num_parallel_envs)]
        if num_parallel_envs is None:
            num_parallel_envs = len(make_env_kwargs)
        if len(make_env_kwargs) != num_parallel_envs:
            raise ValueError("make_env_kwargs must have the same length as num_parallel_envs")

        self.agent = agent
        self.make_env = make_env
        self.make_env_kwargs = make_env_kwargs
        self.num_parallel_envs = num_parallel_envs

    def __iter__(self) -> Iterable[Transition]:
        envs = gym.vector.AsyncVectorEnv([
            functools.partial(self.make_env, **kwargs)
            for kwargs in self.make_env_kwargs
        ])
        state, _ = envs.reset()
        while True:
            action, log_prob = self.agent.select_action(state, collecting_data=True)
            next_state, reward, terminated, truncated, _ = envs.step(action)
            done = terminated | truncated
            batch = zip(state, action, reward, next_state, log_prob, done)
            for s, a, r, s_, l, d in batch:
                yield Transition(s, a, r, s_, l, d)
            state = next_state


class DataCollectionCallback(lightning.Callback):
    def __init__(
        self,
        parallel_experience_generator: DataGeneratorParallel,
        buffer: ReplayBuffer,
        collect_every_n_updates: int,
        collect_num_steps_in_total: int,
    ) -> None:
        super().__init__()
        self.parallel_experience_collector = parallel_experience_generator
        self.buffer = buffer
        self.collect_every_n_updates = collect_every_n_updates
        self.collect_num_steps_total = collect_num_steps_in_total
        if collect_num_steps_in_total % parallel_experience_generator.num_parallel_envs != 0:
            raise ValueError("collect_num_steps_in_total must be divisible by num_parallel_envs")
        self.collect_num_batch_steps = collect_num_steps_in_total // parallel_experience_generator.num_parallel_envs
        self.experience_iterator = iter(self.parallel_experience_collector)

    def on_train_start(self, trainer, pl_module: lightning.LightningModule) -> None:
        self.collect_experience(pl_module)

    def on_train_batch_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule, outputs, batch, batch_idx) -> None:
        if trainer.global_step % self.collect_every_n_updates == 0:
            self.collect_experience(pl_module)

    def collect_experience(self, pl_module: lightning.LightningModule):
        wall_time = time.time()
        for step in range(self.collect_num_batch_steps):
            transition = next(self.experience_iterator)
            self.buffer.add(transition)
        num_seconds = time.time() - wall_time
        num_parallel_envs = self.parallel_experience_collector.num_parallel_envs
        metrics = {
            "collect_experience/duration_seconds": num_seconds,
            "collect_experience/num_steps_total": float(self.collect_num_steps_total),
            "collect_experience/num_steps_in_each_env": float(self.collect_num_batch_steps),
            "collect_experience/num_batch_steps_per_second": self.collect_num_batch_steps / num_seconds,
            "collect_experience/num_seconds_per_batch_step": num_seconds / self.collect_num_batch_steps,
            "collect_experience/num_env_steps_per_second": self.collect_num_batch_steps * num_parallel_envs / num_seconds,
            "collect_experience/num_seconds_per_env_step": num_seconds / (self.collect_num_batch_steps * num_parallel_envs),
            "collect_experience/num_parallel_envs": float(num_parallel_envs),
        }
        pl_module.log_dict(metrics)

# ONNX export

In [None]:
class ModelCheckpointWithActorOnnxExport(lightning.pytorch.callbacks.ModelCheckpoint):
    def __init__(
        self,
        example_input: torch.Tensor,
        export_kwargs: dict[str, Any] | None = None,
        checkpointing_kwargs: dict[str, Any] | None = None,
    ) -> None:
        super().__init__(**checkpointing_kwargs)
        self.example_input = example_input
        self.export_kwargs = {} if export_kwargs is None else export_kwargs

    def _actor_path(self, checkpoint_path: pathlib.Path | str) -> pathlib.Path:
        return pathlib.Path(checkpoint_path).with_suffix(".actor.onnx")

    def _export_actor(self, ppo: PPO, checkpoint_path: pathlib.Path | str) -> None:
        onnx_path = self._actor_path(checkpoint_path)
        if not onnx_path.exists():
            param = next(ppo.actor.parameters())
            inputs = self.example_input.to(param)
            torch.onnx.export(ppo.actor, inputs, onnx_path, **self.export_kwargs)

    def _save_checkpoint(self, trainer: lightning.Trainer, filepath):
        """
        Save actor onnx checkpoint if the full checkpoint is saved.
        """
        output = super()._save_checkpoint(trainer, filepath)
        filepath = pathlib.Path(filepath)
        if filepath.exists():
            self._export_actor(trainer.model, filepath)
        return output

    def _remove_checkpoint(self, trainer: lightning.Trainer, filepath):
        """
        Remove actor onnx checkpoint if the full checkpoint is removed.
        """
        output = super()._remove_checkpoint(trainer, filepath)
        filepath = pathlib.Path(filepath)
        if not filepath.exists():
            onnx_path = self._actor_path(filepath)
            onnx_path.unlink(missing_ok=True)
        return output

# Hyperparameters

In [None]:
max_steps = -1
accelerator = "cpu"
precision = "16-mixed"
devices = []

optimizer_class = "AdamW"
learning_rate = 1e-4
batch_size = 128

hidden_dim = 512
num_blocks = 4

discount_factor = 0.99
importance_sampling_clip_range = (0.8, 1.2)

buffer_capacity = 100_000
collect_data_num_parallel_envs = 10
collect_data_every_n_updates = 10
collect_data_num_steps = 20

val_check_interval = 100
valid_num_runs = 10
valid_num_parallel_envs = 10

checkpoint_monitor = ("score/avg", "max")

hardcore = False

use_early_stopping = False
early_stopping_threshold = None
early_stopping_patience = 5
early_stopping_min_delta = 0.0

In [None]:
trainer_devices = "auto" if accelerator == "cpu" else devices

In [None]:
if hardcore:
    make_env = functools.partial(make_bipedal_walker_env, hardcore=True)
else:
    make_env = functools.partial(make_bipedal_walker_env, hardcore=False)

In [None]:
env: gym.Env = make_env()
state, _ = env.reset()
example_input_for_onnx = torch.from_numpy(state).float().unsqueeze(0)
input_dim = torch.Size(env.observation_space.shape).numel()
output_dim = torch.Size(env.action_space.shape).numel()
env.close()

In [None]:
backbone_config = dict(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_blocks=num_blocks,
    activation=torch.nn.GELU(),
    use_skip=True,
)

# Training

In [None]:
os.makedirs("./wandb", exist_ok=True)
os.makedirs("./videos", exist_ok=True)

hardcore_str = "hardcore_on" if hardcore else "hardcore_off"

wandb_logger = lightning.pytorch.loggers.WandbLogger(
    project="jku-deep-rl_ppo",
    group=hardcore_str,
    tags=["ppo", hardcore_str],
    save_dir="./wandb",
)
wandb_experiment: wandb.wandb_run.Run = wandb_logger.experiment

In [None]:
ppo = PPO(
    backbone_config=backbone_config,
    output_dim=output_dim,
    optimizer_class=optimizer_class,
    optimizer_config=dict(lr=learning_rate),
    discount=discount_factor,
    make_env=make_env,
    video_root_folder = f"./videos/{wandb_experiment.name}",
)

In [None]:
buffer = ReplayBuffer(capacity=buffer_capacity, seed=42)

experience_generator = DataGeneratorParallel(
    agent = ppo,
    make_env = make_env,
    make_env_kwargs = [{} for _ in range(collect_data_num_parallel_envs)],
    num_parallel_envs = collect_data_num_parallel_envs,
)

experience_collection_callback = DataCollectionCallback(
    parallel_experience_generator = experience_generator,
    buffer = buffer,
    collect_every_n_updates = collect_data_every_n_updates,
    collect_num_steps_in_total = collect_data_num_steps,
)

monitor_metric, metric_mode = checkpoint_monitor
checkpointing = ModelCheckpointWithActorOnnxExport(
    example_input=example_input_for_onnx,
    export_kwargs=dict(
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
    ),
    checkpointing_kwargs=dict(
        monitor=monitor_metric,
        mode=metric_mode,
        save_last=True,
        save_top_k=10,
        dirpath=f"checkpoints/{wandb_logger.experiment.name}/",
        filename="step={step}_score_avg={score/avg:.2f}",
        auto_insert_metric_name=False,
    )
)

callbacks = [
    experience_collection_callback,
    checkpointing,
]

if use_early_stopping:
    early_stopping = lightning.pytorch.callbacks.EarlyStopping(
        monitor=monitor_metric,
        mode=metric_mode,
        patience=early_stopping_patience,
        min_delta=early_stopping_min_delta,
        stopping_threshold=early_stopping_threshold,
    )
    callbacks.append(early_stopping)

In [None]:
valid_num_batches = math.ceil(valid_num_runs / valid_num_parallel_envs)

train_loader = torch.utils.data.DataLoader(
    buffer,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=False, # ReplayBuffer already samples randomly
    prefetch_factor=32,
    num_workers=2,
)

valid_make_env_kwargs = [
    {
        "seed": env_idx * valid_num_batches + batch_ids,
    }
    for env_idx in range(valid_num_parallel_envs) for batch_ids in range(valid_num_batches)
]

# validation loader just gives parameters for creating the gym env
valid_loader = torch.utils.data.DataLoader(
    valid_make_env_kwargs,
    batch_size=valid_num_parallel_envs,
    shuffle=False,
    drop_last=False,
    collate_fn=lambda x: x,
)

In [None]:
wandb_logger.experiment.config.update({
    "buffer_capacity": buffer_capacity,
    "collect_data": {
        "num_parallel_envs": collect_data_num_parallel_envs,
        "every_n_updates": collect_data_every_n_updates,
        "collect_data_num_steps": collect_data_num_steps,
    },
    "checkpointing": {
        "monitor": checkpointing.monitor,
        "mode": checkpointing.mode,
        "save_last": checkpointing.save_last,
        "save_top_k": checkpointing.save_top_k,
    },
    "train_loader": {
        "batch_size": train_loader.batch_size,
    },
    "validation": {
        "num_parallel_envs": valid_num_parallel_envs,
        "make_env_kwargs": valid_make_env_kwargs,
        "total_runs": len(valid_make_env_kwargs),
    }
})

In [None]:
trainer = lightning.Trainer(
    logger=wandb_logger,
    max_epochs=-1,
    max_steps=max_steps,
    val_check_interval=val_check_interval,
    devices=trainer_devices,
    precision=precision,
    accelerator=accelerator,
    callbacks=callbacks,
    num_sanity_val_steps=0,
)

In [None]:
trainer.fit(
    ppo,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)