Skip to content

[Graph| tests, example, doc] Add GQA v2 support for implicit causal mask and example, doc update #3409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions doc/graph/fusion_patterns/gqa.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,24 @@ The notations used in the document:

Similar to how SDPA is supported, the GQA pattern is also defined as a
directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the
[SDPA pattern](@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16,
and f16) GQA as follows. The blue nodes are required when defining a GQA pattern
while the brown nodes are optional.
[SDPA pattern](@ref dev_guide_graph_sdpa) to support two types of floating-point
(f32, bf16, and f16) GQA patterns. The blue nodes are required when defining a
GQA pattern while the brown nodes are optional. The key difference between the
two types of GQA patterns lies in whether the input and output tensors have 4D
or 5D shapes. The optional StaticReshape operations are used to convert the tensors
between 4D and 5D shape formats, depending on whether the input and output tensors
are in 4D shapes.

![GQA pattern](images/gqa.png)

Compared to a typical SDPA pattern, there are a few differences in the GQA
pattern:
### GQA Pattern with 4D input and output

Due to the broadcasting semantics of MatMul, implementing GQA often requires
additional tensor manipulation. Specifically, when working with 4D input tensors,
where Query has shape (N, H_q, S, D) and Key/Value have shape (N, H_kv, S, D),
it is necessary to introduce extra StaticReshape operations to align tensor
dimensions for the MatMul operations. Therefore, the 4D GQA pattern involves the
following differences:

1. The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv,
N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping
Expand All @@ -56,6 +66,19 @@ pattern:
similarly. Besides that, they have the same definition as described in the
typical SDPA pattern.

### GQA Pattern with 5D input and output

To simplify process and avoid unnecessary reshapes, oneDNN also supports native
5D GQA pattern. In this approach, the input Query, Key, and Value tensors are
already provided in grouped format.

1. The input Query has 5D shape: (N, H_kv, N_rep, S, D)
2. The input Key/Value have 5D shape: (N, H_kv, 1, S, D)
3. The second MatMul calculates the dot products between the probabilities after
SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D).
4. The input scale factor and mask in the pattern also need to meet the
operations' shape requirement.

## Data Types

oneDNN supports the floating-point GQA pattern with data types f32, bf16, and
Expand All @@ -77,24 +100,28 @@ platforms follow the general description in @ref dev_guide_data_types.
2. The GQA patterns functionally support all input shapes meeting the shape
requirements of each operation in the graph.
3. CPU
- Optimized implementation is available for 4D Q/K/V tensors with shape
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
Query and (N, H_kv, 1, S, D) for Key and Value.
- Optimized implementation is available for OpenMP runtime and Threadpool
runtime on Intel Architecture Processors.
- Specifically for OpenMP runtime, the optimized implementation requires `N *
H_q > 2 * thread number` to get enough parallelism.
4. GPU
- Optimized implementation is available for 4D Q/K/V tensors with shape
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
- Optimized implementation is available for floating-point GQA with `f16`
data type and `D <= 512` on Intel Graphics Products with Intel(R) Xe Matrix
Extensions (Intel(R) XMX) support.
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
Query and (N, H_kv, 1, S, D) for Key and Value.
- Optimized implementation is available for floating-point GQA with `f16` and
`bf16` data type and `D <= 512` on Intel Graphics Products with Intel(R)
Xe Matrix Extensions (Intel(R) XMX) support.

## Example

oneDNN provides a [GQA
example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp)
demonstrating how to construct a floating-point GQA pattern with oneDNN Graph
demonstrating how to construct a 5D floating-point GQA pattern with oneDNN Graph
API on CPU and GPU with different runtimes.

## References
Expand Down
Binary file modified doc/graph/fusion_patterns/images/gqa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 16 additions & 59 deletions examples/graph/gqa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,115 +110,72 @@ void bench_gqa(engine::kind ekind, logical_tensor::data_type dt,
dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc);
// Create dnnl::stream.
dnnl::stream strm(eng);

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

