Skip to content

Commit

Permalink
move transform_bias_rescale_qkv vectorized code to cpu sub folder (#1…
Browse files Browse the repository at this point in the history
…09095)

`at::vec::Vectorized<>` will not be properly vectorized under the folder of `aten/src/ATen/native/transformers`, move the vectorized code to `aten/src/ATen/native/cpu` where the macros of `CPU_CAPABILITY_AVX2`, `CPU_CAPABILITY_AVX512` etc. are defined.

Here is the vtune log before and after this patch on `transform_bias_rescale_qkv_cpu`
1. before:
![transformer_bioas_rescale_qkv_before](https://github.com/pytorch/pytorch/assets/20233731/582f6873-d86e-47a6-bd2a-620b97acc5b1)
2. after:
![transformer_bioas_rescale_qkv_after](https://github.com/pytorch/pytorch/assets/20233731/949004ab-3cbc-4a1d-a03d-9a17efa981ae)

Pull Request resolved: #109095
Approved by: https://github.com/jgong5, https://github.com/lezcano
  • Loading branch information
mingfeima authored and pytorchmergebot committed Sep 18, 2023
1 parent f0fb4b3 commit a683bc5
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 108 deletions.
111 changes: 111 additions & 0 deletions aten/src/ATen/native/cpu/NativeMultiheadAttnKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/OpMathType.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIterator.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/native/transformers/attention.h>
#include <c10/util/irange.h>

namespace at::native {

namespace {

template <typename scalar_t>
void cpu_transform_bias_rescale_qkv(
scalar_t* q_k_v_data,
const scalar_t* qkv_data,
const scalar_t* qkv_bias_data,
int64_t B,
int64_t T,
int64_t D,
int64_t num_head) {

int64_t dim_per_head = D / num_head;

// shapes and strides:
// qkv : {B, T, 3, num_head, dim_per_head}
// qkv_bias : {3, num_head, dim_per_head}
// q_k_v : {3, B, num_head, T, dim_per_head}
//
int64_t i_strideB = T * 3 * D;
int64_t i_strideT = 3 * D;
int64_t o_stride = B * num_head * T * dim_per_head;

// inv_sqrt_dim_per_head in accumulate type
using acc_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<acc_t>;
const acc_t s = 1.0 / std::sqrt(static_cast<acc_t>(dim_per_head));

// parallel on {B, num_head, T}
int64_t grain_size = std::max(at::internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
at::parallel_for(0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
int64_t b{0}, nh{0}, t{0};
data_index_init(begin, b, B, nh, num_head, t, T);

for (const auto i : c10::irange(begin, end)) {
const scalar_t* q_in_ptr = qkv_data + b * i_strideB + t * i_strideT + 0 * D + nh * dim_per_head;
const scalar_t* k_in_ptr = qkv_data + b * i_strideB + t * i_strideT + 1 * D + nh * dim_per_head;
const scalar_t* v_in_ptr = qkv_data + b * i_strideB + t * i_strideT + 2 * D + nh * dim_per_head;

const scalar_t* q_bias_ptr = qkv_bias_data + 0 * D + nh * dim_per_head;
const scalar_t* k_bias_ptr = qkv_bias_data + 1 * D + nh * dim_per_head;
const scalar_t* v_bias_ptr = qkv_bias_data + 2 * D + nh * dim_per_head;

// we can use global index i here for output
scalar_t* q_out_ptr = q_k_v_data + 0 * o_stride + i * dim_per_head;
scalar_t* k_out_ptr = q_k_v_data + 1 * o_stride + i * dim_per_head;
scalar_t* v_out_ptr = q_k_v_data + 2 * o_stride + i * dim_per_head;

// q = (q + bias) * inv_sqrt_dim_per_head
vec::map2<scalar_t>(
[s](Vec q, Vec q_bias) { return (q + q_bias) * Vec(s); },
q_out_ptr, q_in_ptr, q_bias_ptr, dim_per_head);

// k = k + bias
vec::map2<scalar_t>([](Vec k, Vec k_bias) { return k + k_bias; },
k_out_ptr, k_in_ptr, k_bias_ptr, dim_per_head);

// v = v + bias
vec::map2<scalar_t>([](Vec v, Vec v_bias) { return v + v_bias; },
v_out_ptr, v_in_ptr, v_bias_ptr, dim_per_head);

// move to the next index
data_index_step(b, B, nh, num_head, t, T);
}
});
}

void transform_bias_rescale_qkv_kernel_impl(
at::ScalarType type,
void* _q_k_v,
const void* _qkv,
const void* _qkv_bias,
int64_t B,
int64_t T,
int64_t D,
int64_t num_head) {

AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, type, "transform_bias_rescale_qkv", [&] {
scalar_t* q_k_v = static_cast<scalar_t*>(_q_k_v);
const scalar_t* qkv = static_cast<const scalar_t*>(_qkv);
const scalar_t* qkv_bias = static_cast<const scalar_t*>(_qkv_bias);
cpu_transform_bias_rescale_qkv<scalar_t>(
q_k_v,
qkv,
qkv_bias,
B,
T,
D,
num_head);
});
}

} // anonymous namespace

REGISTER_DISPATCH(transform_bias_rescale_qkv_stub, &transform_bias_rescale_qkv_kernel_impl);

} // at::native
115 changes: 7 additions & 108 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <ATen/NestedTensorImpl.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/cpu/vec/vec256/vec256.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <utility>
Expand All @@ -35,6 +34,7 @@ namespace native {

DEFINE_DISPATCH(_fused_sdp_choice_stub);

DEFINE_DISPATCH(transform_bias_rescale_qkv_stub);
DEFINE_DISPATCH(flash_attention_kernel);
DEFINE_DISPATCH(flash_attention_backward_kernel);

Expand All @@ -48,83 +48,6 @@ Tensor gemm_nt(const Tensor& self, const Tensor& other) {
}
}

template <typename scalar_t>
void transform_bias_rescale_qkv_inner_loop(
int64_t B,
int64_t T,
int64_t _3D,
int64_t D,
int64_t num_head,
int64_t dim_per_head,
scalar_t* qkv_data,
scalar_t* qkv_bias_data,
scalar_t* q_k_v_data,
scalar_t inv_sqrt_dim_per_head,
int64_t begin,
int64_t end) {
for (auto i : c10::irange(begin, end)) {
auto t = i % T;
i /= T;
auto nh = i % num_head;
i /= num_head;
auto b = i;
using Vec = vec::Vectorized<scalar_t>;
auto V = vec::Vectorized<scalar_t>::size();
auto dh = 0;
auto d = nh * dim_per_head;
for (; dh + V <= dim_per_head; dh += V, d += V) {
// load
auto q_bias_data = Vec::loadu(&qkv_bias_data[d + 0 * D]);
auto k_bias_data = Vec::loadu(&qkv_bias_data[d + 1 * D]);
auto v_bias_data = Vec::loadu(&qkv_bias_data[d + 2 * D]);

auto q_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 0 * D]) +
q_bias_data;
auto k_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 1 * D]) +
k_bias_data;
auto v_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 2 * D]) +
v_bias_data;

