Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Tuple,
Type,
)
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -1909,18 +1908,8 @@ def _apply_nest(
)
for i, (td, *oth) in enumerate(_zip_strict(self.tensordicts, *others))
]
if all(r is None for r in results):
if filter_empty is None:
warn(
"Your resulting tensordict has no leaves but you did not specify filter_empty=True. "
"This now returns None (filter_empty=True). "
"To silence this warning, set filter_empty to the desired value in your call to `apply`. "
"This warning will be removed in v0.6.",
category=DeprecationWarning,
)
return
elif filter_empty:
return
if all(r is None for r in results) and filter_empty in (None, True):
return
if not inplace:
out = type(self)(
*results,
Expand Down
7 changes: 0 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,13 +1401,6 @@ def make_result(names=names, batch_size=batch_size):
# we raise the deprecation warning only if the tensordict wasn't already empty.
# After we introduce the new behaviour, we will have to consider what happens
# to empty tensordicts by default: will they disappear or stay?
warn(
"Your resulting tensordict has no leaves but you did not specify filter_empty=True. "
"This now returns None (filter_empty=True). "
"To silence this warning, set filter_empty to the desired value in your call to `apply`. "
"This warning will be removed in v0.6.",
category=DeprecationWarning,
)
return
if result is None:
result = make_result()
Expand Down
59 changes: 38 additions & 21 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5178,6 +5178,14 @@ def update_(
if input_dict_or_td is self:
# no op
return self

if not _is_tensor_collection(type(input_dict_or_td)):
from tensordict import TensorDict

input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)

if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
Expand All @@ -5193,29 +5201,35 @@ def inplace_update(name, dest, source):
if key == name[: len(key)]:
dest.copy_(source, non_blocking=non_blocking)

self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
filter_empty=True,
named=named,
is_leaf=_is_leaf_nontensor,
)
return self
else:
if not _is_tensor_collection(type(input_dict_or_td)):
from tensordict import TensorDict

input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)

# Fastest route using _foreach_copy_
keys, vals = self._items_list(True, True)
other_val = input_dict_or_td._values_list(True, True, sorting_keys=keys)
torch._foreach_copy_(vals, other_val)
return self
new_keys, other_val = input_dict_or_td._items_list(
True, True, sorting_keys=keys, default="intersection"
)
if len(new_keys):
if len(other_val) != len(vals):
vals = dict(*zip(keys, vals))
vals = [vals[k] for k in new_keys]
torch._foreach_copy_(vals, other_val)
return self
named = False

def inplace_update(dest, source):
if source is None:
return None
dest.copy_(source, non_blocking=non_blocking)

self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
filter_empty=True,
named=named,
is_leaf=_is_leaf_nontensor,
)
return self

def update_at_(
self,
Expand Down Expand Up @@ -5638,7 +5652,10 @@ def _items_list(
leaves_only=leaves_only,
is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else None,
)
keys, vals = zip(*items)
keys_vals = tuple(zip(*items))
if not keys_vals:
return (), ()
keys, vals = keys_vals
if sorting_keys is None:
return list(keys), list(vals)
if default is None:
Expand Down
7 changes: 0 additions & 7 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0) -> T:

