Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/TensorIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase {
}

bool has_contiguous_first_dim() const {
if (ndim() == 0) {
return true;
}

int num_tensors = ntensors();
for (const auto i : c10::irange(num_tensors)) {
if (strides(i)[0] != element_size(i)) {
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cpu/vec/vec_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <c10/util/TypeCast.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
#include <c10/util/Load.h>

// These macros helped us unify vec_base.h
#ifdef CPU_CAPABILITY_AVX512
Expand Down Expand Up @@ -975,7 +976,7 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) {
#endif
for (const auto i : c10::irange(n)) {
(void)i; //Suppress unused variable warning
*dst = c10::static_cast_with_inter_type<dst_T, src_T>::apply(*src);
*dst = c10::convert<dst_T>(c10::load(src));
src++;
dst++;
}
Expand Down
30 changes: 14 additions & 16 deletions aten/src/ATen/native/cpu/CopyKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,22 +246,20 @@ void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
using dest_t = scalar_t;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(1), "copy_", [&] {
// Note (@zasdfgbnm):
//
// The code below can not be simplified as
// cpu_kernel(iter, c10::static_cast_with_inter_type<dest_t, scalar_t>::apply);
//
// because this would force the compiler to instantiate the inline function and generate a function call in the loop
// instead of inlining it, making all the optimizations like vectorization impossible.
// You can verify this by looking the the symbols of `libtorch_cpu.so`:
//
// readelf -Ws libtorch_cpu.so | grep static_cast_with_inter_type
//
// If done correctly, the above command should have no output.
//
// See: https://github.com/pytorch/pytorch/issues/31271
cpu_kernel(iter, [](scalar_t src) -> dest_t {
return c10::static_cast_with_inter_type<dest_t, scalar_t>::apply(src); });
if (iter.has_contiguous_first_dim()) {
TORCH_INTERNAL_ASSERT(iter.ninputs() == 1);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);

iter.for_each([](char **data, const int64_t *strides, int64_t size) {
auto src = reinterpret_cast<const scalar_t*>(data[1]);
auto dst = reinterpret_cast<dest_t*>(data[0]);
at::vec::convert(src, dst, size);
});
} else {
cpu_kernel(iter, [](scalar_t x) -> dest_t {
return c10::convert<dest_t>(x);
});
}
});
});

Expand Down