Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

Fixes #97161

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97237

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1a23bb4:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

def test_data_attribute(self):
def foo(x):
y = x.data
y.sum()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the fix, this line fails with

RuntimeError: batched == nullptr INTERNAL ASSERT FAILED at "aten/src/ATen/functorch/Interpreter.cpp":98, please report a bug to PyTorch.

@kshitij12345 kshitij12345 marked this pull request as ready for review March 21, 2023 14:22
Comment on lines 106 to 127
c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
DispatchKeySet key_set = getKeysToPropagateToWrapper(value());
auto impl = c10::make_intrusive<BatchedTensorImpl>(key_set, value(), bdim(), level());
impl->set_version_counter(version_counter);
return impl;
}

c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
DispatchKeySet key_set = getKeysToPropagateToWrapper(value());
auto impl = c10::make_intrusive<BatchedTensorImpl>(key_set, value(), bdim(), level());
impl->set_version_counter(version_counter);
return impl;
}

void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
TORCH_CHECK(false, "mutating directly with `.data` inside functorch transform is not allowed.");
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an explanation of why the segfault happened without these? (and do we know the full implications of adding shallow_copy_and_detach to BatchedTensorImpl?)

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In shallow_copy_and_detach of the base TensorImpl, we propagate the key_set which includes FunctorchBatched key. So we create a Tensor which pretends to be a batched tensor but in reality isn't, which leads to problem downstream when we try to get a field or call method exclusive to BatchedTensor.

auto impl = c10::make_intrusive<TensorImpl>(
// No need to populate Storage; copy_tensor_metadata will do it for us.
key_set_,
data_type_,
device_opt_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/std::forward<VariableVersion>(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}

For the example in the issue, the code was failing trying to do batched->value().

static Tensor get_unwrapped(const Tensor& tensor) {
auto* batched = maybeGetBatchedImpl(tensor);
if (batched) {
return batched->value();
}
auto* wrapped = maybeGetTensorWrapper(tensor);


As for the implication, I didn't understand the complete semantics of these functions (and would like to discuss offline).

As for why I thought it is ok to support shallow_copy_from (getter for .data) was because it worked with GradTrackingTensor, so to be consistent I thought it makes sense to allow it for BatchedTensor as well.

Though on second thought, I think we should disallow it for both of them.

import torch

def foo(x):
    y = x.data
    print(y)
    y.sum()
    return x.sum()

# torch.func.vmap(foo)(torch.randn(3, 3))
torch.func.grad(foo)(torch.randn(3, 3))

Output:

GradTrackingTensor(lvl=1, value=
    tensor([[-1.6977,  0.6374,  0.0781],
            [-0.4140,  1.5172,  0.0473],
            [ 0.8435, -0.2261,  0.0345]])
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accessing .data under vmap

Hmm, I think we should disallow accessing .data() under vmap. Under your PR, I get the following:

import torch
def f(x):
    return x.data

x = torch.randn([3], requires_grad=True)

y = torch.vmap(f)(x)

This returns a Tensor y that does require True, which is not correct -- if we did this in a while loop, it would return Tensors that do not require grad.

setting .data under vmap

We should probably also disallow setting .data() under vmap? Under your PR I get the following:

import torch
def f(x, y):
    x.data = y
    return x

x = torch.randn([3])
y = torch.randn([3], requires_grad=True)

res = torch.vmap(f)(x, y)
print(res)

RuntimeError: Batching rule not implemented for aten::_has_compatible_shallow_copy_type. We could not generate a fallback.

So we probably want to improve the error message some how.

setting .data under grad

The output of your script seems fine -- what is wrong with it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding, .data under grad

I had following example in mind where there seems to be conflicting semantics around .data. We disable directly updating .data (second function) but allow access to .data which returns a shallow copy. If we mutate the shallow copy, then x is also updated (and may result in silently incorrect results for x since that mutation doesn't go from autograd).

.data allowing mutation which is opaque to autograd is same semantic as PyTorch eager but currently we have conflicting semantics under grad, vjp, etc.

import torch

# This works.
def foo(x):
    y = x.data
    y.copy_(torch.zeros(3, 3))
    return (x * x).sum()

# FAILS: RuntimeError: false INTERNAL ASSERT FAILED at "aten/src/ATen/functorch/TensorWrapper.cpp":137,
# please report a bug to PyTorch. NYI
def foo(x):
    x.data = torch.zeros(3, 3)
    return (x * x).sum()

print(torch.func.grad(foo)(torch.randn(3, 3)))

bool allow_tensor_metadata_change) const {
DispatchKeySet key_set = getKeysToPropagateToWrapper(value());
auto impl = c10::make_intrusive<BatchedTensorImpl>(key_set, value(), bdim(), level());
impl->set_version_counter(version_counter);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl->set_version_counter(version_counter);
impl->set_version_counter(std::move(version_counter));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need std::move as version_counter is already c10::VariableVersion&&.

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 21, 2023
@kshitij12345
Copy link
Collaborator Author

@zou3519 as discussed offline, have disabled .data for BatchedTensor.

@kshitij12345 kshitij12345 requested a review from zou3519 March 23, 2023 17:09

with self.assertRaisesRegex(RuntimeError, "accessing `data` under vmap transform"):
torch.func.vmap(foo)(torch.randn(3, 3))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a test for the set_data case and assert that it raises the nice error message

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test please for the mutating data case, otherwise, this LGTM

@kshitij12345
Copy link
Collaborator Author

Done. I had to add a batch rule for _has_compatible_shallow_copy_type which just errors out. PTAL :)

@kshitij12345 kshitij12345 added the release notes: torch.func release notes category for torch.vmap or torch.func.* APIs label Mar 25, 2023
@kshitij12345
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 25, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-bionic-cuda11.8-py3.10-gcc7 / test (functorch, 1, 1, linux.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@kshitij12345
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: torch.func release notes category for torch.vmap or torch.func.* APIs triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Segmentation Fault for vmaped function accessing BatchedTensor.data

6 participants