New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA GPU] int8 convolution on CUDA #30771
[XLA GPU] int8 convolution on CUDA #30771
Conversation
@timshen91 I have updated the PR to address your comments. Please review and approve it again. |
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); | ||
// Helper function to create a custom_call instruction to replace the given | ||
// conv instruction | ||
static StatusOr<HloInstruction*> CreateCustomCall(HloInstruction* conv) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull the lambda out to make a regular helper function to it with TF_ASSIGN_OR_RETURN, so that error status will return to the caller.
@yongfeng-nv Could you please resolve the conflicts? Thanks! |
@yongfeng-nv gentle ping to resolve the conflicts. Thanks! |
47997dd
to
7e32aa5
Compare
instr->operand(0)->shape().element_type() == X); | ||
}; | ||
HloInstruction* convert = match.convert_or_clamp->users()[0]; | ||
if (match.conv->operand_count() < 4 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hard-coded 4 is not ideal. I am open to define it in a proper file.
@@ -260,9 +260,7 @@ StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu( | |||
|
|||
// Fuse bias/scaling/ReLU with convolution custom call with floating point | |||
// output | |||
StatusOr<bool> RunFuseBiasSideActivation( | |||
HloModule* module, | |||
std::unordered_set<const HloInstruction*>& tracked_custom_calls) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need tracked_custom_calls. In the previous commit, I used it to track partially matched convolution, especially those with float output.
@@ -319,22 +312,23 @@ absl::optional<ConvWithConvertOrClamp> FindConvWithClamp( | |||
using match::Op; | |||
|
|||
// The pattern we want to match: | |||
// clamp(broadcast(-128), (get_tuple_element(custom_call(int8_x, | |||
// int8_w, ...)), broadcast(127); | |||
// convert<int8>(clamp(broadcast(-128), (get_tuple_element(custom_call(int8_x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since only int8 output needs clamp and it must come with convert, the pattern includes both of them.
The previous implementation requires a clamp<-128,127> on output for int8-to-float convolution. This matches the current behavior of CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM. However, the clamp is not supposed to be there for float output. This commit removes the clamp from the corresponding patterns. |
@yongfeng-nv Can you please check build failures? Thanks! |
@gbaned I have submitted a fix to the failure under "Ubuntu Sanity — Internal CI build failed". Please let me know if anything else to fix. |
PiperOrigin-RevId: 268764354
@yongfeng-nv Merged, but this required a few changes: please try to take care of compiler warnings before submitting. I'll try to upgrade those to errors on the presubmission testing bots. |
Can you show me the log with warnings? I will fix them. |
…-phase3 PiperOrigin-RevId: 268764354
…-phase3 PiperOrigin-RevId: 268764354
…-phase3 PiperOrigin-RevId: 268764354
…-phase3 PiperOrigin-RevId: 268764354
for (int64 i = 0; i < conv->operand_count(); ++i) { | ||
check_size_increase(conv->operand(i)->shape(), new_input_shapes[i]); | ||
} | ||
check_size_increase(result_shape, new_result_shape); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of a bad bug. This lambda doesn't return anything. Therefore in the lines below, we don't check anything. And apparently there are no unit tests that covered this case.
This went undetected for ~9 months.
This is a breakdown of previous PR (#29158) per @timshen91 's suggestion. It doesn't depend on #30761 or #30762.