Skip to content

Commit

Permalink
[PyTorch] Reapply D25687465: Devirtualize TensorImpl::dim() with macro (
Browse files Browse the repository at this point in the history
#50290)

Summary:
Pull Request resolved: #50290

This was reverted because it landed after D24772023 (b73c018), which
changed the implementation of `dim()`,  without rebasing on top of it,
and thus broke the build.
ghstack-source-id: 119608505

Test Plan: CI

Reviewed By: ezyang

Differential Revision: D25852810

fbshipit-source-id: 9735a095d539a3a6dc530b7b3bb758d4872d05a8
  • Loading branch information
swolchok authored and facebook-github-bot committed Jan 13, 2021
1 parent 21542b4 commit 9ebea77
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 12 deletions.
4 changes: 0 additions & 4 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) {
AT_ERROR("sparse tensors do not have set_storage_offset");
}

int64_t SparseTensorImpl::dim() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sparse_dim_ + dense_dim_ == TensorImpl::dim());
return sparse_dim_ + dense_dim_;
}
bool SparseTensorImpl::has_storage() const {
return false;
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ struct TORCH_API SparseTensorImpl : public TensorImpl {
void set_stride(int64_t dim, int64_t new_stride) override;
void set_storage_offset(int64_t storage_offset) override;

int64_t dim() const override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/undefined_tensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TEST(TestUndefined, UndefinedTest) {
ASSERT_EQ(std::string("UndefinedType"), und.toString());

ASSERT_ANY_THROW(und.strides());
ASSERT_ANY_THROW(und.dim());
ASSERT_EQ(und.dim(), 1);
ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5));
ASSERT_ANY_THROW(und.add(und));
ASSERT_ANY_THROW(und.add(ft));
Expand Down
2 changes: 2 additions & 0 deletions c10/core/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,11 @@ void TensorImpl::release_resources() {
}
}

#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
int64_t TensorImpl::dim() const {
return sizes_and_strides_.size();
}
#endif

int64_t TensorImpl::size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
Expand Down
9 changes: 8 additions & 1 deletion c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Return the number of dimensions of this tensor. Note that 0-dimension
* represents a Tensor that is a Scalar, e.g., one that has a single element.
*/
virtual int64_t dim() const;
TENSORIMPL_MAYBE_VIRTUAL int64_t dim() const
#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
{
return sizes_and_strides_.size();
}
#else
;
#endif

/**
* True if this tensor has storage. See storage() for details.
Expand Down
4 changes: 0 additions & 4 deletions c10/core/UndefinedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ int64_t UndefinedTensorImpl::stride(int64_t d) const {
AT_ERROR("stride(dim) called on an undefined Tensor");
}

int64_t UndefinedTensorImpl::dim() const {
AT_ERROR("dim() called on undefined Tensor");
}

bool UndefinedTensorImpl::has_storage() const {
AT_ERROR("has_storage() called on undefined Tensor");
}
Expand Down
1 change: 0 additions & 1 deletion c10/core/UndefinedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
IntArrayRef strides() const override;
int64_t size(int64_t d) const override;
int64_t stride(int64_t d) const override;
int64_t dim() const override;
bool has_storage() const override;
const Storage& storage() const override;
int64_t storage_offset() const override;
Expand Down

0 comments on commit 9ebea77

Please sign in to comment.