Skip to content

Commit 4870e0f

Browse files
committed
doc: graph: update GQA pattern spec
1 parent 814b9b9 commit 4870e0f

File tree

3 files changed

+39
-14
lines changed

3 files changed

+39
-14
lines changed

doc/graph/fusion_patterns/gqa.md

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,20 @@ The notations used in the document:
3030

3131
Similar to how SDPA is supported, the GQA pattern is also defined as a
3232
directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the
33-
[SDPA pattern](@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16,
34-
and f16) GQA as follows. The blue nodes are required when defining a GQA pattern
35-
while the brown nodes are optional.
33+
[SDPA pattern](@ref dev_guide_graph_sdpa) to support 2 types of floating-point
34+
(f32, bf16, and f16) GQA patterns . The blue nodes are required when defining a
35+
GQA pattern while the brown nodes are optional.
3636

37-
![GQA pattern](images/gqa.png)
37+
### GQA Pattern with 4D input
3838

39-
Compared to a typical SDPA pattern, there are a few differences in the GQA
40-
pattern:
39+
![GQA pattern](images/gqa-4D.png)
40+
41+
Due to the broadcasting semantics of MatMul, implementing GQA often requires
42+
additional tensor manipulation. Specifically, when working with 4D input tensors,
43+
where Query has shape (N, H_q, S, D) and Key/Value have shape (N, H_kv, S, D),
44+
it is necessary to introduce extra StaticReshape operations to align tensor
45+
dimensions for the MatMul operations. Therefore, the 4D GQA pattern involves the
46+
following differences:
4147

4248
1. The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv,
4349
N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping
@@ -56,6 +62,21 @@ pattern:
5662
similarly. Besides that, they have the same definition as described in the
5763
typical SDPA pattern.
5864

65+
### GQA Pattern with 5D input
66+
67+
![GQA pattern](images/gqa-5D.png)
68+
69+
To simplify process and avoid unnecessary reshapes, oneDNN also supports native
70+
5D GQA pattern. In this approach, the input Query, Key, and Value tensors are
71+
already provided in grouped format.
72+
73+
1. The input Query has 5D shape: (N, H_kv, N_rep, S, D)
74+
2. The input Key/Value have 5D shape: (N, H_kv, 1, S, D)
75+
3. The second MatMul calculates the dot products between the probabilities after
76+
SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D).
77+
4. The input scale factor and mask in the pattern also need to meet the
78+
operations' shape requirement.
79+
5980
## Data Types
6081

6182
oneDNN supports the floating-point GQA pattern with data types f32, bf16, and
@@ -77,24 +98,28 @@ platforms follow the general description in @ref dev_guide_data_types.
7798
2. The GQA patterns functionally support all input shapes meeting the shape
7899
requirements of each operation in the graph.
79100
3. CPU
80-
- Optimized implementation is available for 4D Q/K/V tensors with shape
81-
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
101+
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
102+
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
103+
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
104+
Query and (N, H_kv, 1, S, D) for Key and Value.
82105
- Optimized implementation is available for OpenMP runtime and Threadpool
83106
runtime on Intel Architecture Processors.
84107
- Specifically for OpenMP runtime, the optimized implementation requires `N *
85108
H_q > 2 * thread number` to get enough parallelism.
86109
4. GPU
87-
- Optimized implementation is available for 4D Q/K/V tensors with shape
88-
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
89-
- Optimized implementation is available for floating-point GQA with `f16`
90-
data type and `D <= 512` on Intel Graphics Products with Intel(R) Xe Matrix
91-
Extensions (Intel(R) XMX) support.
110+
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
111+
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
112+
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
113+
Query and (N, H_kv, 1, S, D) for Key and Value.
114+
- Optimized implementation is available for floating-point GQA with `f16` and
115+
`bf16` data type and `D <= 512` on Intel Graphics Products with Intel(R)
116+
Xe Matrix Extensions (Intel(R) XMX) support.
92117

93118
## Example
94119

95120
oneDNN provides a [GQA
96121
example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp)
97-
demonstrating how to construct a floating-point GQA pattern with oneDNN Graph
122+
demonstrating how to construct a 5D floating-point GQA pattern with oneDNN Graph
98123
API on CPU and GPU with different runtimes.
99124

100125
## References
33.2 KB
Loading

0 commit comments

Comments
 (0)