dnnl_dim_t head_rep = p.q_head_num / p.kv_head_num;
// Prepare input and output shapes to construct the gqa graph.
const dims q_sz = {p.mb, p.q_head_num, p.seq_len, p.head_size};
const dims q_sz_reshape
= {p.mb, p.kv_head_num, head_rep, p.seq_len, p.head_size};
const dims kv_sz = {p.mb, p.kv_head_num, p.seq_len, p.head_size};
const dims kv_sz_reshape = {p.mb, p.kv_head_num, 1, p.seq_len, p.head_size};
const dims q_sz = {p.mb, p.kv_head_num, head_rep, p.seq_len, p.head_size};
const dims kv_sz = {p.mb, p.kv_head_num, 1, p.seq_len, p.head_size};
const dims score_sz = {p.mb, p.kv_head_num, head_rep, p.seq_len, p.seq_len};
const dims scale_sz = {1};
const dims mask_sz = {p.mb, 1, 1, p.seq_len};
const dims mask_sz_reshape = {p.mb, 1, 1, 1, p.seq_len};
const dims mask_sz = {p.mb, 1, 1, 1, p.seq_len};

// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// score = query x key.T
auto query = logical_tensor(id++, dt, q_sz, layout_type::strided);
auto query_reshape
= logical_tensor(id++, dt, q_sz_reshape, layout_type::strided);
auto key = logical_tensor(id++, dt, kv_sz, layout_type::strided);
auto key_reshape
= logical_tensor(id++, dt, kv_sz_reshape, layout_type::strided);
auto score = logical_tensor(id++, dt, score_sz, layout_type::strided);

auto reshape1 = op(id++, op::kind::StaticReshape, "reshape1");
reshape1.set_attr(op::attr::shape, q_sz_reshape);
reshape1.set_attr(op::attr::special_zero, false);
reshape1.add_inputs({query});
reshape1.add_outputs({query_reshape});

auto reshape2 = op(id++, op::kind::StaticReshape, "reshape2");
reshape2.set_attr(op::attr::shape, kv_sz_reshape);
reshape2.set_attr(op::attr::special_zero, false);
reshape2.add_inputs({key});
reshape2.add_outputs({key_reshape});
auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided);

auto bmm1 = op(id++, op::kind::MatMul, "bmm1");
bmm1.set_attr<bool>(op::attr::transpose_b, true);
bmm1.add_inputs({query_reshape, key_reshape});
bmm1.add_inputs({query, key});
bmm1.add_outputs({score});

// scaled_score = score / scale
auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided);
auto scaled_score
= logical_tensor(id++, dt, score_sz, layout_type::strided);
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto scale_div = op(id++, op::kind::Divide, "scale_div");
scale_div.add_inputs({score, scale});
scale_div.add_outputs({scaled_score});

// masked_score = scaled_score + mask
auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided);
auto mask_reshape
= logical_tensor(id++, dt, mask_sz_reshape, layout_type::strided);
auto reshape3 = op(id++, op::kind::StaticReshape, "reshape3");
reshape3.set_attr(op::attr::shape, mask_sz_reshape);
reshape3.set_attr(op::attr::special_zero, false);
reshape3.add_inputs({mask});
reshape3.add_outputs({mask_reshape});

auto masked_score
= logical_tensor(id++, dt, score_sz, layout_type::strided);
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto mask_add = op(id++, op::kind::Add, "mask_add");
mask_add.add_inputs({scaled_score, mask_reshape});
mask_add.add_inputs({scaled_score, mask});
mask_add.add_outputs({masked_score});

// attention_probs = softmax(masked_score)
auto probs = logical_tensor(id++, dt, score_sz, layout_type::strided);
auto softmax = op(id++, op::kind::SoftMax, "softmax");
softmax.set_attr<int64_t>(op::attr::axis, -1);
softmax.set_attr<std::string>(op::attr::mode, "inf_as_zero");
softmax.add_inputs({masked_score});
softmax.add_outputs({probs});

