Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,7 @@ to be able to create this other composition:
ExcludeTransform
FiniteTensorDictCheck
FlattenObservation
FlattenTensorDict
FrameSkipTransform
GrayScale
Hash
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Trainer and hooks
TrainerHookBase
UpdateWeights
TargetNetUpdaterHook
UTDRHook


Algorithm-specific trainers (Experimental)
Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/ppo_trainer/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defaults:

- transform@transform0: noop_reset
- transform@transform1: step_counter
- transform@transform2: reward_sum

- env@training_env: batched_env
- env@training_env.create_env_fn: transformed_env
Expand Down Expand Up @@ -64,6 +65,10 @@ transform1:
max_steps: 200
step_count_key: "step_count"

transform2:
in_keys: ["reward"]
out_keys: ["reward_sum"]

training_env:
num_workers: 1
create_env_fn:
Expand All @@ -73,6 +78,7 @@ training_env:
transforms:
- ${transform0}
- ${transform1}
- ${transform2}
_partial_: true

# Loss configuration
Expand All @@ -92,6 +98,7 @@ collector:
total_frames: 1_000_000
frames_per_batch: 1024
num_workers: 2
_partial_: true

# Replay buffer configuration
replay_buffer:
Expand Down Expand Up @@ -129,3 +136,4 @@ trainer:
save_trainer_file: null
optim_steps_per_batch: null
num_epochs: 2
async_collection: false
18 changes: 14 additions & 4 deletions sota-implementations/sac_trainer/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defaults:

- transform@transform0: step_counter
- transform@transform1: double_to_float
- transform@transform2: reward_sum

- env@training_env: batched_env
- env@training_env.create_env_fn: transformed_env
Expand Down Expand Up @@ -72,6 +73,11 @@ transform1:
in_keys: null
out_keys: null

transform2:
# RewardSumTransform - sums up the rewards
in_keys: ["reward"]
out_keys: ["reward_sum"]

training_env:
num_workers: 4
create_env_fn:
Expand All @@ -81,6 +87,7 @@ training_env:
transforms:
- ${transform0}
- ${transform1}
- ${transform2}
_partial_: true

# Loss configuration
Expand All @@ -107,19 +114,21 @@ collector:
total_frames: 1_000_000
frames_per_batch: 1000
num_workers: 4
init_random_frames: 25000
init_random_frames: 2500
track_policy_version: true
_partial_: true

# Replay buffer configuration
replay_buffer:
storage:
max_size: 1_000_000
max_size: 100_000
device: cpu
ndim: 1
sampler:
writer:
compilable: false
batch_size: 256
batch_size: 64
shared: true

logger:
exp_name: sac_halfcheetah_v4
Expand All @@ -134,7 +143,7 @@ trainer:
target_net_updater: ${target_net_updater}
loss_module: ${loss}
logger: ${logger}
total_frames: 1_000_000
total_frames: ${collector.total_frames}
frame_skip: 1
clip_grad_norm: false # SAC typically doesn't use gradient clipping
clip_norm: null
Expand All @@ -144,3 +153,4 @@ trainer:
log_interval: 25000
save_trainer_file: null
optim_steps_per_batch: 64 # Match SOTA utd_ratio
async_collection: false
164 changes: 164 additions & 0 deletions sota-implementations/sac_trainer/config/config_async.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# SAC Trainer Configuration for HalfCheetah-v4
# Run with `python sota-implementations/sac_trainer/train.py --config-name=config_async`
# This configuration uses the new configurable trainer system and matches SOTA SAC implementation

defaults:

- transform@transform0: step_counter
- transform@transform1: double_to_float
- transform@transform2: reward_sum
- transform@transform3: flatten_tensordict

- env@training_env: batched_env
- env@training_env.create_env_fn: transformed_env
- env@training_env.create_env_fn.base_env: gym
- transform@training_env.create_env_fn.transform: compose

- model@models.policy_model: tanh_normal
- model@models.value_model: value
- model@models.qvalue_model: value

- network@networks.policy_network: mlp
- network@networks.value_network: mlp
- network@networks.qvalue_network: mlp

- collector@collector: multi_async

- replay_buffer@replay_buffer: base
- storage@replay_buffer.storage: lazy_tensor
- writer@replay_buffer.writer: round_robin
- sampler@replay_buffer.sampler: random
- trainer@trainer: sac
- optimizer@optimizer: adam
- loss@loss: sac
- target_net_updater@target_net_updater: soft
- logger@logger: wandb
- _self_

# Network configurations
networks:
policy_network:
out_features: 12 # HalfCheetah action space is 6-dimensional (loc + scale) = 2 * 6
in_features: 17 # HalfCheetah observation space is 17-dimensional
num_cells: [256, 256]

value_network:
out_features: 1 # Value output
in_features: 17 # HalfCheetah observation space
num_cells: [256, 256]

qvalue_network:
out_features: 1 # Q-value output
in_features: 23 # HalfCheetah observation space (17) + action space (6)
num_cells: [256, 256]

# Model configurations
models:
policy_model:
return_log_prob: true
in_keys: ["observation"]
param_keys: ["loc", "scale"]
out_keys: ["action"]
network: ${networks.policy_network}
# Configure NormalParamExtractor for higher exploration
scale_mapping: "biased_softplus_2.0" # Higher bias for more exploration (default: 1.0)
scale_lb: 1e-2 # Minimum scale value (default: 1e-4)

