Skip to content

Commit

Permalink
[PyTorch] Add Vulkan support for at::softmax 1,2,3 dimension tensors
Browse files Browse the repository at this point in the history
Summary: This rounds out the support for the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) on the Vulkan GPU backend. The test inputs of the 1,2,3 dimension cases are simply the truncated existing 4 dimension inputs. The existing shader algorithms are reused.

Test Plan:
1. `buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource  //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1` on Apple M1 MacBook
2. Confirm all tests pass with no regression, and the added tests `*softmax*` pass under `-- --gtest_filter="*softmax*"`
2a. All tests P782531732
2b. `softmax` tests P782529114

```
~/fbsource » buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*softmax*"
Buck UI: https://www.internalfb.com/buck2/692eb82d-c2ee-49bb-833f-3c11d6e2fea9
Jobs completed: 4. Time elapsed: 0.1s.
Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
Note: Google Test filter = *softmax*
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from VulkanAPITest
[ RUN      ] VulkanAPITest.softmax
[       OK ] VulkanAPITest.softmax (42 ms)
[ DISABLED ] VulkanAPITest.DISABLED_log_softmax
[----------] 1 test from VulkanAPITest (42 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (42 ms total)
[  PASSED  ] 1 test.

  YOU HAVE 1 DISABLED TEST

```

Reviewed By: SS-JIA

Differential Revision: D46985319

fbshipit-source-id: 0c052f09ba365dc0ae1946509c64bab195fe76e6
  • Loading branch information
