-
Notifications
You must be signed in to change notification settings - Fork 1k
[Graph| tests, example, doc] Add GQA v2 support for implicit causal mask and example, doc update #3409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Graph| tests, example, doc] Add GQA v2 support for implicit causal mask and example, doc update #3409
Changes from all commits
3a02dd7
cd63179
814b9b9
8456adf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -30,14 +30,24 @@ The notations used in the document: | |||||
|
||||||
Similar to how SDPA is supported, the GQA pattern is also defined as a | ||||||
directional acyclic graph (DAG) using oneDNN Graph API. oneDNN extends the | ||||||
[SDPA pattern](@ref dev_guide_graph_sdpa) to support floating-point (f32, bf16, | ||||||
and f16) GQA as follows. The blue nodes are required when defining a GQA pattern | ||||||
while the brown nodes are optional. | ||||||
[SDPA pattern](@ref dev_guide_graph_sdpa) to support two types of floating-point | ||||||
(f32, bf16, and f16) GQA patterns. The blue nodes are required when defining a | ||||||
GQA pattern while the brown nodes are optional. The key difference between the | ||||||
two types of GQA patterns lies in whether the input and output tensors have 4D | ||||||
or 5D shapes. The optional StaticReshape operations are used to convert the tensors | ||||||
between 4D and 5D shape formats, depending on whether the input and output tensors | ||||||
are in 4D shapes. | ||||||
|
||||||
 | ||||||
|
||||||
Compared to a typical SDPA pattern, there are a few differences in the GQA | ||||||
pattern: | ||||||
### GQA Pattern with 4D input and output | ||||||
|
||||||
Due to the broadcasting semantics of MatMul, implementing GQA often requires | ||||||
additional tensor manipulation. Specifically, when working with 4D input tensors, | ||||||
where Query has shape (N, H_q, S, D) and Key/Value have shape (N, H_kv, S, D), | ||||||
it is necessary to introduce extra StaticReshape operations to align tensor | ||||||
dimensions for the MatMul operations. Therefore, the 4D GQA pattern involves the | ||||||
following differences: | ||||||
|
||||||
1. The input Query has shape (N, H_q, S, D). It will be reshaped to (N, H_kv, | ||||||
N_rep, S, D) by splitting H_q dimension into H_kv and N_rep. The reshaping | ||||||
|
@@ -56,6 +66,19 @@ pattern: | |||||
similarly. Besides that, they have the same definition as described in the | ||||||
typical SDPA pattern. | ||||||
|
||||||
### GQA Pattern with 5D input and output | ||||||
|
||||||
To simplify process and avoid unnecessary reshapes, oneDNN also supports native | ||||||
5D GQA pattern. In this approach, the input Query, Key, and Value tensors are | ||||||
already provided in grouped format. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
1. The input Query has 5D shape: (N, H_kv, N_rep, S, D) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
2. The input Key/Value have 5D shape: (N, H_kv, 1, S, D) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The input paramater can contain either a Key or Value? IF it's an either/or situation, then singular verb should be used.
Suggested change
|
||||||
3. The second MatMul calculates the dot products between the probabilities after | ||||||
SoftMax and Value nodes and generates output with shape (N, H_kv, N_rep, S, D). | ||||||
4. The input scale factor and mask in the pattern also need to meet the | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
operations' shape requirement. | ||||||
|
||||||
## Data Types | ||||||
|
||||||
oneDNN supports the floating-point GQA pattern with data types f32, bf16, and | ||||||
|
@@ -77,24 +100,28 @@ platforms follow the general description in @ref dev_guide_data_types. | |||||
2. The GQA patterns functionally support all input shapes meeting the shape | ||||||
requirements of each operation in the graph. | ||||||
3. CPU | ||||||
- Optimized implementation is available for 4D Q/K/V tensors with shape | ||||||
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value. | ||||||
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D, | ||||||
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for | ||||||
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for | ||||||
Query and (N, H_kv, 1, S, D) for Key and Value. | ||||||
- Optimized implementation is available for OpenMP runtime and Threadpool | ||||||
runtime on Intel Architecture Processors. | ||||||
- Specifically for OpenMP runtime, the optimized implementation requires `N * | ||||||
H_q > 2 * thread number` to get enough parallelism. | ||||||
4. GPU | ||||||
- Optimized implementation is available for 4D Q/K/V tensors with shape | ||||||
defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for Key and Value. | ||||||
- Optimized implementation is available for floating-point GQA with `f16` | ||||||
data type and `D <= 512` on Intel Graphics Products with Intel(R) Xe Matrix | ||||||
Extensions (Intel(R) XMX) support. | ||||||
- Optimized implementation is available for 4D and 5D GQA patterns. For 4D, | ||||||
the shapes are defined as (N, H_q, S, D) for Query and (N, H_kv, S, D) for | ||||||
Key and Value. For 5D, the shapes are defined as (N, H_kv, N_rep, S, D) for | ||||||
Query and (N, H_kv, 1, S, D) for Key and Value. | ||||||
- Optimized implementation is available for floating-point GQA with `f16` and | ||||||
`bf16` data type and `D <= 512` on Intel Graphics Products with Intel(R) | ||||||
Xe Matrix Extensions (Intel(R) XMX) support. | ||||||
|
||||||
## Example | ||||||
|
||||||
oneDNN provides a [GQA | ||||||
example](https://github.com/uxlfoundation/oneDNN/tree/main/examples/graph/gqa.cpp) | ||||||
demonstrating how to construct a floating-point GQA pattern with oneDNN Graph | ||||||
demonstrating how to construct a 5D floating-point GQA pattern with oneDNN Graph | ||||||
API on CPU and GPU with different runtimes. | ||||||
|
||||||
## References | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.