// attention_output = attention_probs x value
auto value = logical_tensor(id++, dt, kv_sz, layout_type::strided);
auto value_reshape
= logical_tensor(id++, dt, kv_sz_reshape, layout_type::strided);

auto output_reshape
= logical_tensor(id++, dt, q_sz_reshape, layout_type::strided);

auto reshape4 = op(id++, op::kind::StaticReshape, "reshape3");
reshape4.set_attr(op::attr::shape, kv_sz_reshape);
reshape4.set_attr(op::attr::special_zero, false);
reshape4.add_inputs({value});
reshape4.add_outputs({value_reshape});
auto output = logical_tensor(id++, dt, q_sz, layout_type::strided);

auto bmm2 = op(id++, op::kind::MatMul, "bmm2");
bmm2.add_inputs({probs, value_reshape});
bmm2.add_outputs({output_reshape});

auto output = logical_tensor(id++, dt, q_sz, layout_type::strided);
auto reshape5 = op(id++, op::kind::StaticReshape, "reshape4");
reshape5.set_attr(op::attr::shape, q_sz);
reshape5.set_attr(op::attr::special_zero, false);
reshape5.add_inputs({output_reshape});
reshape5.add_outputs({output});
bmm2.add_inputs({probs, value});
bmm2.add_outputs({output});

// Construct a gqa graph with engine kind and operations.
dnnl::graph::graph gqa(ekind);
gqa.add_op(reshape1);
gqa.add_op(reshape2);
gqa.add_op(bmm1);
gqa.add_op(scale_div);
gqa.add_op(reshape3);
gqa.add_op(mask_add);
gqa.add_op(softmax);
gqa.add_op(reshape4);
gqa.add_op(bmm2);
gqa.add_op(reshape5);
gqa.finalize();

