Skip to content

Commit

Permalink
[PyTorch] Devirtualize TensorImpl::dim() with macro
Browse files Browse the repository at this point in the history
Pull Request resolved: #49770

Seems like the performance cost of making this commonly-called method virtual isn't worth having use of undefined tensors crash a bit earlier (they'll still fail to dispatch).
ghstack-source-id: 119477211

Differential Revision: [D25687465](https://our.internmc.facebook.com/intern/diff/D25687465/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25687465/)!
  • Loading branch information
swolchok committed Jan 6, 2021
1 parent 395f1af commit 24142b1
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 12 deletions.
4 changes: 0 additions & 4 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) {
AT_ERROR("sparse tensors do not have set_storage_offset");
}

int64_t SparseTensorImpl::dim() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sparse_dim_ + dense_dim_ == TensorImpl::dim());
return sparse_dim_ + dense_dim_;
}
bool SparseTensorImpl::has_storage() const {
return false;
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ struct TORCH_API SparseTensorImpl : public TensorImpl {
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;

int64_t dim() const override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/undefined_tensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TEST(TestUndefined, UndefinedTest) {
ASSERT_EQ(std::string("UndefinedType"), und.toString());

ASSERT_ANY_THROW(und.strides());
ASSERT_ANY_THROW(und.dim());
ASSERT_EQ(und.dim(), 1);
ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5));
ASSERT_ANY_THROW(und.add(und));
ASSERT_ANY_THROW(und.add(ft));
Expand Down
2 changes: 2 additions & 0 deletions c10/core/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ void TensorImpl::release_resources() {
}
}

#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
int64_t TensorImpl::dim() const {
return sizes_.size();
}
#endif

int64_t TensorImpl::size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
Expand Down
9 changes: 8 additions & 1 deletion c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Return the number of dimensions of this tensor. Note that 0-dimension
* represents a Tensor that is a Scalar, e.g., one that has a single element.
*/
virtual int64_t dim() const;
TENSORIMPL_MAYBE_VIRTUAL int64_t dim() const
#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
{
return sizes_.size();
}
#else
;
#endif

/**
* True if this tensor has storage. See storage() for details.
Expand Down
4 changes: 0 additions & 4 deletions c10/core/UndefinedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const {
AT_ERROR("stride(dim) called on an undefined Tensor");
}

int64_t UndefinedTensorImpl::dim() const {
AT_ERROR("dim() called on undefined Tensor");
}

bool UndefinedTensorImpl::has_storage() const {
AT_ERROR("has_storage() called on undefined Tensor");
}
Expand Down
1 change: 0 additions & 1 deletion c10/core/UndefinedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
IntArrayRef strides() const override;
int64_t size(int64_t d) const override;
int64_t stride(int64_t d) const override;
int64_t dim() const override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;
Expand Down

0 comments on commit 24142b1

Please sign in to comment.