From d49d9fb8362a0f7f456926adb96523da96a4ca42 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 21 Mar 2023 10:41:16 +0000 Subject: [PATCH 1/6] [fix] vmap: fix segfault on data access --- aten/src/ATen/functorch/BatchedTensorImpl.cpp | 22 +++++++++++++++++++ aten/src/ATen/functorch/BatchedTensorImpl.h | 7 ++++++ 2 files changed, 29 insertions(+) diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index c5d6eb34030d..93f696b5ed9d 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -103,6 +103,28 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const { return "BatchedTensorImpl"; } +c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); + auto impl = c10::make_intrusive(key_set, value(), bdim(), level()); + impl->set_version_counter(version_counter); + return impl; +} + +c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); + auto impl = c10::make_intrusive(key_set, value(), bdim(), level()); + impl->set_version_counter(version_counter); + return impl; +} + +void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { + TORCH_CHECK(false, "mutating directly with `.data` inside functorch transform is not allowed."); +} + Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) { DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor); auto* batched = maybeGetBatchedImpl(tensor); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index b61edd986580..82173e071987 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -71,6 +71,13 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; #ifdef DEBUG bool has_storage() const override; #endif From 2020d6bbc5a030e3dc4925a943afb892f4516692 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 21 Mar 2023 10:41:49 +0000 Subject: [PATCH 2/6] add test --- test/functorch/test_vmap.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 24ad586a19a1..ec659d10530d 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -1168,6 +1168,13 @@ def f(x, y): self.assertEqual(out, f(None, y)) self.assertEqual(out_dims, (None, None, None)) + def test_data_attribute(self): + def foo(x): + y = x.data + return x + + torch.func.vmap(foo)(torch.randn(3, 3)) + def slice_inputs(inputs, bdims, i): result = [] From 68710fe5b2452db3b9ffdf33fe9c4260774b65b5 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 21 Mar 2023 11:05:57 +0000 Subject: [PATCH 3/6] add op to make sure test fails --- test/functorch/test_vmap.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index ec659d10530d..cb4ff918b5c8 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -1171,6 +1171,7 @@ def f(x, y): def test_data_attribute(self): def foo(x): y = x.data + y.sum() return x torch.func.vmap(foo)(torch.randn(3, 3)) From d5c87886f9a6e3fc0067ccd920881639b891baeb Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 23 Mar 2023 17:06:43 +0000 Subject: [PATCH 4/6] disable data under vmap transform --- aten/src/ATen/functorch/BatchedTensorImpl.cpp | 12 ++++-------- test/functorch/test_vmap.py | 3 ++- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index 93f696b5ed9d..ac8b93bf124f 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -106,19 +106,15 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const { c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { - DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); - auto impl = c10::make_intrusive(key_set, value(), bdim(), level()); - impl->set_version_counter(version_counter); - return impl; + TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed"); + return nullptr; } c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { - DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); - auto impl = c10::make_intrusive(key_set, value(), bdim(), level()); - impl->set_version_counter(version_counter); - return impl; + TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed"); + return nullptr; } void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index cb4ff918b5c8..e38cb61ff694 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -1174,7 +1174,8 @@ def foo(x): y.sum() return x - torch.func.vmap(foo)(torch.randn(3, 3)) + with self.assertRaisesRegex(RuntimeError, "accessing `data` under vmap transform"): + torch.func.vmap(foo)(torch.randn(3, 3)) def slice_inputs(inputs, bdims, i): From 0c2126f559cf23659744c9bc2096d1771ab4dbcf Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 24 Mar 2023 05:47:08 +0000 Subject: [PATCH 5/6] update test and changes --- aten/src/ATen/functorch/BatchRulesDynamic.cpp | 5 +++++ aten/src/ATen/functorch/BatchedTensorImpl.cpp | 2 +- test/functorch/test_vmap.py | 8 +++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDynamic.cpp b/aten/src/ATen/functorch/BatchRulesDynamic.cpp index a85d7f18953f..491d85483d82 100644 --- a/aten/src/ATen/functorch/BatchRulesDynamic.cpp +++ b/aten/src/ATen/functorch/BatchRulesDynamic.cpp @@ -61,6 +61,10 @@ void unsupportedAllclose(const c10::OperatorHandle& op, torch::jit::Stack* stack "support over at github.com/pytorch/functorch/issues/275"); } +void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed."); +} + TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { UNSUPPORTED_DYNAMIC(nonzero); UNSUPPORTED_DYNAMIC(where); @@ -72,6 +76,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { m.impl("item", torch::CppFunction::makeFromBoxedFunction<&unsupportedItem>()); m.impl("is_nonzero", torch::CppFunction::makeFromBoxedFunction<&unsupportedIsNonzero>()); m.impl("allclose", torch::CppFunction::makeFromBoxedFunction<&unsupportedAllclose>()); + m.impl("_has_compatible_shallow_copy_type", torch::CppFunction::makeFromBoxedFunction<&unsupportedData>()); } }} diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index ac8b93bf124f..0b0f58e23bc6 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -118,7 +118,7 @@ c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( } void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { - TORCH_CHECK(false, "mutating directly with `.data` inside functorch transform is not allowed."); + TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed."); } Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) { diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index e38cb61ff694..edb2cc38b13c 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -1171,12 +1171,18 @@ def f(x, y): def test_data_attribute(self): def foo(x): y = x.data - y.sum() return x with self.assertRaisesRegex(RuntimeError, "accessing `data` under vmap transform"): torch.func.vmap(foo)(torch.randn(3, 3)) + def foo(x): + x.data = torch.ones(3, 3) + return x + + with self.assertRaisesRegex(RuntimeError, "mutating directly with `.data` under vmap"): + torch.func.vmap(foo)(torch.randn(3, 3)) + def slice_inputs(inputs, bdims, i): result = [] From 1a23bb46b2c10ab2881be6db2e43a801930c6fbb Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 27 Mar 2023 18:12:42 +0000 Subject: [PATCH 6/6] fix --- aten/src/ATen/functorch/BatchRulesDecompositions.cpp | 5 +++++ aten/src/ATen/functorch/BatchRulesDynamic.cpp | 5 ----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index d0662890789b..aad5d4a62791 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -25,6 +25,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { OP_DECOMPOSE(feature_dropout_); } +void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed."); +} + TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(__and__, Scalar); OP_DECOMPOSE2(__and__, Tensor); @@ -332,6 +336,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(less_equal, Scalar); OP_DECOMPOSE2(less, Scalar); OP_DECOMPOSE2(not_equal, Scalar); + m.impl("_has_compatible_shallow_copy_type", torch::CppFunction::makeFromBoxedFunction<&unsupportedData>()); } }} diff --git a/aten/src/ATen/functorch/BatchRulesDynamic.cpp b/aten/src/ATen/functorch/BatchRulesDynamic.cpp index 491d85483d82..a85d7f18953f 100644 --- a/aten/src/ATen/functorch/BatchRulesDynamic.cpp +++ b/aten/src/ATen/functorch/BatchRulesDynamic.cpp @@ -61,10 +61,6 @@ void unsupportedAllclose(const c10::OperatorHandle& op, torch::jit::Stack* stack "support over at github.com/pytorch/functorch/issues/275"); } -void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) { - TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed."); -} - TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { UNSUPPORTED_DYNAMIC(nonzero); UNSUPPORTED_DYNAMIC(where); @@ -76,7 +72,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { m.impl("item", torch::CppFunction::makeFromBoxedFunction<&unsupportedItem>()); m.impl("is_nonzero", torch::CppFunction::makeFromBoxedFunction<&unsupportedIsNonzero>()); m.impl("allclose", torch::CppFunction::makeFromBoxedFunction<&unsupportedAllclose>()); - m.impl("_has_compatible_shallow_copy_type", torch::CppFunction::makeFromBoxedFunction<&unsupportedData>()); } }}