Skip to content

Commit

Permalink
[StaticRuntime][ATen] Add out variant for narrow_copy (#49502)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49502

It broke the OSS CI the last time I landed it, mostly cuda tests and python bindings.

Similar to permute_out, add the out variant of `aten::narrow` (slice in c2) which does an actual copy. `aten::narrow` creates a view, however, an copy is incurred when we call `input.contiguous` in the ops that follow `aten::narrow`, in `concat_add_mul_replacenan_clip`, `casted_batch_one_hot_lengths`, and `batch_box_cox`.

{F351263599}

Test Plan:
Unit test:

```
buck test //caffe2/aten:math_kernel_test
buck test //caffe2/test:sparse -- test_narrow
```
Benchmark with the adindexer model:
```
bs = 1 is neutral

Before:
I1214 21:32:51.919239 3285258 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.0886948. Iters per second: 11274.6
After:
I1214 21:32:52.492352 3285277 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.0888019. Iters per second: 11261

bs = 20 shows more gains probably because the tensors are bigger and therefore the cost of copying is higher

Before:
I1214 21:20:19.702445 3227229 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.527563. Iters per second: 1895.51
After:
I1214 21:20:20.370173 3227307 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.508734. Iters per second: 1965.67
```

Reviewed By: ajyu

Differential Revision: D25596290

fbshipit-source-id: bff813f29a0fd36fa56d937426a6d3a03f3af977
  • Loading branch information
Hao Lu authored and facebook-github-bot committed Jan 12, 2021
1 parent cf45d65 commit e92e14b
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 4 deletions.
84 changes: 83 additions & 1 deletion aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -791,8 +791,90 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_
return newTensor._coalesced_(self.is_coalesced());
}

Tensor& narrow_copy_dense_cpu_out(
const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output
) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(self.dtype() == output.dtype());

Tensor self_contig = self.contiguous();
const auto self_sizes = self_contig.sizes();

// wrap dim if negative and do bound check
if (dim < 0) {
dim = at::maybe_wrap_dim(dim, self_sizes.size());
} else {
TORCH_CHECK(dim < self_sizes.size());
}

// wrap start and do bound check
const auto cur_size = self_sizes[dim];
if (start != cur_size && start < 0) { // start being the end is valid, but
// not a valid dim specification.
start = at::maybe_wrap_dim(start, cur_size);
}
TORCH_CHECK(
length >= 0 && start <= cur_size - length,
"start (",
start,
") + length (",
length,
") exceeds dimension size (",
cur_size,
").");

// resize output
auto output_sizes = self_sizes.vec();
output_sizes[dim] = length;
at::native::resize_(output, output_sizes);

const int64_t unit = c10::size_from_dim_(dim + 1, self_sizes);
const int64_t num_blocks = c10::size_to_dim_(dim, self_sizes);

const auto itemsize = self_contig.dtype().itemsize();
size_t src_nbytes = itemsize * self_contig.numel();
size_t dst_nbytes = itemsize * output.numel();

size_t src_block_size = unit * self_sizes[dim];
size_t dst_block_size = unit * length;

if (num_blocks == 0 || dst_block_size == 0) {
return output;
}

char* src_bytes = static_cast<char*>(self_contig.data_ptr());
char* dst_bytes = static_cast<char*>(output.data_ptr());

size_t src_block_size_bytes = itemsize * src_block_size;
size_t dst_block_size_bytes = itemsize * dst_block_size;
size_t src_offset = unit * start;

char* src_offset_bytes = src_bytes + itemsize * src_offset;
char* dst_offset_bytes = dst_bytes;

for (size_t i = 0; i < num_blocks; ++i) {
char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes;
char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes) <=
static_cast<void*>(src_bytes + src_nbytes));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes) <=
static_cast<void*>(dst_bytes + dst_nbytes));

memcpy(
local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes);
}
return output;
}

Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length){
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
}

Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
auto output = at::empty_like(self);
return narrow_copy_dense_cpu_out(self, dim, start, length, output);
}

Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
Expand Down
9 changes: 7 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2778,10 +2778,15 @@
DefaultBackend: mvlgamma_

- func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor
variants: method
variants: function, method
dispatch:
CPU, CUDA: narrow_copy_dense
CPU: narrow_copy_dense_cpu
SparseCPU, SparseCUDA: narrow_copy_sparse
DefaultBackend: narrow_copy_dense

- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: narrow_copy_dense_cpu_out

- func: narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)
variants: function, method
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/test/math_kernel_test.cpp
Expand Up @@ -102,3 +102,13 @@ TEST(MathKernelTest, SiluBackward) {
auto math_out = at::native::math_silu_backward(grad_output, input);
ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6);
}

TEST(MathKernelTest, NarrowCopy) {
auto x = rand({5, 8, 7});
for (int64_t dim = 0; dim < 3; ++dim) {
const int64_t start = 1, length = 4;
auto y_ref = x.narrow(dim, start, length);
auto y_test = at::native::narrow_copy_dense(x, dim, start, length);
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
}
}
24 changes: 23 additions & 1 deletion torch/csrc/jit/runtime/static/ops.cpp
Expand Up @@ -33,7 +33,6 @@ bool canRunNatively(Node* n) {
// In alphabetical order
const static std::unordered_set<std::string> native_nodes{
"aten::flatten",
"aten::narrow",
"aten::reshape",
"aten::slice",
"aten::transpose",
Expand Down Expand Up @@ -338,6 +337,29 @@ REGISTER_OPERATOR_FUNCTOR_OPT(
};
});

// The out variant takes precedence over native
REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto self = p_node->Input(0).toTensor(); // self
auto dim = p_node->Input(1).toInt(); // dim
int64_t start = 0;
if (p_node->Input(2).isScalar()) {
start = p_node->Input(2).toInt();
} else {
auto t = p_node->Input(2).toTensor();
start = t.item<int64_t>();
}
auto length = p_node->Input(3).toInt(); // length

if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(self);
}
auto output = p_node->Output(0).toTensor();
output.resize_({0});
at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output);
};
});

std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
auto op_name = n->kind().toQualString();
if (SROperatorRegistry()->Has(op_name)) {
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -557,6 +557,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.mv: lambda input, vec, out=None: -1,
torch.mvlgamma: lambda input, p: -1,
torch.narrow: lambda input, dim, start, length: -1,
torch.narrow_copy: lambda input, dim, start, length: -1,
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
Expand Down

0 comments on commit e92e14b

Please sign in to comment.