Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of ViewMeta with symbolic inputs. #125876

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/FunctionalStorageImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace at::functionalization {

ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, is_multi_output, is_as_strided, out_idx);
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}

// Note [Functionalization: Alias Removal Part 2]
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/FunctionalStorageImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ struct ViewMeta {
ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what, this should be true if the std::function closes over SymInts, or something? Spell it out.

bool is_multi_output = false,
bool is_as_strided = false,
int64_t out_idx = 0)
: forward_fn(std::move(forward)),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output),
is_as_strided(is_as_strided) {}
is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {}

std::function<Tensor(const Tensor&, int64_t)> forward_fn;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
Expand All @@ -50,6 +52,9 @@ struct ViewMeta {

bool is_as_strided;

// Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs;

// Returns a copy of the current ViewMeta, if out_idx matches the current
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index.
Expand Down
8 changes: 7 additions & 1 deletion aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
),
value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
was_storage_changed_(base->was_storage_changed_)
was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
Expand All @@ -147,6 +148,7 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
view_metas_ = base->view_metas_; // copy
}
view_metas_.push_back(meta);
maybe_mark_symbolic(meta);
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
}

Expand Down Expand Up @@ -178,6 +180,8 @@ void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::View
view_metas_.push_back(meta);
// Manually track the fact that this tensor recieved a metadata mutation!
has_metadata_mutation_ = true;
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
maybe_mark_symbolic(meta);
// Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
Expand Down Expand Up @@ -257,6 +261,7 @@ void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) {
value_ = other->value_;
generation_ = other->generation_;
view_metas_ = other->view_metas_;
is_symbolic_ = other->is_symbolic_;
// FREEZE the old storage, preventing mutations to it.
// this is a huge pain to handle properly in all cases, so we ban it.
functional_storage_impl()->freeze();
Expand Down Expand Up @@ -388,6 +393,7 @@ void FunctionalTensorWrapper::copy_tensor_metadata(
dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_;
dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_;
dest_impl->was_storage_changed_ = src_impl->was_storage_changed_;
dest_impl->is_symbolic_ = src_impl->is_symbolic_;
dest_impl->generation_ = src_impl->generation_;
dest_impl->view_metas_ = src_impl->view_metas_;
}
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
->are_all_mutations_under_no_grad_or_inference_mode();
}

void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
}

bool is_symbolic() const {
return is_symbolic_;
}

// Runs the forward_fn of every ViewMeta collected in the current instance
// to some other base.
Tensor apply_view_metas(const Tensor& base);
Expand Down Expand Up @@ -237,6 +245,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool is_multi_output_view_ = false;
// Did the tensor experience a set_() call.
bool was_storage_changed_ = false;
// Did the tensor experience any view operation with symbolic int.
bool is_symbolic_ = false;

size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
}
},
/*has_symbolic_inputs=*/false
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
return self;
Expand Down Expand Up @@ -298,13 +299,16 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
tmp_output = at::_unsafe_view_symint(self_, size);
}

bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });

at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view_symint(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
},
/*has_symbolic_inputs=*/has_symbolic_inputs
);

auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
Expand Down
31 changes: 28 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
parametrize,
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.hop_db import hop_db
Expand Down Expand Up @@ -3261,7 +3262,6 @@ def wrapper(g, *args, **kwargs):

return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g))

@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_input_view_meta_replay(self):
@self._compile_and_erase_bases(0)
def f(a):
Expand All @@ -3275,7 +3275,6 @@ def f(a):
str(out.grad_fn.__class__), """<class 'ViewBackward0'>"""
)

@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_intermediate_view_meta_replay(self):
@self._compile_and_erase_bases(0, 1)
def f(a):
Expand All @@ -3295,7 +3294,6 @@ def f(a):
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
)

@patch("functorch.compile.config.view_replay_for_aliased_outputs", True)
def test_output_aliases_output_view_meta_replay(self):
@self._compile_and_erase_bases(1)
def f(a):
Expand All @@ -3311,6 +3309,33 @@ def f(a):
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
)

@skipIfTorchDynamo()
@patch("torch._dynamo.config.assume_static_by_default", False)
def test_dynamic_output_aliases_input_view_meta_replay(self):
# - torch.compile: using it so we can have a SymInt in the FX graph.
# - Compiling with inductor, so that tensor._base isn't tracked.
#
# This should force the use of as_strided in the view reconstruction path.
# The first 2 view-replay paths won't be taken because:
# - target_functional_tensor will be symbolic (_functionalize_is_symbolic call)
# - tensor._base will be None
@torch.compile(backend="inductor")
def f(a, sz):
return a.view(sz), a.view(-1)