q_data = q_data * Vec(inv_sqrt_dim_per_head);

q_data.store(&q_k_v_data
[0 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head + t * dim_per_head + dh]);
k_data.store(&q_k_v_data
[1 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head + t * dim_per_head + dh]);
v_data.store(&q_k_v_data
[2 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head +
nh * T * dim_per_head + t * dim_per_head + dh]);
}
for (; dh < dim_per_head; dh++) {
auto d = nh * dim_per_head + dh;
auto q_bias = qkv_bias_data[d + 0 * D];
auto k_bias = qkv_bias_data[d + 1 * D];
auto v_bias = qkv_bias_data[d + 2 * D];
auto q_data = qkv_data[b * _3D * T + t * _3D + d + 0 * D] + q_bias;
auto k_data = qkv_data[b * _3D * T + t * _3D + d + 1 * D] + k_bias;
auto v_data = qkv_data[b * _3D * T + t * _3D + d + 2 * D] + v_bias;
q_data = q_data * inv_sqrt_dim_per_head;
q_k_v_data
[0 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head + nh * T * dim_per_head +
t * dim_per_head + dh] = q_data;
q_k_v_data
[1 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head + nh * T * dim_per_head +
t * dim_per_head + dh] = k_data;
q_k_v_data
[2 * B * num_head * T * dim_per_head +
b * num_head * T * dim_per_head + nh * T * dim_per_head +
t * dim_per_head + dh] = v_data;
}
}
}

Tensor transform_0213(const Tensor& a) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
Expand Down Expand Up @@ -284,37 +207,13 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(

const auto qkv_contig = qkv_->expect_contiguous();
const auto qkv_bias_contig = qkv_bias.expect_contiguous();
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
transform_bias_rescale_qkv_stub(
kCPU,
qkv_->scalar_type(),
"transform_bias_rescale_qkv",
[&] {
scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
const scalar_t inv_sqrt_dim_per_head =
1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head));

int64_t grain_size =
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
parallel_for(
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
transform_bias_rescale_qkv_inner_loop(
B,
T,
_3D,
D,
num_head,
dim_per_head,
qkv_data,
qkv_bias_data,
q_k_v_data,
inv_sqrt_dim_per_head,
begin,
end);
});
});
q_k_v.data_ptr(),
qkv_contig->const_data_ptr(),
qkv_bias_contig->const_data_ptr(),
B, T, D, num_head);
auto q_k_v_s =
at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/transformers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ TORCH_API Tensor masked_softmax(
const Tensor& query,
c10::optional<int64_t> mask_type = {});

using transform_bias_rescale_qkv_fn = void(*)(
at::ScalarType type,
void* _q_k_v,
const void* _qkv,
const void* _qkv_bias,
int64_t B,
int64_t T,
int64_t D,
int64_t num_head);

DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub);

TORCH_API Tensor transform0213_gemm_nt_bias(
const Tensor& a,
const Tensor& b,
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/MaxPooling.cpp",
"aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp",
"aten/src/ATen/native/cpu/MultinomialKernel.cpp",
"aten/src/ATen/native/cpu/NativeMultiheadAttnKernel.cpp",
"aten/src/ATen/native/cpu/PaddingKernel.cpp",
"aten/src/ATen/native/cpu/PixelShuffleKernel.cpp",
"aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp",
Expand Down

0 comments on commit a683bc5

Please sign in to comment.