Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
634 changes: 574 additions & 60 deletions test/test_cost.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,9 +1804,9 @@ def test_batch_errors():

@pytest.mark.skipif(not torchrl._utils.RL_WARNINGS, reason="RL_WARNINGS is not set")
def test_add_warning():
from torchrl._utils import RL_WARNINGS
from torchrl._utils import rl_warnings

if not RL_WARNINGS:
if not rl_warnings():
return
rb = ReplayBuffer(storage=ListStorage(10), batch_size=3)
with pytest.warns(
Expand Down
5 changes: 5 additions & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,3 +1062,8 @@ def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]:
runtime_env["env_vars"] = dict(runtime_env["env_vars"])

return ray_init_config


def rl_warnings():
"""Checks the status of the RL_WARNINGS env varioble."""
return RL_WARNINGS
12 changes: 6 additions & 6 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
compile_with_warmup,
logger as torchrl_logger,
prod,
RL_WARNINGS,
rl_warnings,
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None:
total_frames = float("inf")
else:
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
if remainder != 0 and rl_warnings():
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"This means {frames_per_batch - remainder} additional frames will be collected."
Expand All @@ -1238,7 +1238,7 @@ def _setup_init_random_frames(
if (
init_random_frames not in (-1, None, 0)
and init_random_frames % frames_per_batch != 0
and RL_WARNINGS
and rl_warnings()
):
warnings.warn(
f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
Expand All @@ -1261,7 +1261,7 @@ def _setup_postproc(self, postproc: Callable | None) -> None:

def _setup_frames_per_batch(self, frames_per_batch: int) -> None:
"""Calculate and validate frames per batch."""
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
if frames_per_batch % self.n_env != 0 and rl_warnings():
warnings.warn(
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
f" this results in more frames_per_batch per iteration that requested"
Expand Down Expand Up @@ -2809,7 +2809,7 @@ def _setup_multi_total_frames(
total_frames = float("inf")
else:
remainder = total_frames % total_frames_per_batch
if remainder != 0 and RL_WARNINGS:
if remainder != 0 and rl_warnings():
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
Expand Down Expand Up @@ -3741,7 +3741,7 @@ def update_policy_weights_(
def frames_per_batch_worker(self, worker_idx: int | None) -> int:
if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
return self._frames_per_batch[worker_idx]
if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings():
warnings.warn(
f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
f" this results in more frames_per_batch per iteration that requested."
Expand Down
12 changes: 6 additions & 6 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from torch import Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import accept_remote_rref_udf_invocation, RL_WARNINGS
from torchrl._utils import accept_remote_rref_udf_invocation, rl_warnings
from torchrl.data.replay_buffers.samplers import (
PrioritizedSampler,
RandomSampler,
Expand Down Expand Up @@ -871,7 +871,7 @@ def add(self, data: Any) -> int:
data = None
if data is None:
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
if RL_WARNINGS and is_tensor_collection(data) and data.ndim:
if rl_warnings() and is_tensor_collection(data) and data.ndim:
warnings.warn(
f"Using `add()` with a TensorDict that has batch_size={data.batch_size}. "
f"Use `extend()` to add multiple elements, or `add()` with a single element (batch_size=torch.Size([])). "
Expand Down Expand Up @@ -1319,14 +1319,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
>>> # get the info to find what the indices are
>>> sample, info = rb.sample(5, return_info=True)
>>> print(sample, info)
tensor([2, 7, 4, 3, 5]) {'_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
tensor([2, 7, 4, 3, 5]) {'priority_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])}
>>> # update priority
>>> priority = torch.ones(5) * 5
>>> rb.update_priority(info["index"], priority)
>>> # and now a new sample, the weights should be updated
>>> sample, info = rb.sample(5, return_info=True)
>>> print(sample, info)
tensor([2, 5, 2, 2, 5]) {'_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
tensor([2, 5, 2, 2, 5]) {'priority_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465],
dtype=float32), 'index': array([2, 5, 2, 2, 5])}

"""
Expand Down Expand Up @@ -1861,7 +1861,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
>>> print(sample)
TensorDict(
fields={
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
Expand All @@ -1884,7 +1884,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
>>> print(sample)
TensorDict(
fields={
_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False),
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
Expand Down
20 changes: 12 additions & 8 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tensordict.utils import NestedKey
from torch.utils._pytree import tree_map
from torchrl._extension import EXTENSION_WARNING
from torchrl._utils import _replace_last, logger, RL_WARNINGS
from torchrl._utils import _replace_last, logger, rl_warnings
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index

Expand Down Expand Up @@ -373,7 +373,7 @@ class PrioritizedSampler(Sampler):
device=cpu,
is_shared=False)
>>> print(info)
{'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
{'priority_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11,
1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

.. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
self.dtype = dtype
self._max_priority_within_buffer = max_priority_within_buffer
self._init()
if RL_WARNINGS and SumSegmentTreeFp32 is None:
if rl_warnings() and SumSegmentTreeFp32 is None:
logger.warning(EXTENSION_WARNING)

def __repr__(self):
Expand Down Expand Up @@ -588,7 +588,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
weight = torch.pow(weight / p_min, -self._beta)
if storage.ndim > 1:
index = unravel_index(index, storage.shape)
return index, {"_weight": weight}
return index, {"priority_weight": weight}

def add(self, index: torch.Tensor | int) -> None:
super().add(index)
Expand Down Expand Up @@ -2068,7 +2068,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
episode [2, 2, 2, 2, 1, 1]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 1, 2]
>>> print("weight", info["_weight"].tolist())
>>> print("weight", info["priority_weight"].tolist())
weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
>>> priority = torch.tensor([0,3,3,0,0,0,1,1,1])
>>> rb.update_priority(torch.arange(0,9,1), priority=priority)
Expand All @@ -2077,7 +2077,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
episode [2, 2, 2, 2, 2, 2]
>>> print("steps", sample["steps"].tolist())
steps [1, 2, 0, 1, 0, 1]
>>> print("weight", info["_weight"].tolist())
>>> print("weight", info["priority_weight"].tolist())
weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06]
"""

Expand Down Expand Up @@ -2294,15 +2294,19 @@ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]
if isinstance(starts, tuple):
starts = torch.stack(starts, -1)
# starts = torch.as_tensor(starts, device=lengths.device)
info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device)
info["priority_weight"] = torch.as_tensor(
info["priority_weight"], device=lengths.device
)

# extends starting indices of each slice with sequence_length to get indices of all steps
index = self._tensor_slices_from_startend(
seq_length, starts, storage_length=storage.shape[0]
)

# repeat the weight of each slice to match the number of steps
info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length)
info["priority_weight"] = torch.repeat_interleave(
info["priority_weight"], seq_length
)

if self.truncated_key is not None:
# following logics borrowed from SliceSampler
Expand Down
26 changes: 23 additions & 3 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from torch.nn import Parameter

from torchrl._utils import RL_WARNINGS
from torchrl._utils import rl_warnings
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules.tensordict_module.rnn import set_recurrent_mode
from torchrl.objectives.utils import ValueEstimators
Expand All @@ -34,7 +34,7 @@
def _updater_check_forward_prehook(module, *args, **kwargs):
if (
not all(module._has_update_associated.values())
and RL_WARNINGS
and rl_warnings()
and not is_compiling()
):
warnings.warn(
Expand Down Expand Up @@ -128,6 +128,7 @@ class _AcceptedKeys:
tensor_keys: _AcceptedKeys
_vmap_randomness = None
default_value_estimator: ValueEstimators = None
use_prioritized_weights: str | bool = "auto"

deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC

Expand Down Expand Up @@ -449,7 +450,7 @@ def __getattr__(self, item):
params = params.data
elif (
not self._has_update_associated[item[7:-7]]
and RL_WARNINGS
and rl_warnings()
and not is_compiling()
):
# no updater associated
Expand Down Expand Up @@ -491,6 +492,25 @@ def reset(self) -> None:
# mainly used for PPO with KL target
pass

def _maybe_get_priority_weight(
self, tensordict: TensorDictBase
) -> torch.Tensor | None:
"""Extract priority weights from tensordict if prioritized replay is enabled.

Args:
tensordict (TensorDictBase): The input tensordict that may contain priority weights.

Returns:
torch.Tensor | None: The priority weights if available and enabled, None otherwise.
"""
weights = None
if (
self.use_prioritized_weights in (True, "auto")
and self.tensor_keys.priority_weight in tensordict.keys()
):
weights = tensordict.get(self.tensor_keys.priority_weight)
return weights

def _reset_module_parameters(self, module_name, module):
params_name = f"{module_name}_params"
target_name = f"target_{module_name}_params"
Expand Down
11 changes: 9 additions & 2 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class _AcceptedKeys:
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
priority_weight: NestedKey = "priority_weight"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys
Expand Down Expand Up @@ -202,11 +203,13 @@ def __init__(
gamma: float | None = None,
separate_losses: bool = False,
reduction: str | None = None,
use_prioritized_weights: str | bool = "auto",
) -> None:
self._in_keys = None
if reduction is None:
reduction = "mean"
super().__init__()
self.use_prioritized_weights = use_prioritized_weights
self.delay_actor = delay_actor
self.delay_value = delay_value

Expand Down Expand Up @@ -268,6 +271,8 @@ def _set_in_keys(self):
*self.value_network.in_keys,
*[unravel_key(("next", key)) for key in self.value_network.in_keys],
}
if self.use_prioritized_weights:
in_keys.add(unravel_key(self.tensor_keys.priority_weight))
self._in_keys = sorted(in_keys, key=str)

@property
Expand Down Expand Up @@ -316,6 +321,7 @@ def loss_actor(
self,
tensordict: TensorDictBase,
) -> [torch.Tensor, dict]:
weights = self._maybe_get_priority_weight(tensordict)
td_copy = tensordict.select(
*self.actor_in_keys, *self.value_exclusive_keys, strict=False
).detach()
Expand All @@ -325,7 +331,7 @@ def loss_actor(
td_copy = self.value_network(td_copy)
loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
metadata = {}
loss_actor = _reduce(loss_actor, self.reduction)
loss_actor = _reduce(loss_actor, self.reduction, weights=weights)
self._clear_weakrefs(
tensordict,
loss_actor,
Expand All @@ -340,6 +346,7 @@ def loss_value(
self,
tensordict: TensorDictBase,
) -> tuple[torch.Tensor, dict]:
weights = self._maybe_get_priority_weight(tensordict)
# value loss
td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach()
with self.value_network_params.to_module(self.value_network):
Expand Down Expand Up @@ -372,7 +379,7 @@ def loss_value(
"target_value_max": target_value.max(),
"pred_value_max": pred_val.max(),
}
loss_value = _reduce(loss_value, self.reduction)
loss_value = _reduce(loss_value, self.reduction, weights=weights)
self._clear_weakrefs(
tensordict,
"value_network_params",
Expand Down
Loading
Loading