diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 93724c99c..caa166b24 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1303,6 +1303,49 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): result.lock_() return result + @cache # noqa: B019 + def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): + if self.hook_out is not None: + # this is the hacked version. We just need to remove the hook_out and + # reset a proper batch size + result = LazyStackedTensorDict( + *self.tensordicts, + stack_dim=out_dim, + ) + # return self._cache_remove_batch_dim(vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim) + else: + # we must call _remove_batch_dim on all tensordicts + # batch_size: size of the batch when we unhide it. + # out_dim: dimension where the output will be found + new_batch_size = list(self.batch_size) + new_batch_size.insert(out_dim, batch_size) + new_names = list(self.names) + new_names.insert(out_dim, None) + # rebuild the lazy stack + # the stack dim is the same if the out_dim is past it, but it + # must be incremented by one otherwise. + # In the first case, the out_dim must be decremented by one + if out_dim > self.stack_dim: + stack_dim = self.stack_dim + out_dim = out_dim - 1 + else: + stack_dim = self.stack_dim + 1 + result = LazyStackedTensorDict( + *[ + td._maybe_remove_batch_dim( + funcname, + vmap_level=vmap_level, + batch_size=batch_size, + out_dim=out_dim, + ) + for td in self.tensordicts + ], + stack_dim=stack_dim, + ) + if self.is_locked: + result.lock_() + return result + def get_nestedtensor( self, key: NestedKey, @@ -3724,6 +3767,7 @@ def _cast_reduction( _multithread_rebuild = TensorDict._multithread_rebuild _remove_batch_dim = TensorDict._remove_batch_dim + _maybe_remove_batch_dim = TensorDict._maybe_remove_batch_dim all = TensorDict.all any = TensorDict.any expand = TensorDict.expand diff --git a/tensordict/_td.py b/tensordict/_td.py index a7bc5fda3..0914461ed 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -82,10 +82,12 @@ unravel_key, unravel_key_list, ) -from torch import Tensor +from torch import nn, Tensor from torch._dynamo import graph_break +from torch._functorch.vmap import _maybe_remove_batch_dim from torch.jit._shape_functions import infer_size_impl from torch.nn.parameter import UninitializedTensorMixin +from torch.nn.utils._named_member_accessor import swap_tensor from torch.utils._pytree import tree_map try: @@ -447,7 +449,7 @@ def is_empty(self): def _to_module( self, - module, + module: nn.Module, *, inplace: bool | None = None, return_swap: bool = True, @@ -455,8 +457,10 @@ def _to_module( memo=None, use_state_dict: bool = False, non_blocking: bool = False, + is_dynamo: bool | None = None, ): - is_dynamo = is_dynamo_compiling() + if is_dynamo is None: + is_dynamo = torch.compiler.is_dynamo_compiling() if is_dynamo: _check_inbuild() @@ -500,8 +504,8 @@ def _to_module( ) def convert_type(x, y): - if isinstance(y, torch.nn.Parameter): - return torch.nn.Parameter(x) + if isinstance(y, nn.Parameter): + return nn.Parameter(x) if isinstance(y, Buffer): return Buffer(x) return x @@ -514,7 +518,8 @@ def convert_type(x, y): inplace = bool(inplace) # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can - if type(module).__setattr__ is __base__setattr__: + if not is_dynamo and type(module).__setattr__ is __base__setattr__: + # if type(module).__setattr__ is __base__setattr__: __dict__ = module.__dict__ _parameters = __dict__["_parameters"] _buffers = __dict__["_buffers"] @@ -539,12 +544,8 @@ def convert_type(x, y): inplace, ) else: - if return_swap: - local_out = getattr(module, key) if not inplace: - # use specialized __setattr__ if needed - delattr(module, key) - setattr(module, key, value) + local_out = swap_tensor(module, key, value) else: new_val = local_out if return_swap: @@ -568,6 +569,7 @@ def convert_type(x, y): memo=memo, use_state_dict=use_state_dict, non_blocking=non_blocking, + is_dynamo=is_dynamo, ) if return_swap: @@ -1432,8 +1434,12 @@ def _add_batch_dim_wrapper(key, value): def _remove_batch_dim(self, vmap_level, batch_size, out_dim): new_batch_size = list(self.batch_size) new_batch_size.insert(out_dim, batch_size) - new_names = list(self.names) - new_names.insert(out_dim, None) + names = self._maybe_names() + if names: + new_names = list(names) + new_names.insert(out_dim, None) + else: + new_names = None out = TensorDict( { key: ( @@ -1451,6 +1457,38 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ) return out + @cache # noqa: B019 + def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): + new_batch_size = list(self.batch_size) + new_batch_size.insert(out_dim, batch_size) + names = self._maybe_names() + if names: + new_names = list(names) + new_names.insert(out_dim, None) + else: + new_names = None + out = TensorDict( + { + key: ( + value._maybe_remove_batch_dim( + funcname=funcname, + vmap_level=vmap_level, + batch_size=batch_size, + out_dim=out_dim, + ) + if is_tensor_collection(value) + else _maybe_remove_batch_dim( + funcname, value, vmap_level, batch_size, out_dim + ) + ) + for key, value in self.items() + }, + batch_size=new_batch_size, + names=new_names, + lock=self.is_locked, + ) + return out + def _convert_to_tensordict( self, dict_value: dict[str, Any], non_blocking: bool | None = None ) -> T: @@ -4064,6 +4102,9 @@ def _index_tensordict(self, index, new_batch_size=None, names=None): def _remove_batch_dim(self, *args, **kwargs): raise NotImplementedError + def _maybe_remove_batch_dim(self, *args, **kwargs): + raise NotImplementedError + ########################### # Keys utils @@ -4253,6 +4294,7 @@ def _set_tensor_dict( # noqa: F811 out = _buffers.pop(name, None) was_buffer = out is not None if out is None: + # dynamo doesn't like pop... out = __dict__.pop(name) if inplace: # swap tensor and out after updating out diff --git a/tensordict/base.py b/tensordict/base.py index 3710d3bc7..d6e678775 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8476,6 +8476,10 @@ def _add_batch_dim(self, *, in_dim, vmap_level): ... @cache # noqa: B019 def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ... + @abc.abstractmethod + @cache # noqa: B019 + def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ... + # Validation and checks def _convert_to_tensor(self, array: np.ndarray) -> Tensor: if isinstance(array, (float, int, bool)): diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 14e9a95ec..5d84e7914 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -112,37 +112,19 @@ def set_tensor_dict( # noqa: F811 _RESET_OLD_TENSORDICT = True -try: - import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch - from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch - _add_batch_dim, - _broadcast_to_and_flatten, - _get_name, - _remove_batch_dim, - _validate_and_get_batch_size, - Tensor, - tree_flatten, - tree_unflatten, - ) - - _has_functorch = True -except ImportError: - try: - from functorch._src.vmap import ( # @manual=fbcode//caffe2/functorch:functorch_src - _add_batch_dim, - _broadcast_to_and_flatten, - _get_name, - _remove_batch_dim, - _validate_and_get_batch_size, - Tensor, - tree_flatten, - tree_unflatten, - ) +import torch._functorch.vmap as vmap_src # @manual=fbcode//caffe2:torch +from torch._functorch.vmap import ( # @manual=fbcode//caffe2:torch + _add_batch_dim, + _broadcast_to_and_flatten, + _get_name, + _maybe_remove_batch_dim, + _validate_and_get_batch_size, + Tensor, + tree_flatten, + tree_unflatten, +) - _has_functorch = True - import functorch._src.vmap as vmap_src # @manual=fbcode//caffe2/functorch:functorch_src - except ImportError: - _has_functorch = False +_has_functorch = True class _exclude_td_from_pytree: @@ -210,7 +192,7 @@ def _process_batched_inputs( ) if ( isinstance(in_dim, int) - and not isinstance(arg, (Tensor,)) + and not isinstance(arg, Tensor) and not is_tensor_collection(arg) ): raise ValueError( @@ -262,10 +244,7 @@ def _create_batched_inputs( else: batched_input = _add_batch_dim(arg, in_dim, vmap_level) batched_inputs.append(batched_input) - if PYTREE_HAS_ISLEAF: - return tree_unflatten(batched_inputs, args_spec) - with _exclude_td_from_pytree(): - return tree_unflatten(batched_inputs, args_spec) + return tree_unflatten(batched_inputs, args_spec) vmap_src._create_batched_inputs = _create_batched_inputs @@ -301,7 +280,6 @@ def incompatible_error(): f"has structure {output_spec}." ) - # Here: if isinstance(batched_outputs, torch.Tensor) or is_tensor_collection( batched_outputs ): @@ -311,7 +289,8 @@ def incompatible_error(): flat_out_dims = [out_dims] elif isinstance(out_dims, tuple) and len(out_dims) == 1: flat_out_dims = out_dims - out_dims = out_dims[0] + elif out_dims is None: + flat_out_dims = [out_dims] else: incompatible_error() else: @@ -321,16 +300,18 @@ def incompatible_error(): flat_outputs = [] for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims): if not is_tensor_collection(batched_output): - out = _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) + out = _maybe_remove_batch_dim( + _get_name(func), batched_output, vmap_level, batch_size, out_dim + ) else: - out = batched_output._remove_batch_dim( - vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim + out = batched_output._maybe_remove_batch_dim( + _get_name(func), + vmap_level=vmap_level, + batch_size=batch_size, + out_dim=out_dim, ) flat_outputs.append(out) - if PYTREE_HAS_ISLEAF: - return tree_unflatten(flat_outputs, output_spec) - with _exclude_td_from_pytree(): - return tree_unflatten(flat_outputs, output_spec) + return tree_unflatten(flat_outputs, output_spec) vmap_src._unwrap_batched = _unwrap_batched diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 844a74715..2da21f53f 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -709,6 +709,9 @@ def _index_tensordict(self, *args, **kwargs): ... @_fallback def _remove_batch_dim(self, *args, **kwargs): ... + @_fallback + def _maybe_remove_batch_dim(self, *args, **kwargs): ... + @_fallback def _has_names(self, *args, **kwargs): ... diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 6e4a27c98..7b0b8473b 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -1335,9 +1335,9 @@ def __setstate__(self, state): def _add_batch_dim(self, *, in_dim, vmap_level): raise RuntimeError("Persistent tensordicts cannot be used with vmap.") - def _remove_batch_dim(self, vmap_level, batch_size, out_dim): - # not accessible - ... + def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ... + + def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ... def _view(self, *args, **kwargs): raise RuntimeError( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index c043e5ad2..84a21f7f1 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -154,6 +154,7 @@ def __subclasscheck__(self, subclass): "_exclude", # TODO: must be specialized "_fast_apply", "_get_sub_tensordict", + "_maybe_remove_batch_dim", "_multithread_apply_flat", "_remove_batch_dim", "_select", # TODO: must be specialized diff --git a/test/test_compile.py b/test/test_compile.py index 7d0ccf90c..0bf8dabf3 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -723,6 +723,37 @@ def call(x, td): else: assert (td_zero == 0).all() + # in-place modif raises an error even if fullgraph=False + def test_vmap_functional(self, mode): + module = torch.nn.Sequential( + torch.nn.Linear(3, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 5), + ) + + td = TensorDict.from_module(module) + td_zero = TensorDictParams(td.data.expand(10).clone().zero_()) + + def call(x, td): + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=True, swap_dest=td) + return result + + vmap_call = torch.vmap(call, (None, 0)) + call_compile = torch.compile(vmap_call, fullgraph=True, mode=mode) + x = torch.randn(2, 3) + + assert (vmap_call(x, td_zero) == 0).all() + assert (TensorDict.from_module(module) == td).all() + assert (td_zero == 0).all() + + call_compile(x, td_zero) + assert (TensorDict.from_module(module) == td).all() + assert (call_compile(x, td_zero) == 0).all() + assert (TensorDict.from_module(module) == td).all() + assert (td_zero == 0).all() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()