inp = torch.ones(2, 2, requires_grad=True)
out1, out2 = f(inp, (4,))

self.assertIsNotNone(out1.grad_fn)
self.assertExpectedInline(
str(out1.grad_fn.__class__), """<class 'AsStridedBackward0'>"""
)

self.assertIsNotNone(out2.grad_fn)
self.assertExpectedInline(
str(out2.grad_fn.__class__), """<class 'ViewBackward0'>"""
)


def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
Expand Down
3 changes: 3 additions & 0 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ def gen_pyi(
"_functionalize_apply_view_metas": [
"def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ..."
],
"_functionalize_is_symbolic": [
"def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..."
],
"_enable_functionalization": [
"def _enable_functionalization(*, reapply_views: _bool = False): ..."
],
Expand Down
44 changes: 16 additions & 28 deletions torch/_functorch/_aot_autograd/functional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,39 +220,27 @@ def patch_requires_grad(out):
# In summary, we use the fact that FunctionalTensorWrapper saves the view
# functions applied to itself (collected during functionalization) so as
# to replay them (view functions) on the aliased_base_tensor.
if config.view_replay_for_aliased_outputs and target_functional_tensor is not None:
if (
config.view_replay_for_aliased_outputs
and target_functional_tensor is not None
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
):
from .schemas import FunctionalTensorMetadataEq

assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq)
functional_tensor = target_functional_tensor.tensor

try:
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
)
except RuntimeError as e:
# NYI for dynamic shapes.
#
# On functionalization, the ViewMeta lambdas will have symbolic shapes.
# When trying to apply those lambdas on concrete tensors, it will fail.
#
# In order for this to work, we should have a way to replace those
# symbolic shapes with concrete numbers.
aot_joint_log.info(
"could not reconstruct view by re-applying a ViewMeta sequence. "
"Fallbacking to reconstruction using as_strided. "
"Reason: %s",
str(e),
)
else:
# If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and
# patch requires_grad flag.
assert out.shape == target_meta_tensor.shape, (
"incorrect out shape after application of ViewMeta sequence: "
f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)"
)
return patch_requires_grad(out)
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
)
# If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and
# patch requires_grad flag.
assert out.shape == target_meta_tensor.shape, (
"incorrect out shape after application of ViewMeta sequence: "
f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)"
)
return patch_requires_grad(out)

# Try to do view-replay if possible.
# fall back to .as_strided() if we can't.
Expand Down
29 changes: 28 additions & 1 deletion torch/csrc/autograd/python_torch_functions_manual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,29 @@ static PyObject* THPVariable__functionalize_sync(
END_HANDLE_TH_ERRORS
}

static PyObject* THPVariable__functionalize_is_symbolic(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_functionalize_is_symbolic(Tensor tensor)"},
/*traceable=*/true);

ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto tensor = r.tensor(0);
TORCH_INTERNAL_ASSERT(
at::functionalization::impl::isFunctionalTensor(tensor));
auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor);
if (impl->is_symbolic()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

static PyObject* THPVariable__functionalize_apply_view_metas(
PyObject* self,
PyObject* args,
Expand All @@ -673,7 +696,7 @@ static PyObject* THPVariable__functionalize_apply_view_metas(
{"_functionalize_apply_view_metas(Tensor tensor, Tensor base)"},
/*traceable=*/true);

ParsedArgs<4> parsed_args;
ParsedArgs<2> parsed_args;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops

auto r = parser.parse(args, kwargs, parsed_args);
auto tensor = r.tensor(0);
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -796,6 +819,10 @@ static PyMethodDef torch_functions_manual[] = {
castPyCFunctionWithKeywords(THPVariable__functionalize_sync),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_functionalize_is_symbolic",
castPyCFunctionWithKeywords(THPVariable__functionalize_is_symbolic),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
nullptr},
{"_functionalize_apply_view_metas",
castPyCFunctionWithKeywords(THPVariable__functionalize_apply_view_metas),
METH_VARARGS | METH_KEYWORDS | METH_STATIC,
Expand Down
Loading
Loading