Skip to content

Commit

Permalink
Vectorize the softmax calculation when not along the last dim (#59195)
Browse files Browse the repository at this point in the history
Summary:
Currently, if we do softmax which are not along the last dim, the calculation will fall to a [scalar version](https://github.com/pytorch/pytorch/blob/d417a094f398f1c4efd7f818b14b8471a597fbcc/aten/src/ATen/native/SoftMax.cpp#L14-L64).  And we find actually we have the chance to vectorize the calculation along the inner_size dim.

Changes we made:

- Use vectorized softmax_kernel instead of host_softmax when not along the last dim.

Performance data on 28 cores' Intel 8280 CPU when the Input size is [32, 81, 15130] and do softmax along the second dim(81).

- FP32 Baseline: 24.67 ms
- FP32 optimized: 9.2 ms

Pull Request resolved: #59195

Reviewed By: ailzhang

Differential Revision: D28854796

Pulled By: cpuhrsch

fbshipit-source-id: 18477acc3963754c59009b1794f080496ae16c3d
  • Loading branch information
leslie-fang-intel authored and facebook-github-bot committed Jun 14, 2021
1 parent d60d81b commit 68d690f
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 4 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/native/SoftMax.cpp
Expand Up @@ -132,17 +132,15 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_
if (input.numel() == 0) {
return output;
}
if (input.dim() == 0)
if (input.dim() == 0)
input = input.view(1);
TORCH_CHECK(
dim >= 0 && dim < input.dim(),
"dim must be non-negative and less than input dimensions");
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
softmax_lastdim_kernel(kCPU, output, input);
} else {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "softmax", [&] {
host_softmax<scalar_t, false>(output, input, dim);
});
softmax_kernel(kCPU, output, input, dim);
}
return output;
}
Expand Down Expand Up @@ -310,6 +308,8 @@ DEFINE_DISPATCH(softmax_backward_lastdim_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel);

DEFINE_DISPATCH(softmax_kernel);

Tensor softmax(const Tensor& self, Dimname dim, optional<ScalarType> dtype) {
return at::softmax(self, dimname_to_position(self, dim), dtype);
}
Expand Down
116 changes: 116 additions & 0 deletions aten/src/ATen/native/cpu/SoftMaxKernel.cpp
Expand Up @@ -10,6 +10,7 @@
#include <ATen/cpu/vec/vec.h>
#include <c10/util/Optional.h>

#include <ATen/AccumulateType.h>
// [Note AVX-SSE transitions] In general we avoid calls into cmath for code
// compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
// Glibc2.23 See https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280
Expand Down Expand Up @@ -206,6 +207,113 @@ struct vec_host_softmax_lastdim {
}
};

template <typename scalar_t>
inline void _vec_softmax(
scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t outer_size,
int64_t inner_size,
int64_t dim_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t dim_stride = inner_size;
int64_t outer_stride = dim_size * dim_stride;
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32
TORCH_CHECK(
(vectorized_step == 8) || (vectorized_step == 4),
"vectorized_step must be 8 with dtype float or 4 with dtype double");
parallel_for(
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
int64_t idx = begin;
while (idx < end) {
int64_t outer_idx = idx / inner_size;
int64_t inner_idx = idx % inner_size;
if (((inner_idx + vectorized_step) <= inner_size) && ((idx + vectorized_step) <= end)) {
// Vectorization
scalar_t* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
scalar_t* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Step 1: Get max Score
Vec max_m256 = Vec::loadu(input_data);
for (int64_t d = 1; d < dim_size; d += 1) {
Vec input_m256 = Vec::loadu(input_data + d * dim_stride);
max_m256 = vec::maximum(max_m256, input_m256);
}
// Step2: Calculate sum
Vec sum_m256 = Vec(0.0);
for (int64_t d = 0; d < dim_size; d += 1) {
Vec output_m256 =
(Vec::loadu(input_data + d * dim_stride) - max_m256).exp();
output_m256.store(output_data + d * dim_stride);
sum_m256 = sum_m256 + output_m256;
}
// Step3: Unify
for (int64_t d = 0; d < dim_size; d += 1) {
Vec output_m256 =
Vec::loadu(output_data + d * dim_stride) / sum_m256;
output_m256.store(output_data + d * dim_stride);
}
idx += vectorized_step;
} else {
// Tail case(Scalar): it is exactly same logic as host_softmax
// inside aten/src/ATen/native/SoftMax.cpp. There are 2 kind of
// cases which will fall through this part:
// Case 1: For the idx at the end of total chunk for each thread, there are not enough numbers for parallization.
// Case 2: For the idx at the end of each inner_size inside thread, there are not enough numbers for parallization.
int64_t tail_number = ((idx+vectorized_step) > end) ? /*Case1*/ (end - idx) : /*Case2*/ (inner_size - inner_idx);
for (int64_t i=0; i < tail_number; i++) {
outer_idx = (idx + i) / inner_size;
inner_idx = (idx + i) % inner_size;
scalar_t* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
scalar_t* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Step1: Get max score
scalar_t max_input = input_data[0];
for (int64_t d = 1; d < dim_size; d += 1) {
max_input = std::max(max_input, input_data[d * dim_stride]);
}
// Step2: Calculate the Sum
scalar_t sum_data = 0;
for (int64_t d = 0; d < dim_size; d += 1) {
output_data[d * dim_stride] =
std::exp(input_data[d * dim_stride] - max_input);
sum_data += output_data[d * dim_stride];
}
// Step3: Unify
for (int64_t d = 0; d < dim_size; d += 1) {
output_data[d * dim_stride] =
output_data[d * dim_stride]/sum_data;
}
}
idx += tail_number;
}
}
});
}

template <typename scalar_t, bool LogSoftMax>
struct vec_softmax {
static void apply(Tensor& output, const Tensor& input, int64_t dim) {
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
int64_t inner_size = 1;
for (int64_t i = 0; i < dim; ++i)
outer_size *= input.size(i);
for (int64_t i = dim + 1; i < input.dim(); ++i)
inner_size *= input.size(i);
scalar_t* input_data_base = input.data_ptr<scalar_t>();
scalar_t* output_data_base = output.data_ptr<scalar_t>();
if (LogSoftMax) {
AT_ERROR("vec_softmax not implemented for LogSoftMax");
} else {
_vec_softmax(
input_data_base, output_data_base, outer_size, inner_size, dim_size);
}
}
};

template <typename scalar_t, bool LogSoftMax>
struct vec_host_softmax_backward_lastdim {
static void
Expand All @@ -232,6 +340,12 @@ static void softmax_lastdim_kernel_impl(Tensor& result, const Tensor& self) {
});
}

static void softmax_kernel_impl(Tensor& result, const Tensor& self, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "softmax_kernel_impl", [&] {
vec_softmax<scalar_t, false>::apply(result, self, dim);
});
}

static void log_softmax_lastdim_kernel_impl(
Tensor& result,
const Tensor& self) {
Expand Down Expand Up @@ -279,4 +393,6 @@ REGISTER_DISPATCH(
log_softmax_backward_lastdim_kernel,
&log_softmax_backward_lastdim_kernel_impl);

REGISTER_DISPATCH(softmax_kernel, &softmax_kernel_impl);

}} // namespace at::native
3 changes: 3 additions & 0 deletions aten/src/ATen/native/cpu/SoftmaxKernel.h
Expand Up @@ -14,5 +14,8 @@ DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);

using forward_fn_with_dim = void(*)(Tensor &, const Tensor &, const int64_t);
DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);

}
}

0 comments on commit 68d690f

Please sign in to comment.