Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,49 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim):
result.lock_()
return result

@cache # noqa: B019
def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim):
if self.hook_out is not None:
# this is the hacked version. We just need to remove the hook_out and
# reset a proper batch size
result = LazyStackedTensorDict(
*self.tensordicts,
stack_dim=out_dim,
)
# return self._cache_remove_batch_dim(vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim)
else:
# we must call _remove_batch_dim on all tensordicts
# batch_size: size of the batch when we unhide it.
# out_dim: dimension where the output will be found
new_batch_size = list(self.batch_size)
new_batch_size.insert(out_dim, batch_size)
new_names = list(self.names)
new_names.insert(out_dim, None)
# rebuild the lazy stack
# the stack dim is the same if the out_dim is past it, but it
# must be incremented by one otherwise.
# In the first case, the out_dim must be decremented by one
if out_dim > self.stack_dim:
stack_dim = self.stack_dim
out_dim = out_dim - 1
else:
stack_dim = self.stack_dim + 1
result = LazyStackedTensorDict(
*[
td._maybe_remove_batch_dim(
funcname,
vmap_level=vmap_level,
batch_size=batch_size,
out_dim=out_dim,
)
for td in self.tensordicts
],
stack_dim=stack_dim,
)
if self.is_locked:
result.lock_()
return result

def get_nestedtensor(
self,
key: NestedKey,
Expand Down Expand Up @@ -3724,6 +3767,7 @@ def _cast_reduction(
_multithread_rebuild = TensorDict._multithread_rebuild

_remove_batch_dim = TensorDict._remove_batch_dim
_maybe_remove_batch_dim = TensorDict._maybe_remove_batch_dim
all = TensorDict.all
any = TensorDict.any
expand = TensorDict.expand
Expand Down
68 changes: 55 additions & 13 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@
unravel_key,
unravel_key_list,
)
from torch import Tensor
from torch import nn, Tensor
from torch._dynamo import graph_break
from torch._functorch.vmap import _maybe_remove_batch_dim
from torch.jit._shape_functions import infer_size_impl
from torch.nn.parameter import UninitializedTensorMixin
from torch.nn.utils._named_member_accessor import swap_tensor
from torch.utils._pytree import tree_map

try:
Expand Down Expand Up @@ -447,16 +449,18 @@ def is_empty(self):

def _to_module(
self,
module,
module: nn.Module,
*,
inplace: bool | None = None,
return_swap: bool = True,
swap_dest=None,
memo=None,
use_state_dict: bool = False,
non_blocking: bool = False,
is_dynamo: bool | None = None,
):
is_dynamo = is_dynamo_compiling()
if is_dynamo is None:
is_dynamo = torch.compiler.is_dynamo_compiling()
if is_dynamo:
_check_inbuild()

Expand Down Expand Up @@ -500,8 +504,8 @@ def _to_module(
)

def convert_type(x, y):
if isinstance(y, torch.nn.Parameter):
return torch.nn.Parameter(x)
if isinstance(y, nn.Parameter):
return nn.Parameter(x)
if isinstance(y, Buffer):
return Buffer(x)
return x
Expand All @@ -514,7 +518,8 @@ def convert_type(x, y):
inplace = bool(inplace)

# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
if type(module).__setattr__ is __base__setattr__:
if not is_dynamo and type(module).__setattr__ is __base__setattr__:
# if type(module).__setattr__ is __base__setattr__:
__dict__ = module.__dict__
_parameters = __dict__["_parameters"]
_buffers = __dict__["_buffers"]
Expand All @@ -539,12 +544,8 @@ def convert_type(x, y):
inplace,
)
else:
if return_swap:
local_out = getattr(module, key)
if not inplace:
# use specialized __setattr__ if needed
delattr(module, key)
setattr(module, key, value)
local_out = swap_tensor(module, key, value)
else:
new_val = local_out
if return_swap:
Expand All @@ -568,6 +569,7 @@ def convert_type(x, y):
memo=memo,
use_state_dict=use_state_dict,
non_blocking=non_blocking,
is_dynamo=is_dynamo,
)

if return_swap:
Expand Down Expand Up @@ -1432,8 +1434,12 @@ def _add_batch_dim_wrapper(key, value):
def _remove_batch_dim(self, vmap_level, batch_size, out_dim):
new_batch_size = list(self.batch_size)
new_batch_size.insert(out_dim, batch_size)
new_names = list(self.names)
new_names.insert(out_dim, None)
names = self._maybe_names()
if names:
new_names = list(names)
new_names.insert(out_dim, None)
else:
new_names = None
out = TensorDict(
{
key: (
Expand All @@ -1451,6 +1457,38 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim):
)
return out

@cache # noqa: B019
def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim):
new_batch_size = list(self.batch_size)
new_batch_size.insert(out_dim, batch_size)
names = self._maybe_names()
if names:
new_names = list(names)
new_names.insert(out_dim, None)
else:
new_names = None
out = TensorDict(
{
key: (
value._maybe_remove_batch_dim(
funcname=funcname,
vmap_level=vmap_level,
batch_size=batch_size,
out_dim=out_dim,
)
if is_tensor_collection(value)
else _maybe_remove_batch_dim(
funcname, value, vmap_level, batch_size, out_dim
)
)
for key, value in self.items()
},
batch_size=new_batch_size,
names=new_names,
lock=self.is_locked,
)
return out

