-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[fix] vmap: fix segfault on data access #97237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97237
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1a23bb4: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/functorch/test_vmap.py
Outdated
| def test_data_attribute(self): | ||
| def foo(x): | ||
| y = x.data | ||
| y.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the fix, this line fails with
RuntimeError: batched == nullptr INTERNAL ASSERT FAILED at "aten/src/ATen/functorch/Interpreter.cpp":98, please report a bug to PyTorch.
| c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach( | ||
| const c10::VariableVersion& version_counter, | ||
| bool allow_tensor_metadata_change) const { | ||
| DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); | ||
| auto impl = c10::make_intrusive<BatchedTensorImpl>(key_set, value(), bdim(), level()); | ||
| impl->set_version_counter(version_counter); | ||
| return impl; | ||
| } | ||
|
|
||
| c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach( | ||
| c10::VariableVersion&& version_counter, | ||
| bool allow_tensor_metadata_change) const { | ||
| DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); | ||
| auto impl = c10::make_intrusive<BatchedTensorImpl>(key_set, value(), bdim(), level()); | ||
| impl->set_version_counter(version_counter); | ||
| return impl; | ||
| } | ||
|
|
||
| void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) { | ||
| TORCH_CHECK(false, "mutating directly with `.data` inside functorch transform is not allowed."); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an explanation of why the segfault happened without these? (and do we know the full implications of adding shallow_copy_and_detach to BatchedTensorImpl?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In shallow_copy_and_detach of the base TensorImpl, we propagate the key_set which includes FunctorchBatched key. So we create a Tensor which pretends to be a batched tensor but in reality isn't, which leads to problem downstream when we try to get a field or call method exclusive to BatchedTensor.
pytorch/c10/core/TensorImpl.cpp
Lines 782 to 796 in 517a432
| auto impl = c10::make_intrusive<TensorImpl>( | |
| // No need to populate Storage; copy_tensor_metadata will do it for us. | |
| key_set_, | |
| data_type_, | |
| device_opt_); | |
| copy_tensor_metadata( | |
| /*src_impl=*/this, | |
| /*dest_impl=*/impl.get(), | |
| /*version_counter=*/std::forward<VariableVersion>(version_counter), | |
| /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); | |
| impl->refresh_numel(); | |
| impl->refresh_contiguous(); | |
| return impl; | |
| } |
For the example in the issue, the code was failing trying to do batched->value().
pytorch/torch/csrc/functorch/init.cpp
Lines 312 to 317 in 517a432
| static Tensor get_unwrapped(const Tensor& tensor) { | |
| auto* batched = maybeGetBatchedImpl(tensor); | |
| if (batched) { | |
| return batched->value(); | |
| } | |
| auto* wrapped = maybeGetTensorWrapper(tensor); |
As for the implication, I didn't understand the complete semantics of these functions (and would like to discuss offline).
As for why I thought it is ok to support shallow_copy_from (getter for .data) was because it worked with GradTrackingTensor, so to be consistent I thought it makes sense to allow it for BatchedTensor as well.
Though on second thought, I think we should disallow it for both of them.
import torch
def foo(x):
y = x.data
print(y)
y.sum()
return x.sum()
# torch.func.vmap(foo)(torch.randn(3, 3))
torch.func.grad(foo)(torch.randn(3, 3))Output:
GradTrackingTensor(lvl=1, value=
tensor([[-1.6977, 0.6374, 0.0781],
[-0.4140, 1.5172, 0.0473],
[ 0.8435, -0.2261, 0.0345]])
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accessing .data under vmap
Hmm, I think we should disallow accessing .data() under vmap. Under your PR, I get the following:
import torch
def f(x):
return x.data
x = torch.randn([3], requires_grad=True)
y = torch.vmap(f)(x)This returns a Tensor y that does require True, which is not correct -- if we did this in a while loop, it would return Tensors that do not require grad.
setting .data under vmap
We should probably also disallow setting .data() under vmap? Under your PR I get the following:
import torch
def f(x, y):
x.data = y
return x
x = torch.randn([3])
y = torch.randn([3], requires_grad=True)
res = torch.vmap(f)(x, y)
print(res)
RuntimeError: Batching rule not implemented for aten::_has_compatible_shallow_copy_type. We could not generate a fallback.So we probably want to improve the error message some how.
setting .data under grad
The output of your script seems fine -- what is wrong with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding, .data under grad
I had following example in mind where there seems to be conflicting semantics around .data. We disable directly updating .data (second function) but allow access to .data which returns a shallow copy. If we mutate the shallow copy, then x is also updated (and may result in silently incorrect results for x since that mutation doesn't go from autograd).
.data allowing mutation which is opaque to autograd is same semantic as PyTorch eager but currently we have conflicting semantics under grad, vjp, etc.
import torch
# This works.
def foo(x):
y = x.data
y.copy_(torch.zeros(3, 3))
return (x * x).sum()
# FAILS: RuntimeError: false INTERNAL ASSERT FAILED at "aten/src/ATen/functorch/TensorWrapper.cpp":137,
# please report a bug to PyTorch. NYI
def foo(x):
x.data = torch.zeros(3, 3)
return (x * x).sum()
print(torch.func.grad(foo)(torch.randn(3, 3)))| bool allow_tensor_metadata_change) const { | ||
| DispatchKeySet key_set = getKeysToPropagateToWrapper(value()); | ||
| auto impl = c10::make_intrusive<BatchedTensorImpl>(key_set, value(), bdim(), level()); | ||
| impl->set_version_counter(version_counter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| impl->set_version_counter(version_counter); | |
| impl->set_version_counter(std::move(version_counter)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need std::move as version_counter is already c10::VariableVersion&&.
|
@zou3519 as discussed offline, have disabled |
|
|
||
| with self.assertRaisesRegex(RuntimeError, "accessing `data` under vmap transform"): | ||
| torch.func.vmap(foo)(torch.randn(3, 3)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a test for the set_data case and assert that it raises the nice error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test please for the mutating data case, otherwise, this LGTM
|
Done. I had to add a batch rule for |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-bionic-cuda11.8-py3.10-gcc7 / test (functorch, 1, 1, linux.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #97161