Skip to content

Commit

Permalink
Add support for channels_last contig format in cat.
Browse files Browse the repository at this point in the history
Summary:
Existing cat implementation produces output tensor in contig format
disregarding in the input memory format. This PR fixes the kernel as
well as op implementation to account for that.

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
kimishpatel committed Jun 17, 2020
1 parent b5bf21a commit 5be24f3
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 19 deletions.
14 changes: 10 additions & 4 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -145,6 +145,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {

// compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
auto first_tensor_mem_format = tensors[0].suggest_memory_format();
for (int i = 0; i < tensors.size(); i++) {
auto const &tensor = tensors[i];
if (should_skip(tensor)) {
Expand All @@ -155,7 +156,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i);
cat_dim_size += tensor.size(dim);

if (!tensor.is_contiguous()) {
if (!tensor.is_contiguous(first_tensor_mem_format)) {
allContiguous = false;
}

Expand All @@ -170,19 +171,24 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// compute the size of the result
auto result_size = notSkippedTensor.sizes().vec();
result_size[dim] = cat_dim_size;
result.resize_(result_size);
result.resize_(result_size, first_tensor_mem_format);
if (result.numel() == 0) {
return result;
}

// fast path for single thread when both inputs and result are contiguous and not empty
allContiguous = allContiguous && result.is_contiguous(first_tensor_mem_format);
bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
allContiguous = allContiguous && result.is_contiguous();
ScalarType dtype = notSkippedTensor.scalar_type();
if (use_serial_kernel && allContiguous && no_type_promotion && (dtype == ScalarType::Double || dtype == ScalarType::Float)) {
cat_serial_stub(kCPU, result, tensors, dim);
return result;
}

int64_t offset = 0;
if (reuse_iterator && result.is_contiguous() && no_type_promotion) {
if (reuse_iterator &&
result.is_contiguous(first_tensor_mem_format) &&
no_type_promotion) {
auto source_slice = notSkippedTensor;
auto slice_dim_size = source_slice.size(dim);
auto result_slice = result.narrow(dim, 0, slice_dim_size);
Expand Down
22 changes: 7 additions & 15 deletions aten/src/ATen/native/cpu/CatKernel.cpp
Expand Up @@ -20,27 +20,19 @@ struct InputMeta {

template <typename scalar_t>
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
auto size = result.sizes().vec();
int64_t outer = 1, inner = 1;
for (int64_t i = 0; i < dim; i++) {
outer *= size[i];
}
for (int64_t i = dim + 1; i < size.size(); i++) {
inner *= size[i];
}
int64_t outer = result.numel() / (result.size(dim) * result.stride(dim));
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (auto const &tensor : tensors) {
inputs.emplace_back(tensor, dim, inner);
inputs.emplace_back(tensor, dim, tensor.stride(dim));
}

using Vec = vec256::Vec256<scalar_t>;
int64_t offset = 0;
for (int64_t i = 0; i < outer; i++) {
scalar_t* result_ptr = result_data;
for (int64_t i = 0; i < outer; ++i) {
for (int64_t j = 0; j < ninputs; j++) {
scalar_t* result_ptr = result_data + offset;
int64_t local_inner = inputs[j].inner_size;
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
if (local_inner < Vec::size()) {
Expand All @@ -57,13 +49,13 @@ void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
input_ptr,
local_inner);
}
offset += local_inner;
result_ptr += local_inner;
}
}
}

void cat_serial_kernel(Tensor& result, TensorList tensors, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_serial_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_contig_kernel", [&]() {
cat_serial_kernel_impl<scalar_t>(result, tensors, dim);
});
}
Expand Down
50 changes: 50 additions & 0 deletions test/test_torch.py
Expand Up @@ -7045,6 +7045,56 @@ def test_cat_out_channels_last(self, device):
res2 = torch.cat((x, y), out=z)
self.assertEqual(res1, res2)

def test_cat_in_channels_last(self, device):
x = torch.randn((4, 15, 8, 8))
y = torch.randn(x.shape)
res1 = torch.cat((x, y), dim=1)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=1)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)

# Size larger than grain size.
x = torch.randn((4, 15, 256, 256))
y = torch.randn(x.shape)
res1 = torch.cat((x, y), dim=1)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=1)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)

# Concat across dim 0
x = torch.randn((4, 15, 8, 8))
y = torch.randn(x.shape)
res1 = torch.cat((x, y), dim=0)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=0)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)

# Concat across dim 2
x = torch.randn((4, 15, 8, 8))
y = torch.randn(x.shape)
res1 = torch.cat((x, y), dim=2)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=2)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)

# Concat across dim 3
x = torch.randn((4, 15, 8, 8))
y = torch.randn(x.shape)
res1 = torch.cat((x, y), dim=3)
x = x.clone().contiguous(memory_format=torch.channels_last)
y = y.clone().contiguous(memory_format=torch.channels_last)
res2 = torch.cat((x, y), dim=3)
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)

@onlyCUDA
def test_cat_preserve_channels_last(self, device):
x = torch.randn((4, 3, 8, 8), device=device)
Expand Down

0 comments on commit 5be24f3

Please sign in to comment.