diff --git a/.gitignore b/.gitignore index 523f619db..a2ea649d3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,10 @@ functorch/_C.so t.py .vscode/ ccache.sh + +# Editor temporaries +*.swn +*.swo +*.swp +*.swm +*~ diff --git a/functorch/__init__.py b/functorch/__init__.py index d3522ed25..ac61ce0f5 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -7,9 +7,12 @@ import torch import functools import textwrap -from . import _C +from functorch._C import ( + _func_decrement_nesting, + _func_increment_nesting, +) -from ._src.vmap import vmap +from ._src.vmap import vmap, functionalize from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev, vjpfull from ._src.make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1 from ._src.make_functional import ( @@ -70,6 +73,11 @@ def _functorch_str(tensor): if level == -1: return _old_str(tensor) + if _C.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + tensor.sync_() + value = _C.get_unwrapped(tensor) value_repr = repr(value) value_repr = textwrap.indent(value_repr, ' ') @@ -79,6 +87,8 @@ def _functorch_str(tensor): return f'BatchedTensor(lvl={level}, bdim={bdim}, value=\\\n{value_repr})' if _C.is_gradtrackingtensor(tensor): return f'GradTrackingTensor(lvl={level}, value=\\\n{value_repr})' + if _C.is_functionaltensor(tensor): + return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})' raise ValueError("We don't know how to print this, please file us an issue") diff --git a/functorch/_src/vmap.py b/functorch/_src/vmap.py index b260c0cfa..ed70e6473 100644 --- a/functorch/_src/vmap.py +++ b/functorch/_src/vmap.py @@ -16,8 +16,12 @@ from functorch._C import ( _add_batch_dim, _remove_batch_dim, + _wrap_functional_tensor, + _unwrap_functional_tensor, _vmap_decrement_nesting, _vmap_increment_nesting, + _func_decrement_nesting, + _func_increment_nesting, ) in_dims_t = Union[int, Tuple] @@ -278,3 +282,21 @@ def wrapped(*args, **kwargs): finally: _vmap_decrement_nesting() return wrapped + +def functionalize(func: Callable) -> Callable: + @functools.wraps(func) + def wrapped(*args, **kwargs): + try: + func_level = _func_increment_nesting() + func_args = [_wrap_functional_tensor(x, func_level) if isinstance(x, Tensor) else x for x in args] + func_outputs = func(*func_args, **kwargs) + flattened_outputs, _ = tree_flatten(func_outputs) + for a in func_args: + if isinstance(a, Tensor): + # Call sync_() on the inputs, to ensure that they still get mutated inplace if the original + # program mutated its inputs + a.sync_() + return [_unwrap_functional_tensor(x) if isinstance(x, Tensor) else x for x in flattened_outputs] + finally: + _func_decrement_nesting() + return wrapped diff --git a/functorch/csrc/BatchedTensorImpl.cpp b/functorch/csrc/BatchedTensorImpl.cpp index 46893bd3b..57b5f6af1 100644 --- a/functorch/csrc/BatchedTensorImpl.cpp +++ b/functorch/csrc/BatchedTensorImpl.cpp @@ -142,6 +142,23 @@ bool BatchedTensorImpl::has_storage() const { } #endif +void BatchedTensorImpl::replace_(const TensorImpl* other_impl) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(DispatchKey::Batched)); + auto batched_impl = static_cast(other_impl); + + auto unwrapped_impl_self = value().unsafeGetTensorImpl(); + auto unwrapped_impl_other = batched_impl->value().unsafeGetTensorImpl(); + if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) { + // This allows us to retain the program semantic of mutating inputs + unwrapped_impl_self->replace_(unwrapped_impl_other); + } else { + value_ = batched_impl->value(); + } + bdims_ = batched_impl->bdims(); + checkInvariants(); + refreshSizesAndStrides(); +} + const char* BatchedTensorImpl::tensorimpl_type_name() const { return "BatchedTensorImpl"; } diff --git a/functorch/csrc/BatchedTensorImpl.h b/functorch/csrc/BatchedTensorImpl.h index 67444876e..cf6b880d3 100644 --- a/functorch/csrc/BatchedTensorImpl.h +++ b/functorch/csrc/BatchedTensorImpl.h @@ -92,6 +92,7 @@ struct BatchedTensorImpl : public c10::TensorImpl { #ifdef DEBUG bool has_storage() const override; #endif + void replace_(const TensorImpl* other_impl) override; void refreshSizesAndStrides(); diff --git a/functorch/csrc/BatchingRegistrations.cpp b/functorch/csrc/BatchingRegistrations.cpp index dac7101e4..2aa6be9c3 100644 --- a/functorch/csrc/BatchingRegistrations.cpp +++ b/functorch/csrc/BatchingRegistrations.cpp @@ -1418,6 +1418,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { // // m.impl("new_zeros", new_zeros_batching_rule); // // m.impl("contiguous", contiguous_batching_rule); + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } } diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 0444f65d2..38203cf8c 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include +#include #include #include @@ -273,6 +274,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({ kDynamicLayerFrontModeKey, kDynamicLayerBackModeKey, kGradWrapperKey, + DispatchKey::Functionalize, // DispatchKey::Batched, kBatchedKey, DispatchKey::ADInplaceOrView @@ -304,6 +306,19 @@ static bool batchedAtCurrentLevel(const Tensor& tensor) { return batched_at_level == level; } +static bool functionalAtCurrentLevel(const Tensor& tensor) { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + auto layer = dynamicLayerStack.back(); + auto level = layer.layerId(); + + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (!functional) { + return false; + } + auto functional_level = functional->level(); + return functional_level == level; +} + void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE @@ -342,6 +357,12 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* exclude = exclude.remove(kBatchedKey); } include = include.add(kVmapModeKey); + } else if (layer.key() == DispatchKey::Functionalize) { + const auto args = torch::jit::last(stack, op.schema().arguments().size()); + if (anyTensors(args, functionalAtCurrentLevel)) { + exclude = exclude.remove(DispatchKey::Functionalize); + } + include = include.add(DispatchKey::Functionalize); } else { TORCH_INTERNAL_ASSERT(false); } @@ -407,6 +428,30 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* return makeTensorWrapper(tensor, cur_level); }; + auto unwrap_functional = [&](const Tensor& tensor) { + if (!tensor.defined()) { + return tensor; + } + auto* maybe_functional_wrapper = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (!maybe_functional_wrapper) { + return tensor; + } + auto tensor_wrapper_level = maybe_functional_wrapper->level(); + TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level); + if (tensor_wrapper_level == cur_level) { + return maybe_functional_wrapper->value(); + } + return tensor; + }; + + auto wrap_functional = [&](const Tensor& tensor) { + if (!tensor.defined()) { + return tensor; + } + return at::functionalization::impl::makeFunctional(tensor, cur_level); + }; + + // TODO: we only need to do the following (marked with !) on in-place functions // that modify sizes or strides. There aren't many of them. // If autograd dispatch key: @@ -429,6 +474,31 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap); } + bool should_wrap_functional_outputs = false; + if (cur_key == DispatchKey::Functionalize) { + // Step 1: Detect if we'll need to wrap output tensors + // I really don't like this. + // The functional pass should wrap all output tensors in a FunctionalTensorWrapper + // But it shouldn't performing the wrapping when we print - i.e. if none of the inputs are functional tensors + // HOWEVER, we want the wrapping to trigger on factory functions. + // So we're out of luck if a factory function is triggered during printing. + const auto args = torch::jit::last(stack, op.schema().arguments().size()); + bool any_tensor_args = anyTensors(args, [&](const Tensor& tensor) { return true; }); + bool any_tensor_args_are_functional = anyTensors(args, [&](const Tensor& t) { return functionalAtCurrentLevel(t); }); + if (!any_tensor_args) { + // factory op - hope that we're not printing, and wrap the output + should_wrap_functional_outputs = true; + } + if (any_tensor_args_are_functional) { + // if at least one tensor input is wrapped, that means we're in the functionalization pass. wrap the outputs. + should_wrap_functional_outputs = true; + } + + // Step 2: Unwrap any functional tensor wrappers. + auto args_size = op.schema().arguments().size(); + foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap_functional); + } + // pop the top layer. Put it back on dtor. WithoutTop guard; @@ -444,6 +514,12 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* // Re-dispatch op.callBoxed(stack); + if (should_wrap_functional_outputs) { + auto ret_size = op.schema().returns().size(); + foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap_functional); + } + + // Step 4, 5, 6 if (cur_key == DispatchKey::Autograd) { // Step 4 @@ -474,10 +550,20 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); +} + TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); +} + // TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) { // m.impl("_unwrap_for_grad", native::_unwrap_for_grad); // m.impl("dump_tensor", native::dump_tensor); diff --git a/functorch/csrc/FunctionalTensorWrapper.cpp b/functorch/csrc/FunctionalTensorWrapper.cpp new file mode 100644 index 000000000..faa4be40b --- /dev/null +++ b/functorch/csrc/FunctionalTensorWrapper.cpp @@ -0,0 +1,141 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include +#include + +#include + +namespace at { + +FunctionalTensorWrapper::FunctionalTensorWrapper(Tensor value, int64_t level) + : FunctionalTensorImplBase(value.dtype(), value.device()), + value_(value), + level_(level) +{ + TORCH_INTERNAL_ASSERT(value_.defined()); +} + +void FunctionalTensorWrapper::replace_(const TensorImpl* other_impl) { + auto other_functional = dynamic_cast(other_impl); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_functional != nullptr); + + auto self_unwrapped_impl = value_.unsafeGetTensorImpl(); + auto other_unwrapped_impl = other_functional->value().unsafeGetTensorImpl(); + + if (typeid(*self_unwrapped_impl) == typeid(*other_unwrapped_impl)) { + // It is valid to swap out the metadata on the tensorImpl + // but we can only do that if the two tensor's we're swapping have the same type. + // This allows us to ensure that programs that mutate their inputs + // preserve their semantics under a functionalization pass. + self_unwrapped_impl->replace_(other_unwrapped_impl); + } else { + value_ = other_functional->value(); + } +} + +void FunctionalTensorWrapper::set_size(int64_t dim, int64_t new_size) { + value_.unsafeGetTensorImpl()->set_size(dim, new_size); +} +void FunctionalTensorWrapper::set_stride(int64_t dim, int64_t new_stride) { + value_.unsafeGetTensorImpl()->set_stride(dim, new_stride); +} +void FunctionalTensorWrapper::set_storage_offset(int64_t storage_offset) { + value_.unsafeGetTensorImpl()->set_storage_offset(storage_offset); +} +bool FunctionalTensorWrapper::has_storage() const { + return value_.unsafeGetTensorImpl()->has_storage(); +} +IntArrayRef FunctionalTensorWrapper::sizes() const { + return value_.unsafeGetTensorImpl()->sizes(); +} +int64_t FunctionalTensorWrapper::dim() const { + return value_.unsafeGetTensorImpl()->dim(); +} +const Storage& FunctionalTensorWrapper::storage() const { + return value_.unsafeGetTensorImpl()->storage(); +} +int64_t FunctionalTensorWrapper::numel() const { + return value_.unsafeGetTensorImpl()->numel(); +} +bool FunctionalTensorWrapper::is_contiguous(at::MemoryFormat memory_format) const { + return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); +} +int64_t FunctionalTensorWrapper::storage_offset() const { + return value_.unsafeGetTensorImpl()->storage_offset(); +} +int64_t FunctionalTensorWrapper::size(int64_t d) const { + return value_.unsafeGetTensorImpl()->size(d); +} +int64_t FunctionalTensorWrapper::stride(int64_t d) const { + return value_.unsafeGetTensorImpl()->stride(d); +} +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + // TODO: maybe just don't allow this + return value_.unsafeGetTensorImpl()->shallow_copy_and_detach(version_counter, allow_tensor_metadata_change); +} +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + // TODO: maybe just don't allow this + return value_.unsafeGetTensorImpl()->shallow_copy_and_detach(version_counter, allow_tensor_metadata_change); +} +void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr& impl) { + // TODO: maybe just don't allow this + value_.unsafeGetTensorImpl()->shallow_copy_from(impl); +} +const char* FunctionalTensorWrapper::tensorimpl_type_name() const { + return "FunctionalTensorWrapper"; +} + +namespace functionalization { +namespace impl { + +Tensor makeFunctional(const Tensor& tensor, int64_t level) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!dynamic_cast(tensor.unsafeGetTensorImpl())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!tensor.key_set().has(c10::DispatchKey::Functionalize)); + return at::detail::make_tensor(tensor, level); +} + +c10::optional makeFunctional(const c10::optional& tensor, int64_t level) { + if (tensor.has_value()) { + return makeFunctional(*tensor, level); + } + return c10::nullopt; +} + +c10::List makeFunctional(const c10::List& t_list, int64_t level) { + std::vector functional_tensors; + for (auto& t: t_list.vec()) { + functional_tensors.push_back(makeFunctional(t, level)); + } + return c10::List(functional_tensors); +} + +std::vector makeFunctional(const at::TensorList t_list, int64_t level) { + std::vector functional_tensors; + for (auto& t: t_list) { + functional_tensors.push_back(makeFunctional(t, level)); + } + return functional_tensors; +} + +c10::List> makeFunctional(const c10::List>& t_list, int64_t level) { + std::vector> functional_tensors; + for (auto& t: t_list.vec()) { + functional_tensors.push_back(makeFunctional(t, level)); + } + return c10::List>(functional_tensors); +} + +} // namespace impl +} // namespace functionalization +} // namespace at diff --git a/functorch/csrc/FunctionalTensorWrapper.h b/functorch/csrc/FunctionalTensorWrapper.h new file mode 100644 index 000000000..d64df583e --- /dev/null +++ b/functorch/csrc/FunctionalTensorWrapper.h @@ -0,0 +1,74 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace at { + +struct TORCH_API FunctionalTensorWrapper : public at::FunctionalTensorImplBase { + explicit FunctionalTensorWrapper(Tensor value, int64_t level); + + const Tensor& value() const { return value_; }; + int64_t level() const { return level_; }; + + // Override the FunctionalTensorImplBase method describing how to re-use a tensor in the functionalization pass. + //void replace_(const Tensor& other) override; + void replace_(const TensorImpl* other_impl) override; + + // Override ALL virtual functions on the TensorImpl to call into the wrapped value's implementation + IntArrayRef sizes() const override; + int64_t dim() const override; + bool has_storage() const override; + const Storage& storage() const override; + int64_t numel() const override; + bool is_contiguous(at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const override; + //bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + int64_t storage_offset() const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + int64_t size(int64_t d) const override; + int64_t stride(int64_t d) const override; + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + + private: + const char* tensorimpl_type_name() const override; + + Tensor value_; + int64_t level_; +}; + +TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(const Tensor& tensor) { + auto functional_impl = static_cast(tensor.unsafeGetTensorImpl()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr); + return functional_impl; +} + +namespace functionalization { +namespace impl { + +// Utility functions for the functionalization pass. + +TORCH_API Tensor makeFunctional(const Tensor& tensor, int64_t level); +TORCH_API c10::optional makeFunctional(const c10::optional& tensor); +TORCH_API c10::List makeFunctional(const c10::List& t_list); +TORCH_API std::vector makeFunctional(const TensorList t_list); +TORCH_API c10::List> makeFunctional(const c10::List>& tensor); + +} // namespace impl +} // namespace functionalization +} // namespace at diff --git a/functorch/csrc/TensorWrapper.cpp b/functorch/csrc/TensorWrapper.cpp index 049607785..6a521f8d5 100644 --- a/functorch/csrc/TensorWrapper.cpp +++ b/functorch/csrc/TensorWrapper.cpp @@ -163,6 +163,21 @@ void TensorWrapper::set_storage_offset(int64_t storage_offset) { TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper"); } +void TensorWrapper::replace_(const TensorImpl* other_impl) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(kGradWrapperKey)); + auto wrapper_impl = static_cast(other_impl); + + auto unwrapped_impl_self = value().unsafeGetTensorImpl(); + auto unwrapped_impl_other = wrapper_impl->value().unsafeGetTensorImpl(); + if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) { + // This allows us to retain the program semantic of mutating inputs + unwrapped_impl_self->replace_(unwrapped_impl_other); + } else { + value_ = wrapper_impl->value(); + } + refreshMetadata(); +} + const char* TensorWrapper::tensorimpl_type_name() const { return "TensorWrapper"; } @@ -224,6 +239,10 @@ void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Sta TORCH_LIBRARY_IMPL(_, FT_GRAD_WRAPPER_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_GRAD_WRAPPER_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } +} // namespace functorch } // namespace at diff --git a/functorch/csrc/TensorWrapper.h b/functorch/csrc/TensorWrapper.h index 767ee0e82..2fde0dd6a 100644 --- a/functorch/csrc/TensorWrapper.h +++ b/functorch/csrc/TensorWrapper.h @@ -23,6 +23,8 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl { void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; + // Need to override for functionalization to work. + void replace_(const TensorImpl* other_impl) override; void refreshMetadata(); diff --git a/functorch/csrc/VmapModeRegistrations.cpp b/functorch/csrc/VmapModeRegistrations.cpp index bb455cc46..eb5f13db6 100644 --- a/functorch/csrc/VmapModeRegistrations.cpp +++ b/functorch/csrc/VmapModeRegistrations.cpp @@ -113,6 +113,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { m.impl("randn", randn_mbatching_rule); + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index b22d85cd0..0510989c4 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -18,7 +19,7 @@ namespace at { namespace functorch { -static bool has_level(const Tensor& self, int64_t level) { +static bool has_batched_level(const Tensor& self, int64_t level) { const auto* batched = maybeGetBatchedImpl(self); if (!batched) { return false; @@ -27,10 +28,22 @@ static bool has_level(const Tensor& self, int64_t level) { return bdims.back().level() >= level; } +static bool has_functional_level(const Tensor& self, int64_t level) { + const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + if (!functional) { + return false; + } + return functional->level() >= level; +} + Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { return addBatchDim(self, level, batch_dim); } +Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { + return at::functionalization::impl::makeFunctional(self, level); +} + static std::pair remove_existing_batch_dim( const BatchedTensorImpl* batched, int64_t level) { auto bdims = batched->bdims(); @@ -102,7 +115,7 @@ static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) { // // `out_dim` controls where we should put the batch dimension in the output tensor. Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) { - if (!has_level(self, level)) { + if (!has_batched_level(self, level)) { auto self_sizes = self.sizes(); VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end()); expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size); @@ -110,7 +123,7 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, return result; } - // Must be batched if has_level(self, /*any_level*/) + // Must be batched if has_batched_level(self, /*any_level*/) const auto* batched = maybeGetBatchedImpl(self); TORCH_INTERNAL_ASSERT(batched != nullptr); @@ -121,6 +134,15 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, return result; } +Tensor _unwrap_functional_tensor(const Tensor& self) { + auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + // We only ever call that after popping out of a functionalize() call, in which case the current tensors + // should always be wrapped in a FunctionalTensorWrapper. + TORCH_INTERNAL_ASSERT(functional != nullptr); + functional->sync_(); + return functional->value(); +} + Tensor _wrap_for_grad(const Tensor& self, int64_t level) { // NB: different behavior inside?? // return self; @@ -177,6 +199,16 @@ int64_t _vmap_decrement_nesting() { return layer.layerId(); } +int64_t _func_increment_nesting() { + return initAndPushDynamicLayer(DispatchKey::Functionalize); +} + +int64_t _func_decrement_nesting() { + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Functionalize); + return layer.layerId(); +} + static bool is_batchedtensor(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); return batched != nullptr; @@ -187,6 +219,10 @@ static bool is_gradtrackingtensor(const Tensor& tensor) { return wrapped != nullptr; } +static bool is_functionaltensor(const Tensor& tensor) { + return dynamic_cast(tensor.unsafeGetTensorImpl()) != nullptr; +} + static Tensor get_unwrapped(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { @@ -196,13 +232,21 @@ static Tensor get_unwrapped(const Tensor& tensor) { if (wrapped) { return wrapped->value(); } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->value(); + } TORCH_CHECK(false, "No wrappers present!"); } static int64_t maybe_get_level(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { - return batched->bdims().back().level(); + auto tmp = batched->bdims().back().level(); + if (tmp == -1) { + TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); + } + return tmp; } auto* wrapped = maybeGetTensorWrapper(tensor); if (wrapped) { @@ -212,6 +256,10 @@ static int64_t maybe_get_level(const Tensor& tensor) { // TODO: this is a weird special case... return -2; } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->level(); + } return -1; } @@ -229,8 +277,12 @@ static int64_t maybe_get_bdim(const Tensor& tensor) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim"); m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim"); + m.def("_wrap_functional_tensor", &at::functorch::_wrap_functional_tensor, "add functional tensor"); + m.def("_unwrap_functional_tensor", &at::functorch::_unwrap_functional_tensor, "remove functional tensor"); m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim"); m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim"); + m.def("_func_increment_nesting", &at::functorch::_func_increment_nesting, "functionalization start"); + m.def("_func_decrement_nesting", &at::functorch::_func_decrement_nesting, "functionalization end"); m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim"); m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim"); m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim"); @@ -245,6 +297,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // on Tensors? m.def("is_batchedtensor", &at::functorch::is_batchedtensor); m.def("is_gradtrackingtensor", &at::functorch::is_gradtrackingtensor); + m.def("is_functionaltensor", &at::functorch::is_functionaltensor); m.def("get_unwrapped", &at::functorch::get_unwrapped); m.def("maybe_get_level", &at::functorch::maybe_get_level); m.def("maybe_get_bdim", &at::functorch::maybe_get_bdim);