# Environment Solver

## Setup
We will use gymnasium and torch for this implementation.

In [None]:
ENV_ID = "CartPole-v1"
#ENV_ID = "LunarLander-v3"
ALGO_ID = "REINFORCE"  # Change to "REINFORCE" for REINFORCE algorithm
#ALGO_ID = "PPO"

: 

Install packages:

In [None]:
#%CXX=clang++ pip install "gymnasium[box2d]" --quiet
%pip install stable-baselines3 wandb swig tsilva-notebook-utils==0.0.121 --quiet

In [None]:
import warnings
warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API.*")

Load secrets:

In [None]:
from tsilva_notebook_utils.colab import load_secrets_into_env

_ = load_secrets_into_env([
    'WANDB_API_KEY'
])

Retrieve enviroment variables:

In [None]:
import torch
import numpy as np
from tsilva_notebook_utils.torch import get_default_device

DEVICE = get_default_device()
DEVICE

In [None]:
import torch.nn as nn
from tsilva_notebook_utils.gymnasium import build_env as _build_env, set_random_seed
from dataclasses import dataclass
from typing import Union, Tuple

@dataclass
class RLConfig:
    # Environment
    env_id: str
    seed: int = 42
    
    # Training
    max_epochs: int = -1
    gamma: float = 0.99
    lam: float = 0.95
    clip_epsilon: float = 0.2
    batch_size: int = 64
    train_rollout_steps: int = 2048
    
    # Evaluation
    eval_interval: int = 10
    eval_episodes: int = 32
    reward_threshold: float = 200
    
    # Networks
    policy_lr: float = 3e-4
    value_lr: float = 1e-3
    hidden_dim: Union[int, Tuple[int, ...]] = 64
    entropy_coef: float = 0.01
    
    # Other
    normalize: bool = False
    mean_reward_window: int = 100
    rollout_interval: int = 10
    n_envs: Union[str, int] = "auto"
    async_rollouts: bool = True
    
    # Environment-specific configurations with flat structure
    # Use these as reference: https://github.com/DLR-RM/rl-baselines3-zoo/tree/master/hyperparams
    ENV_CONFIGS = {
        "CartPole-v1": {
            # Default config for this environment (applies to all algorithms unless overridden)
            "default": dict(
                train_rollout_steps=512,
                batch_size=256,
                rollout_interval=1,
                eval_interval=20,
                eval_episodes=5,
                reward_threshold=475,
                policy_lr=1e-3,
                value_lr=1e-3,
                hidden_dim=32,
            ),
            # Algorithm-specific overrides
            "reinforce": dict(
                train_rollout_steps=2048,
                batch_size=512,
                policy_lr=1e-3,
                entropy_coef=0.02,
            ),
        },
        "Acrobot-v1": {
            "default": dict(
                gamma=0.99,
                lam=0.98,
                clip_epsilon=0.2,
                batch_size=128,
                train_rollout_steps=2048,
                eval_interval=5,
                reward_threshold=-100,
                policy_lr=3e-4,
                value_lr=3e-4,
                hidden_dim=(128, 64),
                entropy_coef=0.01,
                rollout_interval=1
            ),
            "reinforce": dict(
                train_rollout_steps=4096,
                batch_size=256,
                policy_lr=5e-4,
                entropy_coef=0.05,
            ),
        },
        # TODO:
        #n_envs: 16
        #n_epochs: 4 
        #n_steps: 1024
        "LunarLander-v3": {
            "default": dict(
                reward_threshold=200,
                total_timesteps=1e6, # TODO: call this n_timesteps
                gamma=0.999,
                # TODO: this is not being propagated to collect_rollouts
                lam=0.98, # gae_lambda: 0.98
                clip_epsilon=0.2,
                batch_size=64,
                eval_interval=2,
                policy_lr=1e-4,
                value_lr=5e-4,
                hidden_dim=32,
                entropy_coef=0.01
            ),
            "reinforce": dict(
                entropy_coef=0.03,
                batch_size=128,
            ),
        },
        "Pendulum-v1": {
            "default": dict(
                gamma=0.99,
                lam=0.95,
                clip_epsilon=0.2,
                batch_size=64,
                eval_interval=2,
                eval_episodes=5,
                reward_threshold=-200,
                policy_lr=3e-4,
                value_lr=1e-3,
                hidden_dim=(128, 64),
                entropy_coef=0.0
            ),
            "reinforce": dict(
                entropy_coef=0.02,
                batch_size=128,
            ),
        },
        "MountainCar-v0": {
            "default": dict(
                gamma=0.99,
                lam=0.97,
                clip_epsilon=0.15,
                batch_size=16,
                eval_interval=2,
                eval_episodes=10,
                reward_threshold=-110,
                policy_lr=1e-4,
                value_lr=5e-4,
                hidden_dim=(128, 64),
                entropy_coef=0.05
            ),
            "reinforce": dict(
                entropy_coef=0.08,
                batch_size=32,
            ),
        },
    }
    
    @classmethod
    def create(cls, env_id: str, algorithm: str = "ppo") -> 'RLConfig':
        """
        Create config with hierarchical overrides:
        1. Start with dataclass defaults
        2. Apply environment default config
        3. Apply algorithm-specific config for that environment
        """
        config = cls(env_id=env_id)
        
        # Level 2 & 3: Apply environment and algorithm configs
        if env_id in cls.ENV_CONFIGS:
            env_config = cls.ENV_CONFIGS[env_id]
            
            # Apply environment default config first
            if "default" in env_config:
                for key, value in env_config["default"].items():
                    setattr(config, key, value)
            
            # Apply algorithm-specific config if it exists
            if algorithm in env_config:
                for key, value in env_config[algorithm].items():
                    setattr(config, key, value)
        
        return config