def _convert_to_tensordict(
self, dict_value: dict[str, Any], non_blocking: bool | None = None
) -> T:
Expand Down Expand Up @@ -4064,6 +4102,9 @@ def _index_tensordict(self, index, new_batch_size=None, names=None):
def _remove_batch_dim(self, *args, **kwargs):
raise NotImplementedError

def _maybe_remove_batch_dim(self, *args, **kwargs):
raise NotImplementedError


###########################
# Keys utils
Expand Down Expand Up @@ -4253,6 +4294,7 @@ def _set_tensor_dict( # noqa: F811
out = _buffers.pop(name, None)
was_buffer = out is not None
if out is None:
# dynamo doesn't like pop...
out = __dict__.pop(name)
if inplace:
# swap tensor and out after updating out
Expand Down
4 changes: 4 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8476,6 +8476,10 @@ def _add_batch_dim(self, *, in_dim, vmap_level): ...
@cache # noqa: B019
def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ...

@abc.abstractmethod
@cache # noqa: B019
def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ...

# Validation and checks
def _convert_to_tensor(self, array: np.ndarray) -> Tensor:
if isinstance(array, (float, int, bool)):
Expand Down
69 changes: 25 additions & 44 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,37 +112,19 @@ def set_tensor_dict( # noqa: F811


_RESET_OLD_TENSORDICT = True
try:
import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch
from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
_remove_batch_dim,
_validate_and_get_batch_size,
Tensor,
tree_flatten,
tree_unflatten,
)

_has_functorch = True
except ImportError:
try:
from functorch._src.vmap import ( # @manual=fbcode//caffe2/functorch:functorch_src
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
_remove_batch_dim,
_validate_and_get_batch_size,
Tensor,
tree_flatten,
tree_unflatten,
)
import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch
from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch
_add_batch_dim,
_broadcast_to_and_flatten,
_get_name,
_maybe_remove_batch_dim,
_validate_and_get_batch_size,
Tensor,
tree_flatten,
tree_unflatten,
)

_has_functorch = True
import functorch._src.vmap as vmap_src # @manual=fbcode//caffe2/functorch:functorch_src
except ImportError:
_has_functorch = False
_has_functorch = True


class _exclude_td_from_pytree:
Expand Down Expand Up @@ -210,7 +192,7 @@ def _process_batched_inputs(
)
if (
isinstance(in_dim, int)
and not isinstance(arg, (Tensor,))
and not isinstance(arg, Tensor)
and not is_tensor_collection(arg)
):
raise ValueError(
Expand Down Expand Up @@ -262,10 +244,7 @@ def _create_batched_inputs(
else:
batched_input = _add_batch_dim(arg, in_dim, vmap_level)
batched_inputs.append(batched_input)
if PYTREE_HAS_ISLEAF:
return tree_unflatten(batched_inputs, args_spec)
with _exclude_td_from_pytree():
return tree_unflatten(batched_inputs, args_spec)
return tree_unflatten(batched_inputs, args_spec)

vmap_src._create_batched_inputs = _create_batched_inputs

Expand Down Expand Up @@ -301,7 +280,6 @@ def incompatible_error():
f"has structure {output_spec}."
)

# Here:
if isinstance(batched_outputs, torch.Tensor) or is_tensor_collection(
batched_outputs
):
Expand All @@ -311,7 +289,8 @@ def incompatible_error():
flat_out_dims = [out_dims]
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
flat_out_dims = out_dims
out_dims = out_dims[0]
elif out_dims is None:
flat_out_dims = [out_dims]
else:
incompatible_error()
else:
Expand All @@ -321,16 +300,18 @@ def incompatible_error():
flat_outputs = []
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims):
if not is_tensor_collection(batched_output):
out = _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
out = _maybe_remove_batch_dim(
_get_name(func), batched_output, vmap_level, batch_size, out_dim
)
else:
out = batched_output._remove_batch_dim(
vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim
out = batched_output._maybe_remove_batch_dim(
_get_name(func),
vmap_level=vmap_level,
batch_size=batch_size,
out_dim=out_dim,
)
flat_outputs.append(out)
if PYTREE_HAS_ISLEAF:
return tree_unflatten(flat_outputs, output_spec)
with _exclude_td_from_pytree():
return tree_unflatten(flat_outputs, output_spec)
return tree_unflatten(flat_outputs, output_spec)

vmap_src._unwrap_batched = _unwrap_batched

Expand Down
3 changes: 3 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,9 @@ def _index_tensordict(self, *args, **kwargs): ...
@_fallback
def _remove_batch_dim(self, *args, **kwargs): ...

@_fallback
def _maybe_remove_batch_dim(self, *args, **kwargs): ...

@_fallback
def _has_names(self, *args, **kwargs): ...

Expand Down
6 changes: 3 additions & 3 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,9 +1335,9 @@ def __setstate__(self, state):
def _add_batch_dim(self, *, in_dim, vmap_level):
raise RuntimeError("Persistent tensordicts cannot be used with vmap.")

def _remove_batch_dim(self, vmap_level, batch_size, out_dim):
# not accessible
...
def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ...

def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ...

def _view(self, *args, **kwargs):
raise RuntimeError(
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __subclasscheck__(self, subclass):
"_exclude", # TODO: must be specialized
"_fast_apply",
"_get_sub_tensordict",
"_maybe_remove_batch_dim",
"_multithread_apply_flat",
"_remove_batch_dim",
"_select", # TODO: must be specialized
Expand Down
Loading