qvalue_model:
in_keys: ["observation", "action"]
out_keys: ["state_action_value"]
network: ${networks.qvalue_network}

transform0:
max_steps: 1000
step_count_key: "step_count"

transform1:
# DoubleToFloatTransform - converts double precision to float to fix dtype mismatch
in_keys: null
out_keys: null

transform2:
# RewardSumTransform - sums up the rewards
in_keys: ["reward"]
out_keys: ["reward_sum"]

training_env:
num_workers: 4
create_env_fn:
base_env:
env_name: HalfCheetah-v4
transform:
transforms:
- ${transform0}
- ${transform1}
- ${transform2}
_partial_: true

# Loss configuration
loss:
actor_network: ${models.policy_model}
qvalue_network: ${models.qvalue_model}
target_entropy: "auto"
loss_function: l2
alpha_init: 1.0
delay_qvalue: true
num_qvalue_nets: 2

target_net_updater:
tau: 0.001

# Optimizer configuration
optimizer:
lr: 3.0e-4

# Collector configuration
collector:
create_env_fn: ${training_env}
policy: ${models.policy_model}
total_frames: 5_000_000
frames_per_batch: 1000
num_workers: 8
# Incompatible with async collection
init_random_frames: 0
track_policy_version: true
extend_buffer: true
_partial_: true

# Replay buffer configuration
replay_buffer:
storage:
max_size: 10_000
device: cpu
ndim: 1
sampler:
writer:
compilable: false
batch_size: 256
shared: true
transform: ${transform3}

logger:
exp_name: sac_halfcheetah_v4
offline: false
project: torchrl-sota-implementations

# Trainer configuration
trainer:
collector: ${collector}
optimizer: ${optimizer}
replay_buffer: ${replay_buffer}
target_net_updater: ${target_net_updater}
loss_module: ${loss}
logger: ${logger}
total_frames: ${collector.total_frames}
frame_skip: 1
clip_grad_norm: false # SAC typically doesn't use gradient clipping
clip_norm: null
progress_bar: true
seed: 42
save_trainer_interval: 25000 # Match SOTA eval_iter
log_interval: 25000
save_trainer_file: null
optim_steps_per_batch: 16 # Match SOTA utd_ratio
async_collection: true
5 changes: 0 additions & 5 deletions sota-implementations/sac_trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
# LICENSE file in the root directory of this source tree.

import hydra
import torchrl
from torchrl.trainers.algorithms.configs import * # noqa: F401, F403


@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
def print_reward(td):
torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}")

trainer = hydra.utils.instantiate(cfg.trainer)
trainer.register_op(dest="batch_process", op=print_reward)
trainer.train()


Expand Down
20 changes: 16 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator

from torchrl.envs.llm.transforms.policy_version import PolicyVersion
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
Expand All @@ -69,8 +71,6 @@
set_exploration_type,
)

from torchrl.envs.llm.transforms.policy_version import PolicyVersion

try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
Expand Down Expand Up @@ -1818,13 +1818,20 @@ def get_policy_version(self) -> str | int | None:
return self.policy_version

def getattr_policy(self, attr):
"""Get an attribute from the policy."""
# send command to policy to return the attr
return getattr(self.policy, attr)

def getattr_env(self, attr):
"""Get an attribute from the environment."""
# send command to env to return the attr
return getattr(self.env, attr)

def getattr_rb(self, attr):
"""Get an attribute from the replay buffer."""
# send command to rb to return the attr
return getattr(self.replay_buffer, attr)


class _MultiDataCollector(DataCollectorBase):
"""Runs a given number of DataCollectors on separate processes.
Expand Down Expand Up @@ -2153,6 +2160,7 @@ def __init__(
and hasattr(replay_buffer, "shared")
and not replay_buffer.shared
):
torchrl_logger.warning("Replay buffer is not shared. Sharing it.")
replay_buffer.share()

self._policy_weights_dict = {}
Expand Down Expand Up @@ -2306,8 +2314,8 @@ def _check_replay_buffer_init(self):
fake_td["collector", "traj_ids"] = torch.zeros(
fake_td.shape, dtype=torch.long
)

self.replay_buffer.add(fake_td)
# Use extend to avoid time-related transforms to fail
self.replay_buffer.extend(fake_td.unsqueeze(-1))
self.replay_buffer.empty()

@classmethod
Expand Down Expand Up @@ -2841,6 +2849,10 @@ def getattr_env(self, attr):

return result

def getattr_rb(self, attr):
"""Get an attribute from the replay buffer."""
return getattr(self.replay_buffer, attr)


@accept_remote_rref_udf_invocation
class MultiSyncDataCollector(_MultiDataCollector):
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,10 @@ def max_size_along_dim0(data_shape):

self._storage = out
self.initialized = True
if hasattr(self._storage, "shape"):
torchrl_logger.info(
f"Initialized LazyTensorStorage with {self._storage.shape} shape"
)


class LazyMemmapStorage(LazyTensorStorage):
Expand Down Expand Up @@ -1391,6 +1395,10 @@ def max_size_along_dim0(data_shape):
else:
out = _init_pytree(self.scratch_dir, max_size_along_dim0, data)
self._storage = out
if hasattr(self._storage, "shape"):
torchrl_logger.info(
f"Initialized LazyMemmapStorage with {self._storage.shape} shape"
)
self.initialized = True

def get(self, index: int | Sequence[int] | slice) -> Any:
Expand Down
Loading
Loading