diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 2f0719db4dc8..568856b6e241 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -791,8 +791,90 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +Tensor& narrow_copy_dense_cpu_out( + const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output +) { + TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(self.dtype() == output.dtype()); + + Tensor self_contig = self.contiguous(); + const auto self_sizes = self_contig.sizes(); + + // wrap dim if negative and do bound check + if (dim < 0) { + dim = at::maybe_wrap_dim(dim, self_sizes.size()); + } else { + TORCH_CHECK(dim < self_sizes.size()); + } + + // wrap start and do bound check + const auto cur_size = self_sizes[dim]; + if (start != cur_size && start < 0) { // start being the end is valid, but + // not a valid dim specification. + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + + // resize output + auto output_sizes = self_sizes.vec(); + output_sizes[dim] = length; + at::native::resize_(output, output_sizes); + + const int64_t unit = c10::size_from_dim_(dim + 1, self_sizes); + const int64_t num_blocks = c10::size_to_dim_(dim, self_sizes); + + const auto itemsize = self_contig.dtype().itemsize(); + size_t src_nbytes = itemsize * self_contig.numel(); + size_t dst_nbytes = itemsize * output.numel(); + + size_t src_block_size = unit * self_sizes[dim]; + size_t dst_block_size = unit * length; + + if (num_blocks == 0 || dst_block_size == 0) { + return output; + } + + char* src_bytes = static_cast(self_contig.data_ptr()); + char* dst_bytes = static_cast(output.data_ptr()); + + size_t src_block_size_bytes = itemsize * src_block_size; + size_t dst_block_size_bytes = itemsize * dst_block_size; + size_t src_offset = unit * start; + + char* src_offset_bytes = src_bytes + itemsize * src_offset; + char* dst_offset_bytes = dst_bytes; + + for (size_t i = 0; i < num_blocks; ++i) { + char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes; + char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(local_src_offset_bytes + dst_block_size_bytes) <= + static_cast(src_bytes + src_nbytes)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + static_cast(local_dst_offset_bytes + dst_block_size_bytes) <= + static_cast(dst_bytes + dst_nbytes)); + + memcpy( + local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes); + } + return output; +} + Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length){ - return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); + return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); +} + +Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ + auto output = at::empty_like(self); + return narrow_copy_dense_cpu_out(self, dim, start, length, output); } Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fb2ee2fecc16..45c649022c0b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2778,10 +2778,15 @@ DefaultBackend: mvlgamma_ - func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor - variants: method + variants: function, method dispatch: - CPU, CUDA: narrow_copy_dense + CPU: narrow_copy_dense_cpu SparseCPU, SparseCUDA: narrow_copy_sparse + DefaultBackend: narrow_copy_dense + +- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU: narrow_copy_dense_cpu_out - func: narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a) variants: function, method diff --git a/aten/src/ATen/test/math_kernel_test.cpp b/aten/src/ATen/test/math_kernel_test.cpp index 2fab848cbcf9..6b5657d21a25 100644 --- a/aten/src/ATen/test/math_kernel_test.cpp +++ b/aten/src/ATen/test/math_kernel_test.cpp @@ -102,3 +102,13 @@ TEST(MathKernelTest, SiluBackward) { auto math_out = at::native::math_silu_backward(grad_output, input); ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6); } + +TEST(MathKernelTest, NarrowCopy) { + auto x = rand({5, 8, 7}); + for (int64_t dim = 0; dim < 3; ++dim) { + const int64_t start = 1, length = 4; + auto y_ref = x.narrow(dim, start, length); + auto y_test = at::native::narrow_copy_dense(x, dim, start, length); + ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0); + } +} diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index fb8893040ceb..6fd7c7bb2859 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -33,7 +33,6 @@ bool canRunNatively(Node* n) { // In alphabetical order const static std::unordered_set native_nodes{ "aten::flatten", - "aten::narrow", "aten::reshape", "aten::slice", "aten::transpose", @@ -338,6 +337,29 @@ REGISTER_OPERATOR_FUNCTOR_OPT( }; }); +// The out variant takes precedence over native +REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto self = p_node->Input(0).toTensor(); // self + auto dim = p_node->Input(1).toInt(); // dim + int64_t start = 0; + if (p_node->Input(2).isScalar()) { + start = p_node->Input(2).toInt(); + } else { + auto t = p_node->Input(2).toTensor(); + start = t.item(); + } + auto length = p_node->Input(3).toInt(); // length + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = create_empty_from(self); + } + auto output = p_node->Output(0).toTensor(); + output.resize_({0}); + at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output); + }; +}); + std::function getOutOfPlaceOperation(Node* n) { auto op_name = n->kind().toQualString(); if (SROperatorRegistry()->Has(op_name)) { diff --git a/torch/overrides.py b/torch/overrides.py index 795889659359..31b0b39555c8 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -557,6 +557,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.mv: lambda input, vec, out=None: -1, torch.mvlgamma: lambda input, p: -1, torch.narrow: lambda input, dim, start, length: -1, + torch.narrow_copy: lambda input, dim, start, length: -1, torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1, torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,