Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch Vulkan] fix bug of aten::cat for concatenation of 3D tensors at channel dim with channels as multiple of 4 #103718

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions aten/src/ATen/native/vulkan/ops/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ Tensor cat_feature(

Tensor cat_feature_mult4ch(
const MaterializedITensorListRef& tensors,
vTensor& v_output) {
vTensor& v_output,
uint32_t ndim) {
api::Context* const context = api::context();

int64_t depth_size_allprior = 0;
int64_t ch_interval = 0;
for (const at::Tensor& tensor : tensors) {
ch_interval += tensor.sizes()[1];
ch_interval += get_dim<Dim4D::Channel>(tensor);
}
const int64_t depth_interval = ch_interval / 4;

Expand All @@ -159,12 +160,13 @@ Tensor cat_feature_mult4ch(
tensor_arg.is_vulkan() ? tensor_arg : tensor_arg.vulkan();
const vTensor& v_self = convert(tensor);

const uint32_t depth_slice = safe_downcast<uint32_t>(tensor.sizes()[1] / 4);
const uint32_t depth_slice =
safe_downcast<uint32_t>(get_dim<Dim4D::Channel>(tensor) / 4);

uvec3 copy_extents{
v_self.extents().data[0u], v_self.extents().data[1u], depth_slice};

for (const auto b : c10::irange(tensor.sizes()[0])) {
for (const auto b : c10::irange(get_dim<Dim4D::Batch>(tensor))) {
src_offset.data[2u] = safe_downcast<uint32_t>(depth_slice * b);
dst_offset.data[2u] =
depth_size_allprior + safe_downcast<uint32_t>(depth_interval * b);
Expand Down Expand Up @@ -267,7 +269,6 @@ Tensor cat_height(

Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
TORCH_CHECK(!tensors.empty(), "Vulkan cat expects at least one tensor");

auto materialized = tensors.materialize();
TORCH_INTERNAL_ASSERT(!materialized.empty(), "Accessing empty array");
const at::Tensor& tensor = materialized[0];
Expand All @@ -283,7 +284,7 @@ Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
t.dim(),
"d");

if (ndim < 3 || t.sizes()[1 - (4u - ndim)] % 4 != 0) {
if (ndim < 3 || get_dim<Dim4D::Channel>(t) % 4 != 0) {
is_mult4ch = false;
}

Expand Down Expand Up @@ -311,7 +312,7 @@ Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
return cat_height(materialized, v_output);
} else if (dim == ndim - 3) {
if (is_mult4ch) {
return cat_feature_mult4ch(materialized, v_output);
return cat_feature_mult4ch(materialized, v_output, ndim);
}
return cat_feature(materialized, v_output);
}
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4311,6 +4311,24 @@ TEST_F(VulkanAPITest, cat_4d_dim3_diffwidth_success) {
ASSERT_TRUE(check);
}

TEST_F(VulkanAPITest, cat_3d_dim0_mult4ch_success) {
// Arrange
const auto in_cpu1 = at::rand({4, 193, 113}, at::device(at::kCPU).dtype(at::kFloat));
const auto in_cpu2 = at::rand({4, 193, 113}, at::device(at::kCPU).dtype(at::kFloat));
const auto in_cpu3 = at::rand({4, 193, 113}, at::device(at::kCPU).dtype(at::kFloat));

// Act
const auto out_cpu = at::cat({in_cpu1, in_cpu2, in_cpu3}, 0);
const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 0);

// Assert
const auto check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
showRtol(out_cpu, out_vulkan.cpu());
}

ASSERT_TRUE(check);
}

TEST_F(VulkanAPITest, cat_3d_dim0_diff_channel_success) {
// Arrange
Expand Down