Skip to content

Commit

Permalink
[PyTorch] Use .sizes() isntead of .size() in cat_serial_kernel_impl
Browse files Browse the repository at this point in the history
Pull Request resolved: #49371

As with previous diff, .sizes() is strictly more efficient.
ghstack-source-id: 118627223

Differential Revision: [D25546409](https://our.internmc.facebook.com/intern/diff/D25546409/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25546409/)!
  • Loading branch information
swolchok committed Dec 15, 2020
1 parent df4353a commit c755fee
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions aten/src/ATen/native/cpu/CatKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ struct InputMeta {

InputMeta(const Tensor& t, int64_t dim, int64_t inner)
: data_ptr(t.data_ptr())
, inner_size(t.size(dim) * inner) {}
, inner_size(t.sizes()[dim] * inner) {}
};

template <typename scalar_t>
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
int64_t outer = result.numel() / (result.size(dim) * result.stride(dim));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl");
int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]);
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (auto const &tensor : tensors) {
inputs.emplace_back(tensor, dim, result.stride(dim));
inputs.emplace_back(tensor, dim, result.strides()[dim]);
}

using Vec = vec256::Vec256<scalar_t>;
Expand Down

0 comments on commit c755fee

Please sign in to comment.