Skip to content

Commit

Permalink
[PyTorch Vulkan] fix bug of aten::cat for concatenation of 3D tenso…
Browse files Browse the repository at this point in the history
…rs at channel dim with channels as multiple of 4 (#103718)

Summary:
Pull Request resolved: #103718

The original `cat_feature_mult4ch` assumes input tensors are of 4d and use `tensor.sizes()[1]` to obtain the channel info of the tensor. This will cause bugs when the input tensors are of 3D. We generalize `cat_feature_mult4ch` to make it cover both 3D and 4D.

Test Plan:
Test for 3D tensors with channels as multiple of 4 is show below. Full test result is in P771032677.
```
(base) luwei@luwei-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*cat_3d_dim0_mult4ch_success*"
Building: finished in 0.1 sec (100%) 263/2812 jobs, 0/2812 updated
  Total time: 0.1 sec
BUILD SUCCEEDED
Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
Note: Google Test filter = *cat_3d_dim0_mult4ch_success*
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from VulkanAPITest
[ RUN      ] VulkanAPITest.cat_3d_dim0_mult4ch_success
[       OK ] VulkanAPITest.cat_3d_dim0_mult4ch_success (129 ms)
[----------] 1 test from VulkanAPITest (129 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (129 ms total)
[  PASSED  ] 1 test.
```

Reviewed By: SS-JIA

Differential Revision: D46755034

fbshipit-source-id: 112216ff29bbe50ef7e8723782c7d0322beddc63
  • Loading branch information
copyrightly authored and facebook-github-bot committed Jun 18, 2023
1 parent 15eed5b commit ee6d0b0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
15 changes: 8 additions & 7 deletions aten/src/ATen/native/vulkan/ops/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ Tensor cat_feature(

Tensor cat_feature_mult4ch(
const MaterializedITensorListRef& tensors,
vTensor& v_output) {
vTensor& v_output,
uint32_t ndim) {
api::Context* const context = api::context();

int64_t depth_size_allprior = 0;
int64_t ch_interval = 0;
for (const at::Tensor& tensor : tensors) {
ch_interval += tensor.sizes()[1];
ch_interval += get_dim<Dim4D::Channel>(tensor);
}
const int64_t depth_interval = ch_interval / 4;

Expand All @@ -159,12 +160,13 @@ Tensor cat_feature_mult4ch(
tensor_arg.is_vulkan() ? tensor_arg : tensor_arg.vulkan();
const vTensor& v_self = convert(tensor);

const uint32_t depth_slice = safe_downcast<uint32_t>(tensor.sizes()[1] / 4);
const uint32_t depth_slice =
safe_downcast<uint32_t>(get_dim<Dim4D::Channel>(tensor) / 4);

uvec3 copy_extents{
v_self.extents().data[0u], v_self.extents().data[1u], depth_slice};

for (const auto b : c10::irange(tensor.sizes()[0])) {
for (const auto b : c10::irange(get_dim<Dim4D::Batch>(tensor))) {
src_offset.data[2u] = safe_downcast<uint32_t>(depth_slice * b);
dst_offset.data[2u] =
depth_size_allprior + safe_downcast<uint32_t>(depth_interval * b);
Expand Down Expand Up @@ -267,7 +269,6 @@ Tensor cat_height(

Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
TORCH_CHECK(!tensors.empty(), "Vulkan cat expects at least one tensor");

auto materialized = tensors.materialize();
TORCH_INTERNAL_ASSERT(!materialized.empty(), "Accessing empty array");
const at::Tensor& tensor = materialized[0];
Expand All @@ -283,7 +284,7 @@ Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
t.dim(),
"d");

if (ndim < 3 || t.sizes()[1 - (4u - ndim)] % 4 != 0) {
if (ndim < 3 || get_dim<Dim4D::Channel>(t) % 4 != 0) {
is_mult4ch = false;
}

Expand Down Expand Up @@ -311,7 +312,7 @@ Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
return cat_height(materialized, v_output);
} else if (dim == ndim - 3) {
if (is_mult4ch) {
return cat_feature_mult4ch(materialized, v_output);
return cat_feature_mult4ch(materialized, v_output, ndim);
}
return cat_feature(materialized, v_output);
}
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4311,6 +4311,24 @@ TEST_F(VulkanAPITest, cat_4d_dim3_diffwidth_success) {
ASSERT_TRUE(check);
}

TEST_F(VulkanAPITest, cat_3d_dim0_mult4ch_success) {
// Arrange
const auto in_cpu1 = at::rand({4, 193, 113}, at::device(at::kCPU).dtype(at::kFloat));
const auto in_cpu2 = at::rand({4, 193, 113}, at::device(at::kCPU).dtype(at::kFloat));
const auto in_cpu3 = at::rand({4, 193, 113}, at::device(at::kCPU).dtype(at::kFloat));

// Act
const auto out_cpu = at::cat({in_cpu1, in_cpu2, in_cpu3}, 0);
const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 0);

// Assert
const auto check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
showRtol(out_cpu, out_vulkan.cpu());
}

ASSERT_TRUE(check);
}

TEST_F(VulkanAPITest, cat_3d_dim0_diff_channel_success) {
// Arrange
Expand Down

0 comments on commit ee6d0b0

Please sign in to comment.