@@ -30,14 +30,24 @@ The notations used in the document:
30
30
31
31
Similar to how SDPA is supported, the GQA pattern is also defined as a
32
32
directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the
33
- [ SDPA pattern] (@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16,
34
- and f16) GQA as follows. The blue nodes are required when defining a GQA pattern
35
- while the brown nodes are optional.
33
+ [ SDPA pattern] (@ref dev_guide_graph_sdpa) to support two types of floating-point
34
+ (f32, bf16, and f16) GQA patterns. The blue nodes are required when defining a
35
+ GQA pattern while the brown nodes are optional. The key difference between the
36
+ two types of GQA patterns lies in whether the input and output tensors have 4D
37
+ or 5D shapes. The optional StaticReshape operations are used to convert the tensors
38
+ between 4D and 5D shape formats, depending on whether the input and output tensors
39
+ are in 4D shapes.
36
40
37
41
![ GQA pattern] ( images/gqa.png )
38
42
39
- Compared to a typical SDPA pattern, there are a few differences in the GQA
40
- pattern:
43
+ ### GQA Pattern with 4D input and output
44
+
45
+ Due to the broadcasting semantics of MatMul, implementing GQA often requires
46
+ additional tensor manipulation. Specifically, when working with 4D input tensors,
47
+ where Query has shape (N, H_q, S, D) and Key/Value have shape (N, H_kv, S, D),
48
+ it is necessary to introduce extra StaticReshape operations to align tensor
49
+ dimensions for the MatMul operations. Therefore, the 4D GQA pattern involves the
50
+ following differences:
41
51
42
52
1 . The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv,
43
53
N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping
@@ -56,6 +66,19 @@ pattern:
56
66
similarly. Besides that, they have the same definition as described in the
57
67
typical SDPA pattern.
58
68
69
+ ### GQA Pattern with 5D input and output
70
+
71
+ To simplify process and avoid unnecessary reshapes, oneDNN also supports native
72
+ 5D GQA pattern. In this approach, the input Query, Key, and Value tensors are
73
+ already provided in grouped format.
74
+
75
+ 1 . The input Query has 5D shape: (N, H_kv, N_rep, S, D)
76
+ 2 . The input Key/Value have 5D shape: (N, H_kv, 1, S, D)
77
+ 3 . The second MatMul calculates the dot products between the probabilities after
78
+ SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D).
79
+ 4 . The input scale factor and mask in the pattern also need to meet the
80
+ operations' shape requirement.
81
+
59
82
## Data Types
60
83
61
84
oneDNN supports the floating-point GQA pattern with data types f32, bf16, and
@@ -77,24 +100,28 @@ platforms follow the general description in @ref dev_guide_data_types.
77
100
2 . The GQA patterns functionally support all input shapes meeting the shape
78
101
requirements of each operation in the graph.
79
102
3 . CPU
80
- - Optimized implementation is available for 4D Q/K/V tensors with shape
81
- defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
103
+ - Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
104
+ the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
105
+ Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
106
+ Query and (N, H_kv, 1, S, D) for Key and Value.
82
107
- Optimized implementation is available for OpenMP runtime and Threadpool
83
108
runtime on Intel Architecture Processors.
84
109
- Specifically for OpenMP runtime, the optimized implementation requires `N *
85
110
H_q > 2 * thread number` to get enough parallelism.
86
111
4 . GPU
87
- - Optimized implementation is available for 4D Q/K/V tensors with shape
88
- defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value.
89
- - Optimized implementation is available for floating-point GQA with ` f16 `
90
- data type and ` D <= 512 ` on Intel Graphics Products with Intel(R) Xe Matrix
91
- Extensions (Intel(R) XMX) support.
112
+ - Optimized implementation is available for 4D and 5D GQA patterns. For 4D,
113
+ the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for
114
+ Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for
115
+ Query and (N, H_kv, 1, S, D) for Key and Value.
116
+ - Optimized implementation is available for floating-point GQA with ` f16 ` and
117
+ ` bf16 ` data type and ` D <= 512 ` on Intel Graphics Products with Intel(R)
118
+ Xe Matrix Extensions (Intel(R) XMX) support.
92
119
93
120
## Example
94
121
95
122
oneDNN provides a [ GQA
96
123
example] ( https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp )
97
- demonstrating how to construct a floating-point GQA pattern with oneDNN Graph
124
+ demonstrating how to construct a 5D floating-point GQA pattern with oneDNN Graph
98
125
API on CPU and GPU with different runtimes.
99
126
100
127
## References
0 commit comments