// Get partitions from the gqa graph.
Expand Down
12 changes: 9 additions & 3 deletions src/graph/backend/dnnl/passes/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4068,7 +4068,8 @@ status_t fuse_implicit_causal_mask(std::shared_ptr<subgraph_t> &sg) {
if (!in_val1->has_producer()) continue;
auto &in_op1 = in_val1->get_producer();
if (in_op1.get_kind() != op_kind::dnnl_gen_index) continue;
if (in_op1.get_attr<int64_t>(op_attr::axis) != 3) continue;
auto ndim = in_op1.get_input_value(0)->get_logical_tensor().ndims;
if (in_op1.get_attr<int64_t>(op_attr::axis) != ndim - 1) continue;
if (in_op1.get_input_value(0) != out_op.get_input_value(0)) continue;
op_list.emplace_back(in_op1.shared_from_this());

Expand All @@ -4077,7 +4078,8 @@ status_t fuse_implicit_causal_mask(std::shared_ptr<subgraph_t> &sg) {
if (!in_val0->has_producer()) continue;
auto &in_op0 = in_val0->get_producer();
if (in_op0.get_kind() == op_kind::dnnl_gen_index) {
if (in_op0.get_attr<int64_t>(op_attr::axis) != 2) continue;
auto ndim = in_op0.get_input_value(0)->get_logical_tensor().ndims;
if (in_op0.get_attr<int64_t>(op_attr::axis) != ndim - 2) continue;
op_list.emplace_back(in_op0.shared_from_this());
matched = true;
} else if (compare_op_kind_and_algorithm(in_op0, op_kind::dnnl_binary,
Expand All @@ -4099,7 +4101,11 @@ status_t fuse_implicit_causal_mask(std::shared_ptr<subgraph_t> &sg) {
// Check if the GenIndex op exists
if (gen_index_op.get_kind() != op_kind::dnnl_gen_index)
continue;
if (gen_index_op.get_attr<int64_t>(op_attr::axis) != 2)
auto ndim = gen_index_op.get_input_value(0)
->get_logical_tensor()
.ndims;
if (gen_index_op.get_attr<int64_t>(op_attr::axis)
!= ndim - 2)
continue;
if (gen_index_op.get_input_value(0)
!= out_op.get_input_value(0))
Expand Down
5 changes: 5 additions & 0 deletions tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/gqa-plain-implicit-causal-mask-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json

Expand All @@ -32,6 +33,8 @@
--reset --dt=1:f16+2:f16+3:f16+4:f16+6:f16+104:f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
--reset --dt=4:f32+9:f32+14:f32 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=4:f32+9:f32+14:f32 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=1:f16+3:f16+8:f16+16:f16+19:f16+20:f16 --case=complex_fusion/mha/gqa-plain-implicit-causal-mask-fp32-bs1.json
--reset --case=complex_fusion/mha/gqa-plain-bottom-right-implicit-causal-mask-f16-f32.json
--reset --dt=3:f16+4:f16+2:f16+1:f16+11:f16+0:f16+12:f16+14:f16+16:f16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=0:f16+1:f16+3:f16+7:f16+2:f16+8:f16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --dt=15:f32+16:f32+5:f32+21:f32 --case=complex_fusion/mha/sdpa-compressed-kv-implicit-causal-mask-int8-gs128.json
Expand All @@ -57,6 +60,8 @@
--reset --op-kind=1:Multiply,1:Divide --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
--reset --dt=3:bf16+4:bf16+2:bf16+1:bf16+11:bf16+0:bf16+12:bf16+14:bf16+16:bf16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json
--reset --dt=4:f32+9:f32+14:f32+1:bf16+3:bf16+8:bf16+11:bf16+16:bf16+20:bf16+19:bf16 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=1:bf16+3:bf16+8:bf16+16:bf16+19:bf16+20:bf16 --case=complex_fusion/mha/gqa-plain-implicit-causal-mask-fp32-bs1.json
--reset --dt=0:bf16+1:bf16+4:bf16+22:bf16+24:bf16+25:bf16 --case=complex_fusion/mha/gqa-plain-bottom-right-implicit-causal-mask-f16-f32.json
--reset --dt=4:f32+9:f32+14:f32+0:bf16+1:bf16+2:bf16+3:bf16+11:bf16+12:bf16+18:bf16+19:bf16+8:bf16+16:bf16+20:bf16+23:bf16 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=0:bf16+1:bf16+3:bf16+7:bf16+2:bf16+8:bf16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
Expand Down
5 changes: 5 additions & 0 deletions tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/gqa-plain-implicit-causal-mask-fp32-bs1.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
# f16 inputs + f32 intermediates + f16 outputs
--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
--reset --dt=4:f32+9:f32+14:f32 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=1:f16+3:f16+8:f16+16:f16+19:f16+20:f16 --case=complex_fusion/mha/gqa-plain-implicit-causal-mask-fp32-bs1.json
--reset --case=complex_fusion/mha/gqa-plain-bottom-right-implicit-causal-mask-f16-f32.json
--reset --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json
--reset --case=complex_fusion/mha/codegemma-bf16-f32.json
--reset --case=complex_fusion/mha/gemma2-bf16-f32.json

# bf16 inputs + f32 intermediates + bf16 outputs
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
--reset --dt=4:f32+9:f32+14:f32+1:bf16+3:bf16+8:bf16+11:bf16+16:bf16+20:bf16+19:bf16 --case=complex_fusion/mha/GQA-fp16-v2.json
--reset --dt=1:bf16+3:bf16+8:bf16+16:bf16+19:bf16+20:bf16 --case=complex_fusion/mha/gqa-plain-implicit-causal-mask-fp32-bs1.json
--reset --dt=0:bf16+1:bf16+4:bf16+22:bf16+24:bf16+25:bf16 --case=complex_fusion/mha/gqa-plain-bottom-right-implicit-causal-mask-f16-f32.json
--reset --dt=0:bf16+1:bf16+4:bf16+22:bf16+24:bf16+25:bf16 --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json

# int8 graphs
Expand Down
Loading