# Environment Solver

In [None]:
# Enable autoreloading of modules
%load_ext autoreload
%autoreload 2

In [None]:
ENV_ID = "CartPole-v1"
#ENV_ID = "LunarLander-v3"
#ALGORITHM = "reinforce"
ALGORITHM = "ppo"

Load secrets:

In [None]:
from tsilva_notebook_utils.colab import load_secrets_into_env

_ = load_secrets_into_env([
    'WANDB_API_KEY'
])

In [None]:
from utils.config import load_config

# Load configuration from YAML files
CONFIG = load_config(ENV_ID, ALGORITHM)
print(f"Loaded config for {ENV_ID} with {ALGORITHM} algorithm:")
print(CONFIG)

Build environment:

In [None]:
from tsilva_notebook_utils.gymnasium import log_env_info
from utils.environment import setup_environment

# Setup environment with configuration
build_env_fn = setup_environment(CONFIG)

# Test building env
env = build_env_fn(CONFIG.seed)
log_env_info(env)

Define models:

In [None]:
import numpy as np
from utils.training import create_agent, create_trainer
from tsilva_notebook_utils.torch import get_default_device

# Get environment dimensions
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n if hasattr(env.action_space, 'n') else env.action_space.shape[0]

# Debug device information
print(f"Default device: {get_default_device()}")

# Create agent using utility function
agent = create_agent(CONFIG, build_env_fn, obs_dim, act_dim, algorithm=ALGORITHM)

# Debug model devices
print(f"Policy model device: {next(agent.policy_model.parameters()).device}")
if hasattr(agent, 'value_model') and agent.value_model is not None:
    print(f"Value model device: {next(agent.value_model.parameters()).device}")
print(f"Rollout collector type: {type(agent.rollout_collector)}")

# Create trainer with W&B logging
trainer = create_trainer(CONFIG, project_name=ENV_ID, run_name=f"{ALGORITHM}-{CONFIG.seed}")

# Fit the model
trainer.fit(agent)

In [None]:
from utils.evaluation import evaluate_agent

# Evaluate agent and render episodes
# TODO: evaluate agent should receive rollout collector as an argument, not the agent and build_env_fn
results = evaluate_agent(
    agent, 
    build_env_fn, 
    n_episodes=8, 
    deterministic=True, 
    render=True,
    grid=(2, 2), 
    text_color=(0, 0, 0), 
    out_dir="./tmp"
)

print(f"Mean reward: {results['mean_reward']:.2f}")