Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ functorch/_C.so
t.py
.vscode/
ccache.sh

# Editor temporaries
*.swn
*.swo
*.swp
*.swm
*~
14 changes: 12 additions & 2 deletions functorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import torch
import functools
import textwrap
from . import _C
from functorch._C import (
_func_decrement_nesting,
_func_increment_nesting,
)

from ._src.vmap import vmap
from ._src.vmap import vmap, functionalize
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev, vjpfull
from ._src.make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1
from ._src.make_functional import (
Expand Down Expand Up @@ -70,6 +73,11 @@ def _functorch_str(tensor):
if level == -1:
return _old_str(tensor)

if _C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
# that it's up to date first
tensor.sync_()

value = _C.get_unwrapped(tensor)
value_repr = repr(value)
value_repr = textwrap.indent(value_repr, ' ')
Expand All @@ -79,6 +87,8 @@ def _functorch_str(tensor):
return f'BatchedTensor(lvl={level}, bdim={bdim}, value=\\\n{value_repr})'
if _C.is_gradtrackingtensor(tensor):
return f'GradTrackingTensor(lvl={level}, value=\\\n{value_repr})'
if _C.is_functionaltensor(tensor):
return f'FunctionalTensor(lvl={level}, value=\\\n{value_repr})'

raise ValueError("We don't know how to print this, please file us an issue")

Expand Down
22 changes: 22 additions & 0 deletions functorch/_src/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from functorch._C import (
_add_batch_dim,
_remove_batch_dim,
_wrap_functional_tensor,
_unwrap_functional_tensor,
_vmap_decrement_nesting,
_vmap_increment_nesting,
_func_decrement_nesting,
_func_increment_nesting,
)

in_dims_t = Union[int, Tuple]
Expand Down Expand Up @@ -278,3 +282,21 @@ def wrapped(*args, **kwargs):
finally:
_vmap_decrement_nesting()
return wrapped

def functionalize(func: Callable) -> Callable:
@functools.wraps(func)
def wrapped(*args, **kwargs):
try:
func_level = _func_increment_nesting()
func_args = [_wrap_functional_tensor(x, func_level) if isinstance(x, Tensor) else x for x in args]
func_outputs = func(*func_args, **kwargs)
flattened_outputs, _ = tree_flatten(func_outputs)
for a in func_args:
if isinstance(a, Tensor):
# Call sync_() on the inputs, to ensure that they still get mutated inplace if the original
# program mutated its inputs
a.sync_()
return [_unwrap_functional_tensor(x) if isinstance(x, Tensor) else x for x in flattened_outputs]
finally:
_func_decrement_nesting()
return wrapped
17 changes: 17 additions & 0 deletions functorch/csrc/BatchedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,23 @@ bool BatchedTensorImpl::has_storage() const {
}
#endif

void BatchedTensorImpl::replace_(const TensorImpl* other_impl) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl->key_set().has(DispatchKey::Batched));
auto batched_impl = static_cast<const BatchedTensorImpl*>(other_impl);

auto unwrapped_impl_self = value().unsafeGetTensorImpl();
auto unwrapped_impl_other = batched_impl->value().unsafeGetTensorImpl();
if (typeid(*unwrapped_impl_self) == typeid(*unwrapped_impl_other)) {
// This allows us to retain the program semantic of mutating inputs
unwrapped_impl_self->replace_(unwrapped_impl_other);
} else {
value_ = batched_impl->value();
}
bdims_ = batched_impl->bdims();
checkInvariants();
refreshSizesAndStrides();
}

const char* BatchedTensorImpl::tensorimpl_type_name() const {
return "BatchedTensorImpl";
}
Expand Down
1 change: 1 addition & 0 deletions functorch/csrc/BatchedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ struct BatchedTensorImpl : public c10::TensorImpl {
#ifdef DEBUG
bool has_storage() const override;
#endif
void replace_(const TensorImpl* other_impl) override;

void refreshSizesAndStrides();

Expand Down
2 changes: 2 additions & 0 deletions functorch/csrc/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// // m.impl("new_zeros", new_zeros_batching_rule);
// //
m.impl("contiguous", contiguous_batching_rule);
// We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback
m.impl("replace_", torch::CppFunction::makeFallthrough());
}

}
Expand Down
86 changes: 86 additions & 0 deletions functorch/csrc/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// LICENSE file in the root directory of this source tree.

