Skip to content

Commit

Permalink
Update on "[PyTorch] Use .sizes() isntead of .size() in cat_serial_ke…
Browse files Browse the repository at this point in the history
…rnel_impl"

As with previous diff, .sizes() is strictly more efficient.

Differential Revision: [D25546409](https://our.internmc.facebook.com/intern/diff/D25546409/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25546409/)!

[ghstack-poisoned]
  • Loading branch information
swolchok committed Dec 15, 2020
1 parent 55090d2 commit 8abbcf1
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cpu/CatKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct InputMeta {

template <typename scalar_t>
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl");
int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]);
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
Expand Down

0 comments on commit 8abbcf1

Please sign in to comment.