liuk22 authored and facebook-github-bot committed Jul 11, 2023
1 parent 2f95a3d commit c3a9ebe
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 64 deletions.
141 changes: 97 additions & 44 deletions aten/src/ATen/native/vulkan/ops/Softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,103 @@ namespace {

using namespace api::utils;

void set_softmax_kernel_params(
const long long num_dims,
const long long softmax_dim,
const IntArrayRef v_input_sizes,
api::ShaderInfo& shader_descriptor,
api::utils::ivec4& input_shader_extents,
api::utils::ivec4& early_exit,
api::utils::ivec4& input_dim_stride,
api::utils::ivec4& input_tensor_dims) {
if (num_dims == 1) {
early_exit.data[0u] = 1;
input_dim_stride.data[0u] = 1;
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
} else if (num_dims == 2) {
// for height, width dim case, we can reuse a single shader
// with vectorized parameters
if (softmax_dim == 0) {
early_exit.data[1u] = 1;
input_dim_stride.data[1u] = 1;
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
} else { // dim == 1
early_exit.data[0u] = 1;
input_dim_stride.data[0u] = 1;
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
}
} else if (num_dims == 3) {
// for height, width dim case, we can reuse a single shader
// with vectorized parameters
for (uint32_t i = 0; i < num_dims; i++) {
input_tensor_dims.data[i + 1] = safe_downcast<int32_t>(v_input_sizes[i]);
}
if (softmax_dim == 0) {
early_exit.data[2u] = 1;
input_dim_stride.data[2u] = 1;
shader_descriptor = VK_KERNEL(softmax_channel);
} else if (softmax_dim == 1) {
early_exit.data[1u] = 1;
input_dim_stride.data[1u] = 1;
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
} else { // dim == 2
early_exit.data[0u] = 1;
input_dim_stride.data[0u] = 1;
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
}
} else {
// assume num_dims is 4
// for batch, height, width dim case, we can reuse a single shader
// with vectorized parameters
for (uint32_t i = 0; i < num_dims; i++) {
input_tensor_dims.data[i] = safe_downcast<int32_t>(v_input_sizes[i]);
}
if (softmax_dim == 1) {
// for 4-rank Tensor, softmax along channel dim case, the memory layout
// forces a different shader algorithm than other dims
input_shader_extents.data[2u] =
v_input_sizes[Layout::Activation4D::batch];
shader_descriptor = VK_KERNEL(softmax_channel);
} else {
if (softmax_dim == 0) {
early_exit.data[2u] = safe_downcast<int32_t>(
std::ceil(v_input_sizes[Layout::Activation4D::channels] / 4.0));
input_dim_stride.data[2u] = safe_downcast<int32_t>(
std::ceil(v_input_sizes[Layout::Activation4D::channels] / 4.0));
} else if (softmax_dim == 2) {
early_exit.data[1u] = 1;
input_dim_stride.data[1u] = 1;
} else { // dim == 3
early_exit.data[0u] = 1;
input_dim_stride.data[0u] = 1;
}
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
}
}
}

Tensor softmax_internal(
const at::Tensor& input_arg,
const int64_t dim,
const bool half_to_float,
const bool log_softmax) {
TORCH_CHECK(
input_arg.dim() == 4, "Vulkan softmax expects 4-dimensional input!");

input_arg.dim() >= 1 && input_arg.dim() <= 4,
"Vulkan softmax expects 1,2,3 or 4-dimensional input!");
TORCH_CHECK(
dim >= 0 && dim < input_arg.dim(),
"Softmax dim input was ", dim, " out of range for Tensor input with dimensions ", input_arg.dim());
api::Context* const context = api::context();

const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
const vTensor& v_input = convert(input);
const IntArrayRef v_input_sizes = v_input.sizes();
c10::SmallVector<int64_t, 4u> output_sizes{
v_input_sizes[Layout::Activation4D::batch],
v_input_sizes[Layout::Activation4D::channels],
v_input_sizes[Layout::Activation4D::height],
v_input_sizes[Layout::Activation4D::width],
};

vTensor v_output{
context,
output_sizes,
v_input_sizes,
input_arg.scalar_type(),
};

// we have custom global workgroup extents for softmax to enable
// shader algorithms that avoid redundant denominator computations
const api::utils::uvec3 global_workgroup_extents = v_output.extents();
api::utils::ivec4 input_shader_extents = {
safe_downcast<int32_t>(v_input.extents().data[0u]),
Expand All @@ -61,6 +130,12 @@ Tensor softmax_internal(
0,
0, // zero pad
};
api::utils::ivec4 input_tensor_dims = {
0,
0,
0,
0,
};
api::ShaderInfo shader_descriptor;
if (log_softmax) {
if (dim == 1) {
Expand All @@ -71,29 +146,15 @@ Tensor softmax_internal(
"Vulkan log_softmax expects 4-dimensional input with dim=1!");
}
} else {
if (dim == 1) {
// for channel dim case, the memory layout forces
// a different shader algorithm than other dims
input_shader_extents.data[2u] =
v_input_sizes[Layout::Activation4D::batch];
shader_descriptor = VK_KERNEL(softmax_channel);
} else {
// for batch, height, width dim case, we can reuse a single shader
// with vectorized parameters
if (dim == 0) {
early_exit.data[2u] = safe_downcast<int32_t>(
std::ceil(v_input_sizes[Layout::Activation4D::channels] / 4.0));
input_dim_stride.data[2u] = safe_downcast<int32_t>(
std::ceil(v_input_sizes[Layout::Activation4D::channels] / 4.0));
} else if (dim == 2) {
early_exit.data[1u] = 1;
input_dim_stride.data[1u] = 1;
} else { // dim == 3
early_exit.data[0u] = 1;
input_dim_stride.data[0u] = 1;
}
shader_descriptor = VK_KERNEL(softmax_batch_height_width);
}
set_softmax_kernel_params(
input_arg.dim(),
dim,
v_input_sizes,
shader_descriptor,
input_shader_extents,
early_exit,
input_dim_stride,
input_tensor_dims);
}

const struct Block final {
Expand All @@ -102,15 +163,7 @@ Tensor softmax_internal(
ivec4 input_dim_stride;
ivec4 early_exit;
} block{
input_shader_extents,
{
safe_downcast<int32_t>(v_input_sizes[Layout::Activation4D::batch]),
safe_downcast<int32_t>(v_input_sizes[Layout::Activation4D::channels]),
safe_downcast<int32_t>(v_input_sizes[Layout::Activation4D::height]),
safe_downcast<int32_t>(v_input_sizes[Layout::Activation4D::width]),
}, // input_tensor_dims
input_dim_stride,
early_exit};
input_shader_extents, input_tensor_dims, input_dim_stride, early_exit};
api::UniformParamsBuffer params(context, block);
api::PipelineBarrier pipeline_barrier{};

Expand Down
45 changes: 25 additions & 20 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3388,29 +3388,34 @@ TEST_F(VulkanAPITest, sigmoid_) {
ASSERT_TRUE(check);
}

void test_softmax_4d(int64_t dim) {
c10::InferenceMode mode;

at::Tensor test_in[] = {
at::rand({1, 3, 4, 2}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
at::rand({4, 8, 5, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
at::rand({9, 11, 12, 12}, at::TensorOptions(at::kCPU).dtype(at::kFloat)),
TEST_F(VulkanAPITest, softmax) {
c10::InferenceMode mode;
std::vector<std::vector<int64_t>> test_in_dims = {
{1, 3, 4, 2},
{4, 8, 5, 7},
{9, 11, 12, 12},
};
for (auto in_cpu : test_in) {
const auto out_cpu = at::softmax(in_cpu, dim);
const auto in_vulkan = in_cpu.vulkan();
const auto out_vulkan = at::softmax(in_vulkan, dim);
const auto check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
showRtol(out_cpu, out_vulkan.cpu());
for (const std::vector<int64_t >& dim_vec : test_in_dims) {
for (int trunc = 0; trunc < dim_vec.size(); trunc++) {
const std::vector<int64_t> trunc_dim_vec = std::vector<int64_t>(dim_vec.begin(), dim_vec.end() - trunc);
at::Tensor in_cpu = at::rand(trunc_dim_vec, at::TensorOptions(at::kCPU).dtype(at::kFloat));
for (int dim = 0; dim < trunc_dim_vec.size(); dim++) {
const at::Tensor out_cpu = at::softmax(in_cpu, dim);
const at::Tensor in_vulkan = in_cpu.vulkan();
const at::Tensor out_vulkan = at::softmax(in_vulkan, dim);
const bool check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
std::cout << "Softmax test failed on axis " << dim << "for tensor dims {";
for (int place = 0; place < trunc_dim_vec.size() - 1; place++) {
std::cout << trunc_dim_vec[place] << " ";
}
std::cout << trunc_dim_vec.back() << "}" << std::endl;
showRtol(out_cpu, out_vulkan.cpu());
}
ASSERT_TRUE(check);
}
ASSERT_TRUE(check);
}
}

TEST_F(VulkanAPITest, softmax_4d) {
for (int dim = 0; dim < 4; dim++) {
test_softmax_4d(dim);
}
}
}

Expand Down

0 comments on commit c3a9ebe

Please sign in to comment.