# Create configs for different algorithms
CONFIG = RLConfig.create(ENV_ID, ALGO_ID)
CONFIG

Build environment:

In [None]:
from tsilva_notebook_utils.gymnasium import log_env_info

# Set random seed for reproducibility
set_random_seed(CONFIG.seed)

# Wrap build env with config parameters
build_env = lambda seed, n_envs=None: _build_env(
    CONFIG.env_id, 
    norm_obs=CONFIG.normalize, 
    n_envs=n_envs if n_envs is not None else CONFIG.n_envs, 
    seed=seed
)

# Test building env
env = build_env(CONFIG.seed)
log_env_info(env)

## Build Agent

Define models:

In [None]:
class MLPNet(nn.Module):
    """Reusable MLP with configurable hidden dimensions"""
    
    def __init__(self, input_dim, output_dim, hidden_dim=64, activation=nn.ReLU):
        super().__init__()
        
        if isinstance(hidden_dim, (int, float)):
            hidden_dims = [int(hidden_dim)]
        else:
            hidden_dims = [int(dim) for dim in hidden_dim]
        
        layers = []
        current_dim = input_dim
        
        for hidden_size in hidden_dims:
            layers.extend([
                nn.Linear(current_dim, hidden_size),
                activation()
            ])
            current_dim = hidden_size
        
        layers.append(nn.Linear(current_dim, output_dim))
        self.net = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.net(x)

class PolicyNet(MLPNet):
    def __init__(self, obs_dim, act_dim, hidden_dim=64):
        super().__init__(obs_dim, act_dim, hidden_dim)

class ValueNet(MLPNet):
    def __init__(self, obs_dim, hidden_dim=64):
        super().__init__(obs_dim, 1, hidden_dim)

In [None]:

import wandb
from pytorch_lightning.loggers import WandbLogger
from tsilva_notebook_utils.lightning import WandbCleanup

# Create PPO agent
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n if hasattr(env.action_space, 'n') else env.action_space.shape[0]
ppo_agent = PPOAgent(obs_dim, act_dim, CONFIG, build_env)

agent_cls = None
if ALGO_ID.upper() == "PPO": agent_cls = PPOAgent
elif ALGO_ID.upper() == "REINFORCE": agent_cls = REINFORCEAgent
else: raise ValueError(f"Unsupported algorithm: {ALGO_ID}. Choose 'PPO' or 'REINFORCE'")
agent = agent_cls(obs_dim, act_dim, CONFIG, build_env)

wandb_logger = WandbLogger(
    project=f"{ENV_ID}",
    name=f"{ALGO_ID}-{wandb.util.generate_id()[:5]}",
    log_model=True
)

# Print W&B run URL explicitly
print(f"🔗 W&B Run: {wandb_logger.experiment.url}")

trainer = pl.Trainer(
    logger=wandb_logger,
    log_every_n_steps=10,
    max_epochs=CONFIG.max_epochs,
    enable_progress_bar=True,
    enable_checkpointing=False,  # Disable checkpointing for speed
    accelerator="auto",
    callbacks=[WandbCleanup()]
)

# Fit the model
trainer.fit(ppo_agent)

## Evaluate

In [None]:
import random
from tsilva_notebook_utils.gymnasium import render_episode_frames

n_episodes = 8
trajectories, _ = collect_rollouts(
    build_env(random.randint(0, 1_000_000), n_envs=n_episodes),
    ppo_agent.policy_model,
    n_episodes=n_episodes,
    deterministic=True,
    collect_frames=True
)
episodes = group_trajectories_by_episode(trajectories) # something is wrong in frame collection
mean_reward = np.mean([sum(step[2] for step in episode) for episode in episodes])
episode_frames = [[step[-1] for step in episode] for episode in episodes]
print(f"Mean reward: {mean_reward:.2f}")
render_episode_frames(episode_frames, out_dir="./tmp", grid=(2, 2), text_color=(0, 0, 0))

In [None]:
# Key metrics to watch on W&B dashboard:
primary_metrics = [
    "eval/mean_reward",           # Main success indicator
    "train/mean_reward",          # Training progress
    "epoch/explained_var",        # Value function quality
    "epoch/entropy",              # Exploration level
    "epoch/clip_fraction"         # Policy update stability
]

warning_conditions = {
    "epoch/clip_fraction > 0.5": "Reduce policy_lr or clip_epsilon",
    "epoch/approx_kl > 0.1": "Reduce policy_lr", 
    "epoch/explained_var < 0.3": "Increase value_lr or network size",
    "epoch/entropy < 0.01": "Increase entropy_coef",
    "rollout/queue_miss > rollout/queue_updated": "Check async collection"
}