#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/FunctionalTensorWrapper.h>
#include <functorch/csrc/TensorWrapper.h>
#include <functorch/csrc/BatchedTensorImpl.h>

Expand Down Expand Up @@ -273,6 +274,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
kDynamicLayerFrontModeKey,
kDynamicLayerBackModeKey,
kGradWrapperKey,
DispatchKey::Functionalize,
// DispatchKey::Batched,
kBatchedKey,
DispatchKey::ADInplaceOrView
Expand Down Expand Up @@ -304,6 +306,19 @@ static bool batchedAtCurrentLevel(const Tensor& tensor) {
return batched_at_level == level;
}

static bool functionalAtCurrentLevel(const Tensor& tensor) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
auto layer = dynamicLayerStack.back();
auto level = layer.layerId();

auto* functional = dynamic_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
if (!functional) {
return false;
}
auto functional_level = functional->level();
return functional_level == level;
}

void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
Expand Down Expand Up @@ -342,6 +357,12 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
exclude = exclude.remove(kBatchedKey);
}
include = include.add(kVmapModeKey);
} else if (layer.key() == DispatchKey::Functionalize) {
const auto args = torch::jit::last(stack, op.schema().arguments().size());
if (anyTensors(args, functionalAtCurrentLevel)) {
exclude = exclude.remove(DispatchKey::Functionalize);
}
include = include.add(DispatchKey::Functionalize);
} else {
TORCH_INTERNAL_ASSERT(false);
}
Expand Down Expand Up @@ -407,6 +428,30 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
return makeTensorWrapper(tensor, cur_level);
};

auto unwrap_functional = [&](const Tensor& tensor) {
if (!tensor.defined()) {
return tensor;
}
auto* maybe_functional_wrapper = dynamic_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
if (!maybe_functional_wrapper) {
return tensor;
}
auto tensor_wrapper_level = maybe_functional_wrapper->level();
TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level);
if (tensor_wrapper_level == cur_level) {
return maybe_functional_wrapper->value();
}
return tensor;
};

auto wrap_functional = [&](const Tensor& tensor) {
if (!tensor.defined()) {
return tensor;
}
return at::functionalization::impl::makeFunctional(tensor, cur_level);
};


// TODO: we only need to do the following (marked with !) on in-place functions
// that modify sizes or strides. There aren't many of them.
// If autograd dispatch key:
Expand All @@ -429,6 +474,31 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
}

bool should_wrap_functional_outputs = false;
if (cur_key == DispatchKey::Functionalize) {
// Step 1: Detect if we'll need to wrap output tensors
// I really don't like this.
// The functional pass should wrap all output tensors in a FunctionalTensorWrapper
// But it shouldn't performing the wrapping when we print - i.e. if none of the inputs are functional tensors
// HOWEVER, we want the wrapping to trigger on factory functions.
// So we're out of luck if a factory function is triggered during printing.
const auto args = torch::jit::last(stack, op.schema().arguments().size());
bool any_tensor_args = anyTensors(args, [&](const Tensor& tensor) { return true; });
bool any_tensor_args_are_functional = anyTensors(args, [&](const Tensor& t) { return functionalAtCurrentLevel(t); });
if (!any_tensor_args) {
// factory op - hope that we're not printing, and wrap the output
should_wrap_functional_outputs = true;
}
if (any_tensor_args_are_functional) {
// if at least one tensor input is wrapped, that means we're in the functionalization pass. wrap the outputs.
should_wrap_functional_outputs = true;
}

// Step 2: Unwrap any functional tensor wrappers.
auto args_size = op.schema().arguments().size();
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap_functional);
}

// pop the top layer. Put it back on dtor.
WithoutTop guard;

Expand All @@ -444,6 +514,12 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
// Re-dispatch
op.callBoxed(stack);

if (should_wrap_functional_outputs) {
auto ret_size = op.schema().returns().size();
foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap_functional);
}


// Step 4, 5, 6
if (cur_key == DispatchKey::Autograd) {
// Step 4
Expand Down Expand Up @@ -474,10 +550,20 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
}

TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
// We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback
m.impl("replace_", torch::CppFunction::makeFallthrough());
}

TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}

TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
// We need this for the functionalization pass: replace_ shouldn't enter the boxed fallback
m.impl("replace_", torch::CppFunction::makeFallthrough());
}

// TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
// m.impl("_unwrap_for_grad", native::_unwrap_for_grad);
// m.impl("dump_tensor", native::dump_tensor);
Expand Down
Loading