-
Notifications
You must be signed in to change notification settings - Fork 1k
graph: doc, interface, backend: support SDPA training #3396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
2828827
to
9e6b119
Compare
|:-----|:----------------|:------| | ||
| f32 | f32, bf16, f16 | f32 | | ||
| bf16 | bf16 | f32 | | ||
| f16 | f16 | f32 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LuFinch Could you please confirm that the stats tensor is always f32 in PyTorch?
@ElaineBao It would be better to clarify this in the RFC document as well. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed that all the CUDA FlashAttention/EfficientAttention/CudnnAttention backends use fp32 for logsumexp.
// The order of input logical tensors in inputs is not certain, we need | ||
// to record the input offset in a certain order of ops. | ||
CHECK_BOOL(record_input_offset(sg, inputs)); | ||
dims src1_user_dims = ltw(inputs[graph_inport[mm1_src]]).vdims(); | ||
ndims = src1_user_dims.size(); | ||
VCHECK_SDP_DECOMP(ndims == 4 || ndims == 5, false, | ||
"Input dims should be 4 or 5, but got %zu", src1_user_dims.size()); | ||
VCHECK_SDP_DECOMP( | ||
outputs.size() == 1, false, "Doesn't support SDPA training yet"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to make the error message consistent with the checks.
outputs.size() == 1, false, "Doesn't support SDPA training yet"); | |
outputs.size() == 1, false, "does not support multiple outputs"); |
// At least 3 inputs: Q, K, V | ||
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments, | ||
"At least 3 inputs are required"); | ||
VCHECK_SDP_PRIMITIVE(outputs.size() == 1, status::unimplemented, | ||
"Doesn't support SDPA training yet"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
} | ||
|
||
auto f32_dst = dst; | ||
if (f32_dst->get_logical_tensor().data_type == dnnl::impl::data_type::f32) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess below should work:
if (f32_dst->get_logical_tensor().data_type == dnnl::impl::data_type::f32) { | |
if (f32_dst->get_logical_tensor().data_type == data_type::f32) { |
= empty_logical_tensor_with_default_id(); | ||
f32_dst = std::make_shared<value_t>( | ||
*new_softmax_op, 0, softmax_op_out_lt, true); | ||
f32_dst->set_data_type(dnnl::impl::data_type::f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. I think it's not needed to spell out all the nested namespaces.
@@ -698,6 +698,168 @@ static status_t select_handler( | |||
return status::success; | |||
} | |||
|
|||
static status_t softmax_handler( | |||
const std::shared_ptr<op_t> &op, subgraph_rewriter_t &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of the lowering function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the main purpose is to compute stats with multiple primitives. If stats output does not exist, it will use a single softmax primitive to compute.
@@ -177,6 +177,48 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_fusion_gpu) | |||
return std::make_shared<sdp_base_t<>>(); | |||
}); | |||
|
|||
DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean the inference fusion pattern will be reused for training forward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct
9e6b119
to
6f12739
Compare
6f12739
to
4efa167
Compare
|
||
switch (exec_arg) { | ||
case DNNL_ARG_SRC: | ||
SAFE(::custom::fill_mem(mem, ref_mem, -8, 7), WARN); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we decide this data range?
@@ -203,6 +210,13 @@ int ref_partition_t::init_graph_mem( | |||
} | |||
|
|||
void ref_partition_t::exec_ops(res_t *res) { | |||
// check if there's softmax backward op in the partition, | |||
// which will be a candidate for sdpa training backward pattern | |||
bool has_softmax_backward = std::any_of(partition_ops_ref_.begin(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May use static
?
@@ -263,7 +288,7 @@ void ref_partition_t::exec_ops(res_t *res) { | |||
|| (parent0 == "Multiply" && parent1 == "MatMul"); | |||
} | |||
|
|||
if (is_sdpa_pattern || is_gated_mlp_pattern) { | |||
if (is_sdpa_pattern || is_sdpa_bwd_pattern || is_gated_mlp_pattern) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names are a little bit confusing: actually they are used to indicate whether the current op needs precision downcasting, but not the status of the whole pattern. Could you help improve it?
Description
This PR implements rfcs: graph api: support SDPA training #3233