From c6c5e498535a6278ec0d42b3dc31715fe8ac7426 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 10 Aug 2021 14:41:50 -0700 Subject: [PATCH 1/5] add a functionalize() transform --- functorch/__init__.py | 15 +++++- functorch/_src/vmap.py | 34 ++++++++++++ functorch/csrc/BatchedTensorImpl.cpp | 17 ++++++ functorch/csrc/BatchedTensorImpl.h | 1 + functorch/csrc/BatchingRegistrations.cpp | 2 + functorch/csrc/DynamicLayer.cpp | 39 +++++++++++++- functorch/csrc/TensorWrapper.cpp | 19 +++++++ functorch/csrc/TensorWrapper.h | 2 + functorch/csrc/VmapModeRegistrations.cpp | 2 + functorch/csrc/init.cpp | 68 ++++++++++++++++++++++-- 10 files changed, 191 insertions(+), 8 deletions(-) diff --git a/functorch/__init__.py b/functorch/__init__.py index d3522ed25..6698fa512 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -7,9 +7,13 @@ import torch import functools import textwrap -from . import _C +# from . import _C +from functorch._C import ( + _func_decrement_nesting, + _func_increment_nesting, +) -from ._src.vmap import vmap +from ._src.vmap import vmap, functionalize from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev, vjpfull from ._src.make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1 from ._src.make_functional import ( @@ -68,6 +72,8 @@ def _backward(*args, **kwargs): def _functorch_str(tensor): level = _C.maybe_get_level(tensor) if level == -1: + if _C.is_functionaltensor(tensor): + raise Exception("HELP") return _old_str(tensor) value = _C.get_unwrapped(tensor) @@ -79,6 +85,11 @@ def _functorch_str(tensor): return f'BatchedTensor(lvl={level}, bdim={bdim}, value=\\\n{value_repr})' if _C.is_gradtrackingtensor(tensor): return f'GradTrackingTensor(lvl={level}, value=\\\n{value_repr})' + if _C.is_functionaltensor(tensor): + # NOTE: functional tensor's only have a notion of "level" when + # you use the functionalize() API. Otherwise their level is set to -1. + # That shouldn't matter, since this print function is only used by functorch + return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})' raise ValueError("We don't know how to print this, please file us an issue") diff --git a/functorch/_src/vmap.py b/functorch/_src/vmap.py index b260c0cfa..885fa7a4d 100644 --- a/functorch/_src/vmap.py +++ b/functorch/_src/vmap.py @@ -16,8 +16,12 @@ from functorch._C import ( _add_batch_dim, _remove_batch_dim, + _wrap_functional_tensor, + _unwrap_functional_tensor, _vmap_decrement_nesting, _vmap_increment_nesting, + _func_decrement_nesting, + _func_increment_nesting, ) in_dims_t = Union[int, Tuple] @@ -278,3 +282,33 @@ def wrapped(*args, **kwargs): finally: _vmap_decrement_nesting() return wrapped + +class functionalizer(object): + def __enter__(self): + _func_increment_nesting() + + def __exit__(self, *args): + _func_decrement_nesting() + + def __call__(self, func): + @functools.wraps(func) + def decorate_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + return decorate_func + +def functionalize(func: Callable) -> Callable: + @functools.wraps(func) + def wrapped(*args, **kwargs): + try: + # INVARIANT: when _func_increment_nesting is active, every tensor + # that's exposed to python is wrapped in a FunctionalTensorImpl + # TODO I don't think the above comment is true once functionalize() uses the DynamicLayer machinery + func_level = _func_increment_nesting() + func_args = [_wrap_functional_tensor(x, func_level) if isinstance(x, Tensor) else x for x in args] + func_outputs = func(*func_args, **kwargs) + flattened_outputs, _ = tree_flatten(func_outputs) + return [_unwrap_functional_tensor(x) if isinstance(x, Tensor) else x for x in flattened_outputs] + finally: + _func_decrement_nesting() + return wrapped diff --git a/functorch/csrc/BatchedTensorImpl.cpp b/functorch/csrc/BatchedTensorImpl.cpp index 46893bd3b..57b5f6af1 100644 --- a/functorch/csrc/BatchedTensorImpl.cpp +++ b/functorch/csrc/BatchedTensorImpl.cpp @@ -142,6 +142,23 @@ bool BatchedTensorImpl::has_storage() const { } #endif +void BatchedTensorImpl::replace_(const TensorImpl* other_impl) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(DispatchKey::Batched)); + auto batched_impl = static_cast(other_impl); + + auto unwrapped_impl_self = value().unsafeGetTensorImpl(); + auto unwrapped_impl_other = batched_impl->value().unsafeGetTensorImpl(); + if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) { + // This allows us to retain the program semantic of mutating inputs + unwrapped_impl_self->replace_(unwrapped_impl_other); + } else { + value_ = batched_impl->value(); + } + bdims_ = batched_impl->bdims(); + checkInvariants(); + refreshSizesAndStrides(); +} + const char* BatchedTensorImpl::tensorimpl_type_name() const { return "BatchedTensorImpl"; } diff --git a/functorch/csrc/BatchedTensorImpl.h b/functorch/csrc/BatchedTensorImpl.h index 67444876e..cf6b880d3 100644 --- a/functorch/csrc/BatchedTensorImpl.h +++ b/functorch/csrc/BatchedTensorImpl.h @@ -92,6 +92,7 @@ struct BatchedTensorImpl : public c10::TensorImpl { #ifdef DEBUG bool has_storage() const override; #endif + void replace_(const TensorImpl* other_impl) override; void refreshSizesAndStrides(); diff --git a/functorch/csrc/BatchingRegistrations.cpp b/functorch/csrc/BatchingRegistrations.cpp index dac7101e4..2aa6be9c3 100644 --- a/functorch/csrc/BatchingRegistrations.cpp +++ b/functorch/csrc/BatchingRegistrations.cpp @@ -1418,6 +1418,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { // // m.impl("new_zeros", new_zeros_batching_rule); // // m.impl("contiguous", contiguous_batching_rule); + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } } diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 0444f65d2..83d576d18 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace at { namespace functorch { @@ -273,6 +274,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({ kDynamicLayerFrontModeKey, kDynamicLayerBackModeKey, kGradWrapperKey, + DispatchKey::Functionalize, // DispatchKey::Batched, kBatchedKey, DispatchKey::ADInplaceOrView @@ -304,12 +306,25 @@ static bool batchedAtCurrentLevel(const Tensor& tensor) { return batched_at_level == level; } +static bool functionalAtCurrentLevel(const Tensor& tensor) { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + auto layer = dynamicLayerStack.back(); + auto level = layer.layerId(); + + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (!functional) { + return false; + } + auto functional_level = functional->level(); + return functional_level == level; +} + void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE - if (c10::show_dispatch_trace_enabled()) { + //if (c10::show_dispatch_trace_enabled()) { std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl; - } + //} #endif if (dynamicLayerStack.size() == 0) { sanityCheckStack(op, stack); @@ -342,6 +357,16 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* exclude = exclude.remove(kBatchedKey); } include = include.add(kVmapModeKey); + } else if (layer.key() == DispatchKey::Functionalize) { + // The only reason we need to do this is to print properly - see `_tensor_str` in functorch/__init__.py + // I think things would break if any factory functions were called inside of it though. + // Since we need to explicitly call the Func boxed fallback kernel for factory functions. + // (it looks like anyTensors() does what I want in this case, and returns True if there are no tensors args). + const auto args = torch::jit::last(stack, op.schema().arguments().size()); + if (anyTensors(args, functionalAtCurrentLevel)) { + exclude = exclude.remove(DispatchKey::Functionalize); + } + include = include.add(DispatchKey::Functionalize); } else { TORCH_INTERNAL_ASSERT(false); } @@ -474,10 +499,20 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); +} + TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); +} + // TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) { // m.impl("_unwrap_for_grad", native::_unwrap_for_grad); // m.impl("dump_tensor", native::dump_tensor); diff --git a/functorch/csrc/TensorWrapper.cpp b/functorch/csrc/TensorWrapper.cpp index 049607785..6a521f8d5 100644 --- a/functorch/csrc/TensorWrapper.cpp +++ b/functorch/csrc/TensorWrapper.cpp @@ -163,6 +163,21 @@ void TensorWrapper::set_storage_offset(int64_t storage_offset) { TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper"); } +void TensorWrapper::replace_(const TensorImpl* other_impl) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(kGradWrapperKey)); + auto wrapper_impl = static_cast(other_impl); + + auto unwrapped_impl_self = value().unsafeGetTensorImpl(); + auto unwrapped_impl_other = wrapper_impl->value().unsafeGetTensorImpl(); + if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) { + // This allows us to retain the program semantic of mutating inputs + unwrapped_impl_self->replace_(unwrapped_impl_other); + } else { + value_ = wrapper_impl->value(); + } + refreshMetadata(); +} + const char* TensorWrapper::tensorimpl_type_name() const { return "TensorWrapper"; } @@ -224,6 +239,10 @@ void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Sta TORCH_LIBRARY_IMPL(_, FT_GRAD_WRAPPER_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_GRAD_WRAPPER_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } +} // namespace functorch } // namespace at diff --git a/functorch/csrc/TensorWrapper.h b/functorch/csrc/TensorWrapper.h index 767ee0e82..2fde0dd6a 100644 --- a/functorch/csrc/TensorWrapper.h +++ b/functorch/csrc/TensorWrapper.h @@ -23,6 +23,8 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl { void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; + // Need to override for functionalization to work. + void replace_(const TensorImpl* other_impl) override; void refreshMetadata(); diff --git a/functorch/csrc/VmapModeRegistrations.cpp b/functorch/csrc/VmapModeRegistrations.cpp index bb455cc46..eb5f13db6 100644 --- a/functorch/csrc/VmapModeRegistrations.cpp +++ b/functorch/csrc/VmapModeRegistrations.cpp @@ -113,6 +113,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { m.impl("randn", randn_mbatching_rule); + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index b22d85cd0..db3e25c05 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -18,7 +19,7 @@ namespace at { namespace functorch { -static bool has_level(const Tensor& self, int64_t level) { +static bool has_batched_level(const Tensor& self, int64_t level) { const auto* batched = maybeGetBatchedImpl(self); if (!batched) { return false; @@ -27,10 +28,22 @@ static bool has_level(const Tensor& self, int64_t level) { return bdims.back().level() >= level; } +static bool has_functional_level(const Tensor& self, int64_t level) { + const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + if (!functional) { + return false; + } + return functional->level() >= level; +} + Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { return addBatchDim(self, level, batch_dim); } +Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { + return at::functionalization::makeFunctional(self, level); +} + static std::pair remove_existing_batch_dim( const BatchedTensorImpl* batched, int64_t level) { auto bdims = batched->bdims(); @@ -102,7 +115,7 @@ static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) { // // `out_dim` controls where we should put the batch dimension in the output tensor. Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) { - if (!has_level(self, level)) { + if (!has_batched_level(self, level)) { auto self_sizes = self.sizes(); VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end()); expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size); @@ -110,7 +123,7 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, return result; } - // Must be batched if has_level(self, /*any_level*/) + // Must be batched if has_batched_level(self, /*any_level*/) const auto* batched = maybeGetBatchedImpl(self); TORCH_INTERNAL_ASSERT(batched != nullptr); @@ -121,6 +134,14 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, return result; } +Tensor _unwrap_functional_tensor(const Tensor& self) { + const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + // We only ever call that after popping out of a functionalize() call, in which case the current tensors + // should always be wrapped in a FunctionalTensorImpl. + TORCH_INTERNAL_ASSERT(functional != nullptr); + return functional->value(); +} + Tensor _wrap_for_grad(const Tensor& self, int64_t level) { // NB: different behavior inside?? // return self; @@ -177,6 +198,18 @@ int64_t _vmap_decrement_nesting() { return layer.layerId(); } +int64_t _func_increment_nesting() { + //c10::impl::tls_set_dispatch_key_included(DispatchKey::Functionalize, true); + return initAndPushDynamicLayer(DispatchKey::Functionalize); +} + +int64_t _func_decrement_nesting() { + //c10::impl::tls_set_dispatch_key_included(DispatchKey::Functionalize, false); + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Functionalize); + return layer.layerId(); +} + static bool is_batchedtensor(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); return batched != nullptr; @@ -187,6 +220,10 @@ static bool is_gradtrackingtensor(const Tensor& tensor) { return wrapped != nullptr; } +static bool is_functionaltensor(const Tensor& tensor) { + return dynamic_cast(tensor.unsafeGetTensorImpl()) != nullptr; +} + static Tensor get_unwrapped(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { @@ -196,13 +233,21 @@ static Tensor get_unwrapped(const Tensor& tensor) { if (wrapped) { return wrapped->value(); } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->value(); + } TORCH_CHECK(false, "No wrappers present!"); } static int64_t maybe_get_level(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { - return batched->bdims().back().level(); + auto tmp = batched->bdims().back().level(); + if (tmp == -1) { + TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); + } + return tmp; } auto* wrapped = maybeGetTensorWrapper(tensor); if (wrapped) { @@ -212,6 +257,16 @@ static int64_t maybe_get_level(const Tensor& tensor) { // TODO: this is a weird special case... return -2; } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->level(); + // TODO: make this less hacky? + // The functionalization pass isn't like other functorch passes, + // and has no concept of a level. + // We still need to convey some info here in order to properly print functional tensors. + //return -2; + } + TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); return -1; } @@ -229,8 +284,12 @@ static int64_t maybe_get_bdim(const Tensor& tensor) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim"); m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim"); + m.def("_wrap_functional_tensor", &at::functorch::_wrap_functional_tensor, "add functional tensor"); + m.def("_unwrap_functional_tensor", &at::functorch::_unwrap_functional_tensor, "remove functional tensor"); m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim"); m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim"); + m.def("_func_increment_nesting", &at::functorch::_func_increment_nesting, "functionalization start"); + m.def("_func_decrement_nesting", &at::functorch::_func_decrement_nesting, "functionalization end"); m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim"); m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim"); m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim"); @@ -245,6 +304,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // on Tensors? m.def("is_batchedtensor", &at::functorch::is_batchedtensor); m.def("is_gradtrackingtensor", &at::functorch::is_gradtrackingtensor); + m.def("is_functionaltensor", &at::functorch::is_functionaltensor); m.def("get_unwrapped", &at::functorch::get_unwrapped); m.def("maybe_get_level", &at::functorch::maybe_get_level); m.def("maybe_get_bdim", &at::functorch::maybe_get_bdim); From dcabf132df8892265fe66031ff06b070a5173135 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 10 Aug 2021 14:41:50 -0700 Subject: [PATCH 2/5] add a functionalize() transform --- functorch/__init__.py | 15 +++++- functorch/_src/vmap.py | 34 ++++++++++++ functorch/csrc/BatchedTensorImpl.cpp | 17 ++++++ functorch/csrc/BatchedTensorImpl.h | 1 + functorch/csrc/BatchingRegistrations.cpp | 2 + functorch/csrc/DynamicLayer.cpp | 39 +++++++++++++- functorch/csrc/TensorWrapper.cpp | 19 +++++++ functorch/csrc/TensorWrapper.h | 2 + functorch/csrc/VmapModeRegistrations.cpp | 2 + functorch/csrc/init.cpp | 68 ++++++++++++++++++++++-- 10 files changed, 191 insertions(+), 8 deletions(-) diff --git a/functorch/__init__.py b/functorch/__init__.py index d3522ed25..6698fa512 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -7,9 +7,13 @@ import torch import functools import textwrap -from . import _C +# from . import _C +from functorch._C import ( + _func_decrement_nesting, + _func_increment_nesting, +) -from ._src.vmap import vmap +from ._src.vmap import vmap, functionalize from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev, vjpfull from ._src.make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1 from ._src.make_functional import ( @@ -68,6 +72,8 @@ def _backward(*args, **kwargs): def _functorch_str(tensor): level = _C.maybe_get_level(tensor) if level == -1: + if _C.is_functionaltensor(tensor): + raise Exception("HELP") return _old_str(tensor) value = _C.get_unwrapped(tensor) @@ -79,6 +85,11 @@ def _functorch_str(tensor): return f'BatchedTensor(lvl={level}, bdim={bdim}, value=\\\n{value_repr})' if _C.is_gradtrackingtensor(tensor): return f'GradTrackingTensor(lvl={level}, value=\\\n{value_repr})' + if _C.is_functionaltensor(tensor): + # NOTE: functional tensor's only have a notion of "level" when + # you use the functionalize() API. Otherwise their level is set to -1. + # That shouldn't matter, since this print function is only used by functorch + return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})' raise ValueError("We don't know how to print this, please file us an issue") diff --git a/functorch/_src/vmap.py b/functorch/_src/vmap.py index b260c0cfa..885fa7a4d 100644 --- a/functorch/_src/vmap.py +++ b/functorch/_src/vmap.py @@ -16,8 +16,12 @@ from functorch._C import ( _add_batch_dim, _remove_batch_dim, + _wrap_functional_tensor, + _unwrap_functional_tensor, _vmap_decrement_nesting, _vmap_increment_nesting, + _func_decrement_nesting, + _func_increment_nesting, ) in_dims_t = Union[int, Tuple] @@ -278,3 +282,33 @@ def wrapped(*args, **kwargs): finally: _vmap_decrement_nesting() return wrapped + +class functionalizer(object): + def __enter__(self): + _func_increment_nesting() + + def __exit__(self, *args): + _func_decrement_nesting() + + def __call__(self, func): + @functools.wraps(func) + def decorate_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + return decorate_func + +def functionalize(func: Callable) -> Callable: + @functools.wraps(func) + def wrapped(*args, **kwargs): + try: + # INVARIANT: when _func_increment_nesting is active, every tensor + # that's exposed to python is wrapped in a FunctionalTensorImpl + # TODO I don't think the above comment is true once functionalize() uses the DynamicLayer machinery + func_level = _func_increment_nesting() + func_args = [_wrap_functional_tensor(x, func_level) if isinstance(x, Tensor) else x for x in args] + func_outputs = func(*func_args, **kwargs) + flattened_outputs, _ = tree_flatten(func_outputs) + return [_unwrap_functional_tensor(x) if isinstance(x, Tensor) else x for x in flattened_outputs] + finally: + _func_decrement_nesting() + return wrapped diff --git a/functorch/csrc/BatchedTensorImpl.cpp b/functorch/csrc/BatchedTensorImpl.cpp index 46893bd3b..57b5f6af1 100644 --- a/functorch/csrc/BatchedTensorImpl.cpp +++ b/functorch/csrc/BatchedTensorImpl.cpp @@ -142,6 +142,23 @@ bool BatchedTensorImpl::has_storage() const { } #endif +void BatchedTensorImpl::replace_(const TensorImpl* other_impl) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(DispatchKey::Batched)); + auto batched_impl = static_cast(other_impl); + + auto unwrapped_impl_self = value().unsafeGetTensorImpl(); + auto unwrapped_impl_other = batched_impl->value().unsafeGetTensorImpl(); + if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) { + // This allows us to retain the program semantic of mutating inputs + unwrapped_impl_self->replace_(unwrapped_impl_other); + } else { + value_ = batched_impl->value(); + } + bdims_ = batched_impl->bdims(); + checkInvariants(); + refreshSizesAndStrides(); +} + const char* BatchedTensorImpl::tensorimpl_type_name() const { return "BatchedTensorImpl"; } diff --git a/functorch/csrc/BatchedTensorImpl.h b/functorch/csrc/BatchedTensorImpl.h index 67444876e..cf6b880d3 100644 --- a/functorch/csrc/BatchedTensorImpl.h +++ b/functorch/csrc/BatchedTensorImpl.h @@ -92,6 +92,7 @@ struct BatchedTensorImpl : public c10::TensorImpl { #ifdef DEBUG bool has_storage() const override; #endif + void replace_(const TensorImpl* other_impl) override; void refreshSizesAndStrides(); diff --git a/functorch/csrc/BatchingRegistrations.cpp b/functorch/csrc/BatchingRegistrations.cpp index dac7101e4..2aa6be9c3 100644 --- a/functorch/csrc/BatchingRegistrations.cpp +++ b/functorch/csrc/BatchingRegistrations.cpp @@ -1418,6 +1418,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { // // m.impl("new_zeros", new_zeros_batching_rule); // // m.impl("contiguous", contiguous_batching_rule); + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } } diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 0444f65d2..83d576d18 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace at { namespace functorch { @@ -273,6 +274,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({ kDynamicLayerFrontModeKey, kDynamicLayerBackModeKey, kGradWrapperKey, + DispatchKey::Functionalize, // DispatchKey::Batched, kBatchedKey, DispatchKey::ADInplaceOrView @@ -304,12 +306,25 @@ static bool batchedAtCurrentLevel(const Tensor& tensor) { return batched_at_level == level; } +static bool functionalAtCurrentLevel(const Tensor& tensor) { + auto& dynamicLayerStack = dynamicLayerStackAccessor(); + auto layer = dynamicLayerStack.back(); + auto level = layer.layerId(); + + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (!functional) { + return false; + } + auto functional_level = functional->level(); + return functional_level == level; +} + void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE - if (c10::show_dispatch_trace_enabled()) { + //if (c10::show_dispatch_trace_enabled()) { std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl; - } + //} #endif if (dynamicLayerStack.size() == 0) { sanityCheckStack(op, stack); @@ -342,6 +357,16 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* exclude = exclude.remove(kBatchedKey); } include = include.add(kVmapModeKey); + } else if (layer.key() == DispatchKey::Functionalize) { + // The only reason we need to do this is to print properly - see `_tensor_str` in functorch/__init__.py + // I think things would break if any factory functions were called inside of it though. + // Since we need to explicitly call the Func boxed fallback kernel for factory functions. + // (it looks like anyTensors() does what I want in this case, and returns True if there are no tensors args). + const auto args = torch::jit::last(stack, op.schema().arguments().size()); + if (anyTensors(args, functionalAtCurrentLevel)) { + exclude = exclude.remove(DispatchKey::Functionalize); + } + include = include.add(DispatchKey::Functionalize); } else { TORCH_INTERNAL_ASSERT(false); } @@ -474,10 +499,20 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); +} + TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); +} + // TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) { // m.impl("_unwrap_for_grad", native::_unwrap_for_grad); // m.impl("dump_tensor", native::dump_tensor); diff --git a/functorch/csrc/TensorWrapper.cpp b/functorch/csrc/TensorWrapper.cpp index 049607785..6a521f8d5 100644 --- a/functorch/csrc/TensorWrapper.cpp +++ b/functorch/csrc/TensorWrapper.cpp @@ -163,6 +163,21 @@ void TensorWrapper::set_storage_offset(int64_t storage_offset) { TORCH_INTERNAL_ASSERT(false, "Can't set_storage_offset for TensorWrapper"); } +void TensorWrapper::replace_(const TensorImpl* other_impl) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(kGradWrapperKey)); + auto wrapper_impl = static_cast(other_impl); + + auto unwrapped_impl_self = value().unsafeGetTensorImpl(); + auto unwrapped_impl_other = wrapper_impl->value().unsafeGetTensorImpl(); + if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) { + // This allows us to retain the program semantic of mutating inputs + unwrapped_impl_self->replace_(unwrapped_impl_other); + } else { + value_ = wrapper_impl->value(); + } + refreshMetadata(); +} + const char* TensorWrapper::tensorimpl_type_name() const { return "TensorWrapper"; } @@ -224,6 +239,10 @@ void dead_tensor_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Sta TORCH_LIBRARY_IMPL(_, FT_GRAD_WRAPPER_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction<&dead_tensor_wrapper_fallback>()); } +TORCH_LIBRARY_IMPL(aten, FT_GRAD_WRAPPER_KEY, m) { + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } +} // namespace functorch } // namespace at diff --git a/functorch/csrc/TensorWrapper.h b/functorch/csrc/TensorWrapper.h index 767ee0e82..2fde0dd6a 100644 --- a/functorch/csrc/TensorWrapper.h +++ b/functorch/csrc/TensorWrapper.h @@ -23,6 +23,8 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl { void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; + // Need to override for functionalization to work. + void replace_(const TensorImpl* other_impl) override; void refreshMetadata(); diff --git a/functorch/csrc/VmapModeRegistrations.cpp b/functorch/csrc/VmapModeRegistrations.cpp index bb455cc46..eb5f13db6 100644 --- a/functorch/csrc/VmapModeRegistrations.cpp +++ b/functorch/csrc/VmapModeRegistrations.cpp @@ -113,6 +113,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { m.impl("randn", randn_mbatching_rule); + // We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback + m.impl("replace_", torch::CppFunction::makeFallthrough()); } diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index b22d85cd0..db3e25c05 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -18,7 +19,7 @@ namespace at { namespace functorch { -static bool has_level(const Tensor& self, int64_t level) { +static bool has_batched_level(const Tensor& self, int64_t level) { const auto* batched = maybeGetBatchedImpl(self); if (!batched) { return false; @@ -27,10 +28,22 @@ static bool has_level(const Tensor& self, int64_t level) { return bdims.back().level() >= level; } +static bool has_functional_level(const Tensor& self, int64_t level) { + const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + if (!functional) { + return false; + } + return functional->level() >= level; +} + Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { return addBatchDim(self, level, batch_dim); } +Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { + return at::functionalization::makeFunctional(self, level); +} + static std::pair remove_existing_batch_dim( const BatchedTensorImpl* batched, int64_t level) { auto bdims = batched->bdims(); @@ -102,7 +115,7 @@ static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) { // // `out_dim` controls where we should put the batch dimension in the output tensor. Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, int64_t out_dim) { - if (!has_level(self, level)) { + if (!has_batched_level(self, level)) { auto self_sizes = self.sizes(); VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end()); expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size); @@ -110,7 +123,7 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, return result; } - // Must be batched if has_level(self, /*any_level*/) + // Must be batched if has_batched_level(self, /*any_level*/) const auto* batched = maybeGetBatchedImpl(self); TORCH_INTERNAL_ASSERT(batched != nullptr); @@ -121,6 +134,14 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, return result; } +Tensor _unwrap_functional_tensor(const Tensor& self) { + const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + // We only ever call that after popping out of a functionalize() call, in which case the current tensors + // should always be wrapped in a FunctionalTensorImpl. + TORCH_INTERNAL_ASSERT(functional != nullptr); + return functional->value(); +} + Tensor _wrap_for_grad(const Tensor& self, int64_t level) { // NB: different behavior inside?? // return self; @@ -177,6 +198,18 @@ int64_t _vmap_decrement_nesting() { return layer.layerId(); } +int64_t _func_increment_nesting() { + //c10::impl::tls_set_dispatch_key_included(DispatchKey::Functionalize, true); + return initAndPushDynamicLayer(DispatchKey::Functionalize); +} + +int64_t _func_decrement_nesting() { + //c10::impl::tls_set_dispatch_key_included(DispatchKey::Functionalize, false); + auto layer = popDynamicLayerAndDeleteMetadata(); + TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Functionalize); + return layer.layerId(); +} + static bool is_batchedtensor(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); return batched != nullptr; @@ -187,6 +220,10 @@ static bool is_gradtrackingtensor(const Tensor& tensor) { return wrapped != nullptr; } +static bool is_functionaltensor(const Tensor& tensor) { + return dynamic_cast(tensor.unsafeGetTensorImpl()) != nullptr; +} + static Tensor get_unwrapped(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { @@ -196,13 +233,21 @@ static Tensor get_unwrapped(const Tensor& tensor) { if (wrapped) { return wrapped->value(); } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->value(); + } TORCH_CHECK(false, "No wrappers present!"); } static int64_t maybe_get_level(const Tensor& tensor) { auto* batched = maybeGetBatchedImpl(tensor); if (batched) { - return batched->bdims().back().level(); + auto tmp = batched->bdims().back().level(); + if (tmp == -1) { + TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); + } + return tmp; } auto* wrapped = maybeGetTensorWrapper(tensor); if (wrapped) { @@ -212,6 +257,16 @@ static int64_t maybe_get_level(const Tensor& tensor) { // TODO: this is a weird special case... return -2; } + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (functional) { + return functional->level(); + // TODO: make this less hacky? + // The functionalization pass isn't like other functorch passes, + // and has no concept of a level. + // We still need to convey some info here in order to properly print functional tensors. + //return -2; + } + TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); return -1; } @@ -229,8 +284,12 @@ static int64_t maybe_get_bdim(const Tensor& tensor) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_add_batch_dim", &at::functorch::_add_batch_dim, "add batch dim"); m.def("_remove_batch_dim", &at::functorch::_remove_batch_dim, "remove batch dim"); + m.def("_wrap_functional_tensor", &at::functorch::_wrap_functional_tensor, "add functional tensor"); + m.def("_unwrap_functional_tensor", &at::functorch::_unwrap_functional_tensor, "remove functional tensor"); m.def("_vmap_increment_nesting", &at::functorch::_vmap_increment_nesting, "remove batch dim"); m.def("_vmap_decrement_nesting", &at::functorch::_vmap_decrement_nesting, "remove batch dim"); + m.def("_func_increment_nesting", &at::functorch::_func_increment_nesting, "functionalization start"); + m.def("_func_decrement_nesting", &at::functorch::_func_decrement_nesting, "functionalization end"); m.def("_grad_increment_nesting", &at::functorch::_grad_increment_nesting, "remove batch dim"); m.def("_grad_decrement_nesting", &at::functorch::_grad_decrement_nesting, "remove batch dim"); m.def("_wrap_for_grad", &at::functorch::_wrap_for_grad, "add batch dim"); @@ -245,6 +304,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // on Tensors? m.def("is_batchedtensor", &at::functorch::is_batchedtensor); m.def("is_gradtrackingtensor", &at::functorch::is_gradtrackingtensor); + m.def("is_functionaltensor", &at::functorch::is_functionaltensor); m.def("get_unwrapped", &at::functorch::get_unwrapped); m.def("maybe_get_level", &at::functorch::maybe_get_level); m.def("maybe_get_bdim", &at::functorch::maybe_get_bdim); From 3ab36b68a8897fc159e18645ab9058080aa9a41d Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 13 Aug 2021 11:14:04 -0700 Subject: [PATCH 3/5] move FunctionalTensorWrapper into functorch, add unwrapping and wrapping logic to the LayerBackFallback --- functorch/__init__.py | 5 + functorch/_src/vmap.py | 8 +- functorch/csrc/DynamicLayer.cpp | 63 ++++++++- functorch/csrc/FunctionalTensorWrapper.cpp | 141 +++++++++++++++++++++ functorch/csrc/FunctionalTensorWrapper.h | 73 +++++++++++ functorch/csrc/init.cpp | 25 ++-- 6 files changed, 291 insertions(+), 24 deletions(-) create mode 100644 functorch/csrc/FunctionalTensorWrapper.cpp create mode 100644 functorch/csrc/FunctionalTensorWrapper.h diff --git a/functorch/__init__.py b/functorch/__init__.py index 6698fa512..38cc0e139 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -76,6 +76,11 @@ def _functorch_str(tensor): raise Exception("HELP") return _old_str(tensor) + if _C.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + tensor.sync_() + value = _C.get_unwrapped(tensor) value_repr = repr(value) value_repr = textwrap.indent(value_repr, ' ') diff --git a/functorch/_src/vmap.py b/functorch/_src/vmap.py index 885fa7a4d..de462d45e 100644 --- a/functorch/_src/vmap.py +++ b/functorch/_src/vmap.py @@ -301,13 +301,15 @@ def functionalize(func: Callable) -> Callable: @functools.wraps(func) def wrapped(*args, **kwargs): try: - # INVARIANT: when _func_increment_nesting is active, every tensor - # that's exposed to python is wrapped in a FunctionalTensorImpl - # TODO I don't think the above comment is true once functionalize() uses the DynamicLayer machinery func_level = _func_increment_nesting() func_args = [_wrap_functional_tensor(x, func_level) if isinstance(x, Tensor) else x for x in args] func_outputs = func(*func_args, **kwargs) flattened_outputs, _ = tree_flatten(func_outputs) + for a in func_args: + if isinstance(a, Tensor): + # Call sync_() on the inputs, to ensure that they still get mutated inplace if the original + # program mutated its inputs + a.sync_() return [_unwrap_functional_tensor(x) if isinstance(x, Tensor) else x for x in flattened_outputs] finally: _func_decrement_nesting() diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 83d576d18..27deca4ae 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include +#include #include #include @@ -14,7 +15,6 @@ #include #include #include -#include namespace at { namespace functorch { @@ -311,7 +311,7 @@ static bool functionalAtCurrentLevel(const Tensor& tensor) { auto layer = dynamicLayerStack.back(); auto level = layer.layerId(); - auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); if (!functional) { return false; } @@ -358,10 +358,6 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* } include = include.add(kVmapModeKey); } else if (layer.key() == DispatchKey::Functionalize) { - // The only reason we need to do this is to print properly - see `_tensor_str` in functorch/__init__.py - // I think things would break if any factory functions were called inside of it though. - // Since we need to explicitly call the Func boxed fallback kernel for factory functions. - // (it looks like anyTensors() does what I want in this case, and returns True if there are no tensors args). const auto args = torch::jit::last(stack, op.schema().arguments().size()); if (anyTensors(args, functionalAtCurrentLevel)) { exclude = exclude.remove(DispatchKey::Functionalize); @@ -432,6 +428,30 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* return makeTensorWrapper(tensor, cur_level); }; + auto unwrap_functional = [&](const Tensor& tensor) { + if (!tensor.defined()) { + return tensor; + } + auto* maybe_functional_wrapper = dynamic_cast(tensor.unsafeGetTensorImpl()); + if (!maybe_functional_wrapper) { + return tensor; + } + auto tensor_wrapper_level = maybe_functional_wrapper->level(); + TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level); + if (tensor_wrapper_level == cur_level) { + return maybe_functional_wrapper->value(); + } + return tensor; + }; + + auto wrap_functional = [&](const Tensor& tensor) { + if (!tensor.defined()) { + return tensor; + } + return at::functionalization::impl::makeFunctional(tensor, cur_level); + }; + + // TODO: we only need to do the following (marked with !) on in-place functions // that modify sizes or strides. There aren't many of them. // If autograd dispatch key: @@ -454,6 +474,31 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap); } + bool should_wrap_functional_outputs = false; + if (cur_key == DispatchKey::Functionalize) { + // Step 1: Detect if we'll need to wrap output tensors + // I really don't like this. + // The functional pass should wrap all output tensors in a FunctionalTensorWrapper + // But it shouldn't performing the wrapping when we print - i.e. if none of the inputs are functional tensors + // HOWEVER, we want the wrapping to trigger on factory functions. + // So we're out of luck if a factory function is triggered during printing. + const auto args = torch::jit::last(stack, op.schema().arguments().size()); + bool any_tensor_args = anyTensors(args, [&](const Tensor& tensor) { return true; }); + bool any_tensor_args_are_functional = anyTensors(args, [&](const Tensor& t) { return functionalAtCurrentLevel(t); }); + if (!any_tensor_args) { + // factory op - hope that we're not printing, and wrap the output + should_wrap_functional_outputs = true; + } + if (any_tensor_args_are_functional) { + // if at least one tensor input is wrapped, that means we're in the functionalization pass. wrap the outputs. + should_wrap_functional_outputs = true; + } + + // Step 2: Unwrap any functional tensor wrappers. + auto args_size = op.schema().arguments().size(); + foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap_functional); + } + // pop the top layer. Put it back on dtor. WithoutTop guard; @@ -469,6 +514,12 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* // Re-dispatch op.callBoxed(stack); + if (should_wrap_functional_outputs) { + auto ret_size = op.schema().returns().size(); + foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap_functional); + } + + // Step 4, 5, 6 if (cur_key == DispatchKey::Autograd) { // Step 4 diff --git a/functorch/csrc/FunctionalTensorWrapper.cpp b/functorch/csrc/FunctionalTensorWrapper.cpp new file mode 100644 index 000000000..e85dd9c29 --- /dev/null +++ b/functorch/csrc/FunctionalTensorWrapper.cpp @@ -0,0 +1,141 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include +#include + +#include + +namespace at { + +FunctionalTensorWrapper::FunctionalTensorWrapper(Tensor value, int64_t level) + : FunctionalTensorImplBase(value.dtype(), value.device()), + value_(value), + level_(level) +{ + TORCH_INTERNAL_ASSERT(value_.defined()); +} + +void FunctionalTensorWrapper::replace_(const Tensor& other) { + auto self_impl = value_.unsafeGetTensorImpl(); + auto other_functional = dynamic_cast(other.unsafeGetTensorImpl()); + // new invariant: every time the fucntionalization pass redispatches during functionalize() calls, + // we'll hit the DynamicLayerModeBackFallback which should wrap outputs in a FunctionalTensorWrapper + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_functional != nullptr); + auto other_impl = other_functional->value().unsafeGetTensorImpl(); + if (typeid(*self_impl) == typeid(*other_impl)) { + // It is valid to swap out the metadata on the tensorImpl + // but we can only do that if the two tensor's we're swapping have the same type. + // This allows us to ensure that programs that mutate their inputs + // preserve their semantics under a functionalization pass. + self_impl->replace_(other_impl); + } else { + value_ = other_functional->value(); + } +} + +void FunctionalTensorWrapper::set_size(int64_t dim, int64_t new_size) { + value_.unsafeGetTensorImpl()->set_size(dim, new_size); +} +void FunctionalTensorWrapper::set_stride(int64_t dim, int64_t new_stride) { + value_.unsafeGetTensorImpl()->set_stride(dim, new_stride); +} +void FunctionalTensorWrapper::set_storage_offset(int64_t storage_offset) { + value_.unsafeGetTensorImpl()->set_storage_offset(storage_offset); +} +bool FunctionalTensorWrapper::has_storage() const { + return value_.unsafeGetTensorImpl()->has_storage(); +} +IntArrayRef FunctionalTensorWrapper::sizes() const { + return value_.unsafeGetTensorImpl()->sizes(); +} +int64_t FunctionalTensorWrapper::dim() const { + return value_.unsafeGetTensorImpl()->dim(); +} +const Storage& FunctionalTensorWrapper::storage() const { + return value_.unsafeGetTensorImpl()->storage(); +} +int64_t FunctionalTensorWrapper::numel() const { + return value_.unsafeGetTensorImpl()->numel(); +} +bool FunctionalTensorWrapper::is_contiguous(at::MemoryFormat memory_format) const { + return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); +} +int64_t FunctionalTensorWrapper::storage_offset() const { + return value_.unsafeGetTensorImpl()->storage_offset(); +} +int64_t FunctionalTensorWrapper::size(int64_t d) const { + return value_.unsafeGetTensorImpl()->size(d); +} +int64_t FunctionalTensorWrapper::stride(int64_t d) const { + return value_.unsafeGetTensorImpl()->stride(d); +} +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + // TODO: maybe just don't allow this + return value_.unsafeGetTensorImpl()->shallow_copy_and_detach(version_counter, allow_tensor_metadata_change); +} +c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + // TODO: maybe just don't allow this + return value_.unsafeGetTensorImpl()->shallow_copy_and_detach(version_counter, allow_tensor_metadata_change); +} +void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr& impl) { + // TODO: maybe just don't allow this + value_.unsafeGetTensorImpl()->shallow_copy_from(impl); +} +const char* FunctionalTensorWrapper::tensorimpl_type_name() const { + return "FunctionalTensorWrapper"; +} + +namespace functionalization { +namespace impl { + +Tensor makeFunctional(const Tensor& tensor, int64_t level) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!dynamic_cast(tensor.unsafeGetTensorImpl())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!tensor.key_set().has(c10::DispatchKey::Functionalize)); + return at::detail::make_tensor(tensor, level); +} + +c10::optional makeFunctional(const c10::optional& tensor, int64_t level) { + if (tensor.has_value()) { + return makeFunctional(*tensor, level); + } + return c10::nullopt; +} + +c10::List makeFunctional(const c10::List& t_list, int64_t level) { + std::vector functional_tensors; + for (auto& t: t_list.vec()) { + functional_tensors.push_back(makeFunctional(t, level)); + } + return c10::List(functional_tensors); +} + +std::vector makeFunctional(const at::TensorList t_list, int64_t level) { + std::vector functional_tensors; + for (auto& t: t_list) { + functional_tensors.push_back(makeFunctional(t, level)); + } + return functional_tensors; +} + +c10::List> makeFunctional(const c10::List>& t_list, int64_t level) { + std::vector> functional_tensors; + for (auto& t: t_list.vec()) { + functional_tensors.push_back(makeFunctional(t, level)); + } + return c10::List>(functional_tensors); +} + +} // namespace impl +} // namespace functionalization +} // namespace at diff --git a/functorch/csrc/FunctionalTensorWrapper.h b/functorch/csrc/FunctionalTensorWrapper.h new file mode 100644 index 000000000..ecd43cb35 --- /dev/null +++ b/functorch/csrc/FunctionalTensorWrapper.h @@ -0,0 +1,73 @@ + +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace at { + +struct TORCH_API FunctionalTensorWrapper : public at::FunctionalTensorImplBase { + explicit FunctionalTensorWrapper(Tensor value, int64_t level); + + const Tensor& value() const { return value_; }; + int64_t level() const { return level_; }; + + // Override the FunctionalTensorImplBase method describing how to re-use a tensor in the functionalization pass. + void replace_(const Tensor& other) override; + + // Override ALL virtual functions on the TensorImpl to call into the wrapped value's implementation + IntArrayRef sizes() const override; + int64_t dim() const override; + bool has_storage() const override; + const Storage& storage() const override; + int64_t numel() const override; + bool is_contiguous(at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const override; + //bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + int64_t storage_offset() const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + int64_t size(int64_t d) const override; + int64_t stride(int64_t d) const override; + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + + private: + const char* tensorimpl_type_name() const override; + + Tensor value_; + int64_t level_; +}; + +TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(const Tensor& tensor) { + auto functional_impl = static_cast(tensor.unsafeGetTensorImpl()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr); + return functional_impl; +} + +namespace functionalization { +namespace impl { + +// Utility functions for the functionalization pass. + +TORCH_API Tensor makeFunctional(const Tensor& tensor, int64_t level); +TORCH_API c10::optional makeFunctional(const c10::optional& tensor); +TORCH_API c10::List makeFunctional(const c10::List& t_list); +TORCH_API std::vector makeFunctional(const TensorList t_list); +TORCH_API c10::List> makeFunctional(const c10::List>& tensor); + +} // namespace impl +} // namespace functionalization +} // namespace at diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index db3e25c05..bad9a5c45 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -6,8 +6,8 @@ #include #include -#include +#include #include #include #include @@ -29,7 +29,7 @@ static bool has_batched_level(const Tensor& self, int64_t level) { } static bool has_functional_level(const Tensor& self, int64_t level) { - const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); if (!functional) { return false; } @@ -41,7 +41,7 @@ Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { } Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { - return at::functionalization::makeFunctional(self, level); + return at::functionalization::impl::makeFunctional(self, level); } static std::pair remove_existing_batch_dim( @@ -135,10 +135,11 @@ Tensor _remove_batch_dim(const Tensor& self, int64_t level, int64_t batch_size, } Tensor _unwrap_functional_tensor(const Tensor& self) { - const auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); + auto* functional = dynamic_cast(self.unsafeGetTensorImpl()); // We only ever call that after popping out of a functionalize() call, in which case the current tensors - // should always be wrapped in a FunctionalTensorImpl. + // should always be wrapped in a FunctionalTensorWrapper. TORCH_INTERNAL_ASSERT(functional != nullptr); + functional->sync_(); return functional->value(); } @@ -221,7 +222,7 @@ static bool is_gradtrackingtensor(const Tensor& tensor) { } static bool is_functionaltensor(const Tensor& tensor) { - return dynamic_cast(tensor.unsafeGetTensorImpl()) != nullptr; + return dynamic_cast(tensor.unsafeGetTensorImpl()) != nullptr; } static Tensor get_unwrapped(const Tensor& tensor) { @@ -233,7 +234,7 @@ static Tensor get_unwrapped(const Tensor& tensor) { if (wrapped) { return wrapped->value(); } - auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); if (functional) { return functional->value(); } @@ -245,7 +246,7 @@ static int64_t maybe_get_level(const Tensor& tensor) { if (batched) { auto tmp = batched->bdims().back().level(); if (tmp == -1) { - TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); + TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); } return tmp; } @@ -257,16 +258,10 @@ static int64_t maybe_get_level(const Tensor& tensor) { // TODO: this is a weird special case... return -2; } - auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); + auto* functional = dynamic_cast(tensor.unsafeGetTensorImpl()); if (functional) { return functional->level(); - // TODO: make this less hacky? - // The functionalization pass isn't like other functorch passes, - // and has no concept of a level. - // We still need to convey some info here in order to properly print functional tensors. - //return -2; } - TORCH_INTERNAL_ASSERT(!dynamic_cast(tensor.unsafeGetTensorImpl()) && !at::functorch::isBatchedTensor(tensor)); return -1; } From a7b541b591eb543036719b85b93d76e0477c5728 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 13 Aug 2021 11:58:21 -0700 Subject: [PATCH 4/5] some cleanup --- .gitignore | 7 +++++++ functorch/__init__.py | 6 ------ functorch/_src/vmap.py | 14 -------------- functorch/csrc/DynamicLayer.cpp | 4 ++-- functorch/csrc/init.cpp | 2 -- 5 files changed, 9 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 523f619db..a2ea649d3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,10 @@ functorch/_C.so t.py .vscode/ ccache.sh + +# Editor temporaries +*.swn +*.swo +*.swp +*.swm +*~ diff --git a/functorch/__init__.py b/functorch/__init__.py index 38cc0e139..ac61ce0f5 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -7,7 +7,6 @@ import torch import functools import textwrap -# from . import _C from functorch._C import ( _func_decrement_nesting, _func_increment_nesting, @@ -72,8 +71,6 @@ def _backward(*args, **kwargs): def _functorch_str(tensor): level = _C.maybe_get_level(tensor) if level == -1: - if _C.is_functionaltensor(tensor): - raise Exception("HELP") return _old_str(tensor) if _C.is_functionaltensor(tensor): @@ -91,9 +88,6 @@ def _functorch_str(tensor): if _C.is_gradtrackingtensor(tensor): return f'GradTrackingTensor(lvl={level}, value=\\\n{value_repr})' if _C.is_functionaltensor(tensor): - # NOTE: functional tensor's only have a notion of "level" when - # you use the functionalize() API. Otherwise their level is set to -1. - # That shouldn't matter, since this print function is only used by functorch return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})' raise ValueError("We don't know how to print this, please file us an issue") diff --git a/functorch/_src/vmap.py b/functorch/_src/vmap.py index de462d45e..ed70e6473 100644 --- a/functorch/_src/vmap.py +++ b/functorch/_src/vmap.py @@ -283,20 +283,6 @@ def wrapped(*args, **kwargs): _vmap_decrement_nesting() return wrapped -class functionalizer(object): - def __enter__(self): - _func_increment_nesting() - - def __exit__(self, *args): - _func_decrement_nesting() - - def __call__(self, func): - @functools.wraps(func) - def decorate_func(*args, **kwargs): - with self: - return func(*args, **kwargs) - return decorate_func - def functionalize(func: Callable) -> Callable: @functools.wraps(func) def wrapped(*args, **kwargs): diff --git a/functorch/csrc/DynamicLayer.cpp b/functorch/csrc/DynamicLayer.cpp index 27deca4ae..38203cf8c 100644 --- a/functorch/csrc/DynamicLayer.cpp +++ b/functorch/csrc/DynamicLayer.cpp @@ -322,9 +322,9 @@ static bool functionalAtCurrentLevel(const Tensor& tensor) { void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); #ifdef HAS_TORCH_SHOW_DISPATCH_TRACE - //if (c10::show_dispatch_trace_enabled()) { + if (c10::show_dispatch_trace_enabled()) { std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl; - //} + } #endif if (dynamicLayerStack.size() == 0) { sanityCheckStack(op, stack); diff --git a/functorch/csrc/init.cpp b/functorch/csrc/init.cpp index bad9a5c45..0510989c4 100644 --- a/functorch/csrc/init.cpp +++ b/functorch/csrc/init.cpp @@ -200,12 +200,10 @@ int64_t _vmap_decrement_nesting() { } int64_t _func_increment_nesting() { - //c10::impl::tls_set_dispatch_key_included(DispatchKey::Functionalize, true); return initAndPushDynamicLayer(DispatchKey::Functionalize); } int64_t _func_decrement_nesting() { - //c10::impl::tls_set_dispatch_key_included(DispatchKey::Functionalize, false); auto layer = popDynamicLayerAndDeleteMetadata(); TORCH_INTERNAL_ASSERT(layer.key() == DispatchKey::Functionalize); return layer.layerId(); From 75261bce2d0d1babc56671538b7592ed7658927d Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 13 Aug 2021 17:23:16 -0700 Subject: [PATCH 5/5] just use one replace_() method --- functorch/csrc/FunctionalTensorWrapper.cpp | 16 ++++++++-------- functorch/csrc/FunctionalTensorWrapper.h | 3 ++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/functorch/csrc/FunctionalTensorWrapper.cpp b/functorch/csrc/FunctionalTensorWrapper.cpp index e85dd9c29..faa4be40b 100644 --- a/functorch/csrc/FunctionalTensorWrapper.cpp +++ b/functorch/csrc/FunctionalTensorWrapper.cpp @@ -22,19 +22,19 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(Tensor value, int64_t level) TORCH_INTERNAL_ASSERT(value_.defined()); } -void FunctionalTensorWrapper::replace_(const Tensor& other) { - auto self_impl = value_.unsafeGetTensorImpl(); - auto other_functional = dynamic_cast(other.unsafeGetTensorImpl()); - // new invariant: every time the fucntionalization pass redispatches during functionalize() calls, - // we'll hit the DynamicLayerModeBackFallback which should wrap outputs in a FunctionalTensorWrapper +void FunctionalTensorWrapper::replace_(const TensorImpl* other_impl) { + auto other_functional = dynamic_cast(other_impl); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_functional != nullptr); - auto other_impl = other_functional->value().unsafeGetTensorImpl(); - if (typeid(*self_impl) == typeid(*other_impl)) { + + auto self_unwrapped_impl = value_.unsafeGetTensorImpl(); + auto other_unwrapped_impl = other_functional->value().unsafeGetTensorImpl(); + + if (typeid(*self_unwrapped_impl) == typeid(*other_unwrapped_impl)) { // It is valid to swap out the metadata on the tensorImpl // but we can only do that if the two tensor's we're swapping have the same type. // This allows us to ensure that programs that mutate their inputs // preserve their semantics under a functionalization pass. - self_impl->replace_(other_impl); + self_unwrapped_impl->replace_(other_unwrapped_impl); } else { value_ = other_functional->value(); } diff --git a/functorch/csrc/FunctionalTensorWrapper.h b/functorch/csrc/FunctionalTensorWrapper.h index ecd43cb35..d64df583e 100644 --- a/functorch/csrc/FunctionalTensorWrapper.h +++ b/functorch/csrc/FunctionalTensorWrapper.h @@ -20,7 +20,8 @@ struct TORCH_API FunctionalTensorWrapper : public at::FunctionalTensorImplBase { int64_t level() const { return level_; }; // Override the FunctionalTensorImplBase method describing how to re-use a tensor in the functionalization pass. - void replace_(const Tensor& other) override; + //void replace_(const Tensor& other) override; + void replace_(const TensorImpl* other_impl) override; // Override ALL virtual functions on the TensorImpl to call into the wrapped value's implementation IntArrayRef sizes() const override;