Skip to content

Commit 8456adf

Browse files
committed
doc: graph: update GQA pattern spec
1 parent 814b9b9 commit 8456adf

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

doc/graph/fusion_patterns/gqa.md

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,24 @@ The notations used in the document:
3030

3131
Similar to how SDPA is supported, the GQA pattern is also defined as a
3232
directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the
33-
[SDPA pattern](@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16,
34-
and f16) GQA as follows. The blue nodes are required when defining a GQA pattern
35-
while the brown nodes are optional.
33+
[SDPA pattern](@ref dev_guide_graph_sdpa) to support two types of floating-point
34+
(f32, bf16, and f16) GQA patterns. The blue nodes are required when defining a
35+
GQA pattern while the brown nodes are optional. The key difference between the
36+
two types of GQA patterns lies in whether the input and output tensors have 4D
37+
or 5D shapes. The optional StaticReshape operations are used to convert the tensors
38+
between 4D and 5D shape formats, depending on whether the input and output tensors
39+
are in 4D shapes.
3640

3741
![GQA pattern](images/gqa.png)
3842

39-
Compared to a typical SDPA pattern, there are a few differences in the GQA
40-
pattern:
43+
### GQA Pattern with 4D input and output
44+
45+
Due to the broadcasting semantics of MatMul, implementing GQA often requires
46+
additional tensor manipulation. Specifically, when working with 4D input tensors,
47+
where Query has shape (N, H_q, S, D) and Key/Value have shape (N, H_kv, S, D),
48+
it is necessary to introduce extra StaticReshape operations to align tensor
49+
dimensions for the MatMul operations. Therefore, the 4D GQA pattern involves the
50+
following differences:
4151

4252
1. The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv,
4353
N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping
@@ -56,6 +66,19 @@ pattern:
5666
similarly. Besides that, they have the same definition as described in the
5767
typical SDPA pattern.
5868

69+
### GQA Pattern with 5D input and output
70+
71+
To simplify process and avoid unnecessary reshapes, oneDNN also supports native
72+
5D GQA pattern. In this approach, the input Query, Key, and Value tensors are
73+
already provided in grouped format.
74+
75+
1. The input Query has 5D shape: (N, H_kv, N_rep, S, D)
76+
2. The input Key/Value have 5D shape: (N, H_kv, 1, S, D)
77+
3. The second MatMul calculates the dot products between the probabilities after
78+
SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D).
79+
4. The input scale factor and mask in the pattern also need to meet the
80+
operations' shape requirement.
81+
5982
## Data Types
6083

6184
oneDNN supports the floating-point GQA pattern with data types f32, bf16, and
@@ -77,24 +100,28 @@ platforms follow the general description in @ref dev_guide_data_types.
77100
2. The GQA patterns functionally support all input shapes meeting the shape
78101
requirements of each operation in the graph.
79102
3. CPU
80-
- Optimized implementation is available for 4D Q/K/V tensors with shape
81-
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
103+
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
104+
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
105+
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
106+
Query and (N, H_kv, 1, S, D) for Key and Value.
82107
- Optimized implementation is available for OpenMP runtime and Threadpool
83108
runtime on Intel Architecture Processors.
84109
- Specifically for OpenMP runtime, the optimized implementation requires `N *
85110
H_q > 2 * thread number` to get enough parallelism.
86111
4. GPU
87-
- Optimized implementation is available for 4D Q/K/V tensors with shape
88-
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
89-
- Optimized implementation is available for floating-point GQA with `f16`
90-
data type and `D <= 512` on Intel Graphics Products with Intel(R) Xe Matrix
91-
Extensions (Intel(R) XMX) support.
112+
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
113+
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
114+
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
115+
Query and (N, H_kv, 1, S, D) for Key and Value.
116+
- Optimized implementation is available for floating-point GQA with `f16` and
117+
`bf16` data type and `D <= 512` on Intel Graphics Products with Intel(R)
118+
Xe Matrix Extensions (Intel(R) XMX) support.
92119

93120
## Example
94121

95122
oneDNN provides a [GQA
96123
example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp)
97-
demonstrating how to construct a floating-point GQA pattern with oneDNN Graph
124+
demonstrating how to construct a 5D floating-point GQA pattern with oneDNN Graph
98125
API on CPU and GPU with different runtimes.
99126

100127
## References
10.9 KB
Loading

0 commit comments

Comments
 (0)