def pad_sequence(
list_of_tensordicts: Sequence[T],
batch_first: bool | None = None,
pad_dim: int = 0,
padding_value: float = 0.0,
out: T | None = None,
Expand Down Expand Up @@ -146,12 +145,6 @@ def pad_sequence(
"The device argument is ignored by this function and will be removed in v0.5. To cast your"
" result to a different device, call `tensordict.to(device)` instead."
)
if batch_first is not None:
warnings.warn(
"The batch_first argument is deprecated and will be removed in v0.6. "
"The output will always be batch_first.",
category=DeprecationWarning,
)

if not len(list_of_tensordicts):
raise RuntimeError("list_of_tensordicts cannot be empty")
Expand Down
1 change: 0 additions & 1 deletion tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
set_interaction_mode,
set_interaction_type,
)
from tensordict.nn.sequence import TensorDictSequential
Expand Down
15 changes: 0 additions & 15 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,21 +1085,6 @@ def forward(
) -> TensorDictBase:
"""When the tensordict parameter is not set, kwargs are used to create an instance of TensorDict."""
try:
if len(args):
tensordict_out = args[0]
args = args[1:]
# we will get rid of tensordict_out as a regular arg, because it
# blocks us when using vmap
# with stateful but functional modules: the functional module checks if
# it still contains parameters. If so it considers that only a "params" kwarg
# is indicative of what the params are, when we could potentially make a
# special rule for TensorDictModule that states that the second arg is
# likely to be the module params.
warnings.warn(
"tensordict_out will be deprecated in v0.6. "
"Make sure you have removed any such arg by then.",
category=DeprecationWarning,
)
if len(args):
raise ValueError(
"Got a non-empty list of extra agruments, when none was expected."
Expand Down
82 changes: 17 additions & 65 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

import re
import warnings
from enum import auto, IntEnum

try:
from enum import StrEnum
except ImportError:
from .utils import StrEnum
from textwrap import indent
from typing import Any, Callable, Dict, List, Optional
from warnings import warn
from typing import Any, Dict, List, Optional

from tensordict._nestedkey import NestedKey

Expand All @@ -30,7 +33,7 @@
__all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"]


class InteractionType(IntEnum):
class InteractionType(StrEnum):
"""A list of possible interaction types with a distribution.

MODE, MEDIAN and MEAN point to the property / attribute with the same name.
Expand All @@ -44,11 +47,11 @@ class InteractionType(IntEnum):

"""

MODE = auto()
MEDIAN = auto()
MEAN = auto()
RANDOM = auto()
DETERMINISTIC = auto()
MODE = "mode"
MEDIAN = "median"
MEAN = "mean"
RANDOM = "random"
DETERMINISTIC = "deterministic"

@classmethod
def from_str(cls, type_str: str) -> InteractionType:
Expand All @@ -62,57 +65,11 @@ def from_str(cls, type_str: str) -> InteractionType:
_INTERACTION_TYPE: InteractionType | None = None


def _insert_interaction_mode_deprecation_warning(
prefix: str = "",
) -> Callable[[str, Warning, int], None]:
return warn(
(
f"{prefix}interaction_mode is deprecated for naming clarity and will be removed in v0.6. "
f"Please use {prefix}interaction_type with InteractionType enum instead."
),
DeprecationWarning,
stacklevel=2,
)


def interaction_type() -> InteractionType | None:
"""Returns the current sampling type."""
return _INTERACTION_TYPE


def interaction_mode() -> str | None:
"""*Deprecated* Returns the current sampling mode."""
_insert_interaction_mode_deprecation_warning()
type = interaction_type()
return type.name.lower() if type else None


class set_interaction_mode(_DecoratorContextManager):
"""*Deprecated* Sets the sampling mode of all ProbabilisticTDModules to the desired mode.

Args:
mode (str): mode to use when the policy is being called.
"""

def __init__(self, mode: str | None = "mode") -> None:
_insert_interaction_mode_deprecation_warning("set_")
super().__init__()
self.mode = InteractionType.from_str(mode) if mode else None

def clone(self) -> set_interaction_mode:
# override this method if your children class takes __init__ parameters
return type(self)(self.mode)

def __enter__(self) -> None:
global _INTERACTION_TYPE
self.prev = _INTERACTION_TYPE
_INTERACTION_TYPE = self.mode

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _INTERACTION_TYPE
_INTERACTION_TYPE = self.prev


class set_interaction_type(_DecoratorContextManager):
"""Sets all ProbabilisticTDModules sampling to the desired type.

Expand Down Expand Up @@ -366,12 +323,10 @@ def __init__(
self.log_prob_key = log_prob_key

if default_interaction_mode is not None:
_insert_interaction_mode_deprecation_warning("default_")
self.default_interaction_type = InteractionType.from_str(
default_interaction_mode
raise ValueError(
"default_interaction_mode is deprecated, use default_interaction_type instead."
)
else:
self.default_interaction_type = default_interaction_type
self.default_interaction_type = default_interaction_type

if isinstance(distribution_class, str):
distribution_class = distributions_maps.get(distribution_class.lower())
Expand Down Expand Up @@ -418,12 +373,9 @@ def log_prob(self, tensordict):

@property
def SAMPLE_LOG_PROB_KEY(self):
warnings.warn(
"SAMPLE_LOG_PROB_KEY will be deprecated in v0.6."
"Use 'obj.log_prob_key' instead",
category=DeprecationWarning,
raise RuntimeError(
"SAMPLE_LOG_PROB_KEY is fully deprecated. Use `obj.log_prob_key` instead."
)
return self.log_prob_key

@dispatch(auto_batch_size=False)
@_set_skip_existing_None()
Expand Down
29 changes: 29 additions & 0 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import inspect
import os
from enum import ReprEnum
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -444,3 +445,31 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode


# Reproduce StrEnum for python<3.11


class StrEnum(str, ReprEnum): # noqa
def __new__(cls, *values):
if len(values) > 3:
raise TypeError("too many arguments for str(): %r" % (values,))
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str):
raise TypeError("%r is not a string" % (values[0],))
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(values[1], str):
raise TypeError("encoding must be a string, not %r" % (values[1],))
if len(values) == 3:
# check that errors argument is a string
if not isinstance(values[2], str):
raise TypeError("errors must be a string, not %r" % (values[2]))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member

def _generate_next_value_(name, start, count, last_values):
return name.lower()
6 changes: 2 additions & 4 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,9 @@ def _get_item(tensor: Tensor, index: IndexType) -> Tensor:
if _is_lis_of_list_of_bools(index):
index = torch.tensor(index, device=tensor.device)
if index.dtype is torch.bool:
warnings.warn(
raise RuntimeError(
"Indexing a tensor with a nested list of boolean values is "
"going to be deprecated in v0.6 as this functionality is not supported "
f"by PyTorch. (follows error: {err})",
category=DeprecationWarning,
"not supported by PyTorch.",
)
return tensor[index]
raise err
Expand Down