@@ -30,14 +30,20 @@ The notations used in the document:
30
30
31
31
Similar to how SDPA is supported, the GQA pattern is also defined as a
32
32
directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the
33
- [ SDPA pattern] (@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16,
34
- and f16) GQA as follows . The blue nodes are required when defining a GQA pattern
35
- while the brown nodes are optional.
33
+ [ SDPA pattern] (@ref dev_guide_graph_sdpa) to support 2 types of floating-point
34
+ (f32, bf16, and f16) GQA patterns . The blue nodes are required when defining a
35
+ GQA pattern while the brown nodes are optional.
36
36
37
- ![ GQA pattern ] ( images/gqa.png )
37
+ ### GQA Pattern with 4D input
38
38
39
- Compared to a typical SDPA pattern, there are a few differences in the GQA
40
- pattern:
39
+ ![ GQA pattern] ( images/gqa-4D.png )
40
+
41
+ Due to the broadcasting semantics of MatMul, implementing GQA often requires
42
+ additional tensor manipulation. Specifically, when working with 4D input tensors,
43
+ where Query has shape (N, H_q, S, D) and Key/Value have shape (N, H_kv, S, D),
44
+ it is necessary to introduce extra StaticReshape operations to align tensor
45
+ dimensions for the MatMul operations. Therefore, the 4D GQA pattern involves the
46
+ following differences:
41
47
42
48
1 . The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv,
43
49
N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping
@@ -56,6 +62,21 @@ pattern:
56
62
similarly. Besides that, they have the same definition as described in the
57
63
typical SDPA pattern.
58
64
65
+ ### GQA Pattern with 5D input
66
+
67
+ ![ GQA pattern] ( images/gqa-5D.png )
68
+
69
+ To simplify process and avoid unnecessary reshapes, oneDNN also supports native
70
+ 5D GQA pattern. In this approach, the input Query, Key, and Value tensors are
71
+ already provided in grouped format.
72
+
73
+ 1 . The input Query has 5D shape: (N, H_kv, N_rep, S, D)
74
+ 2 . The input Key/Value have 5D shape: (N, H_kv, 1, S, D)
75
+ 3 . The second MatMul calculates the dot products between the probabilities after
76
+ SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D).
77
+ 4 . The input scale factor and mask in the pattern also need to meet the
78
+ operations' shape requirement.
79
+
59
80
## Data Types
60
81
61
82
oneDNN supports the floating-point GQA pattern with data types f32, bf16, and
@@ -77,24 +98,28 @@ platforms follow the general description in @ref dev_guide_data_types.
77
98
2 . The GQA patterns functionally support all input shapes meeting the shape
78
99
requirements of each operation in the graph.
79
100
3 . CPU
80
- - Optimized implementation is available for 4D Q/K/V tensors with shape
81
- defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
101
+ - Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
102
+ the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
103
+ Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
104
+ Query and (N, H_kv, 1, S, D) for Key and Value.
82
105
- Optimized implementation is available for OpenMP runtime and Threadpool
83
106
runtime on Intel Architecture Processors.
84
107
- Specifically for OpenMP runtime, the optimized implementation requires `N *
85
108
H_q > 2 * thread number` to get enough parallelism.
86
109
4 . GPU
87
- - Optimized implementation is available for 4D Q/K/V tensors with shape
88
- defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
89
- - Optimized implementation is available for floating-point GQA with ` f16 `
90
- data type and ` D <= 512 ` on Intel Graphics Products with Intel(R) Xe Matrix
91
- Extensions (Intel(R) XMX) support.
110
+ - Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
111
+ the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
112
+ Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
113
+ Query and (N, H_kv, 1, S, D) for Key and Value.
114
+ - Optimized implementation is available for floating-point GQA with ` f16 ` and
115
+ ` bf16 ` data type and ` D <= 512 ` on Intel Graphics Products with Intel(R)
116
+ Xe Matrix Extensions (Intel(R) XMX) support.
92
117
93
118
## Example
94
119
95
120
oneDNN provides a [ GQA
96
121
example] ( https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp )
97
- demonstrating how to construct a floating-point GQA pattern with oneDNN Graph
122
+ demonstrating how to construct a 5D floating-point GQA pattern with oneDNN Graph
98
123
API on CPU and GPU with different runtimes.
99
124
100
125
## References
0 commit comments