Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
OP_DECOMPOSE(feature_dropout_);
}

void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE2(__and__, Scalar);
OP_DECOMPOSE2(__and__, Tensor);
Expand Down Expand Up @@ -332,6 +336,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE2(less_equal, Scalar);
OP_DECOMPOSE2(less, Scalar);
OP_DECOMPOSE2(not_equal, Scalar);
m.impl("_has_compatible_shallow_copy_type", torch::CppFunction::makeFromBoxedFunction<&unsupportedData>());
}

}}
18 changes: 18 additions & 0 deletions aten/src/ATen/functorch/BatchedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const {
return "BatchedTensorImpl";
}

c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed");
return nullptr;
}

c10::intrusive_ptr<TensorImpl> BatchedTensorImpl::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed");
return nullptr;
}

void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
}

Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) {
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
auto* batched = maybeGetBatchedImpl(tensor);
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/functorch/BatchedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
void set_size(int64_t dim, int64_t new_size) override;
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
#ifdef DEBUG
bool has_storage() const override;
#endif
Expand Down
15 changes: 15 additions & 0 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,21 @@ def f(x, y):
self.assertEqual(out, f(None, y))
self.assertEqual(out_dims, (None, None, None))

def test_data_attribute(self):
def foo(x):
y = x.data
return x

with self.assertRaisesRegex(RuntimeError, "accessing `data` under vmap transform"):
torch.func.vmap(foo)(torch.randn(3, 3))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a test for the set_data case and assert that it raises the nice error message

def foo(x):
x.data = torch.ones(3, 3)
return x

with self.assertRaisesRegex(RuntimeError, "mutating directly with `.data` under vmap"):
torch.func.vmap(foo)(torch.randn(3, 3))


def slice_inputs(inputs, bdims, i):
result = []
Expand Down