@@ -556,11 +556,9 @@ CudnnConvBackendConfig GetDefaultBackendConfig() {
556
556
return config;
557
557
}
558
558
559
- // Tries to rewrite a single convolution into a call to cudnn.
560
- StatusOr<bool > RunOnInstruction (HloInstruction* conv) {
561
- CHECK_EQ (conv->opcode (), HloOpcode::kConvolution );
562
-
563
- HloInstruction* custom_call = [&]() -> StatusOr<HloInstruction*> {
559
+ // Helper function to create a custom_call instruction to replace the given
560
+ // conv instruction
561
+ static StatusOr<HloInstruction*> CreateCustomCallHelper (HloInstruction* conv) {
564
562
bool match;
565
563
Window window;
566
564
ConvolutionDimensionNumbers dnums;
@@ -584,13 +582,40 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
584
582
// If all else fails, try a forward convolution.
585
583
if (CanImplementAsCudnnForwardConv (conv)) {
586
584
if (primitive_util::IsIntegralType (
587
- conv->operand (0 )->shape ().element_type ()) &&
588
- conv->shape ().element_type () != F32) {
589
- return Unimplemented (
590
- " The convolution instruction with integer inputs only allows "
591
- " float outputs. Insert a clamp instruction with range [-128, 127) "
592
- " followed by a convert "
593
- " instruction after the convolution instruction for int8 outputs." );
585
+ conv->operand (0 )->shape ().element_type ())) {
586
+ // In addition to replacing a convolution instruction with
587
+ // a custom call, integer convolutions must have this pattern to match
588
+ // CuDNN semantics:
589
+ // conv<InputT=int32, ResultT=int32>(
590
+ // convert<int32>(int8_x), convert<int32>(int8_y))
591
+ // We transform it to:
592
+ // custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward)
593
+ //
594
+ // We will error out, if the pattern is not found for integer
595
+ // convolution.
596
+ const auto is_int8_to_int32_cast =
597
+ [](const HloInstruction* instr) -> bool {
598
+ return (instr->opcode () == HloOpcode::kConvert &&
599
+ instr->operand (0 )->shape ().element_type () == S8 &&
600
+ instr->shape ().element_type () == S32);
601
+ };
602
+ HloInstruction* input_convert = conv->mutable_operand (0 );
603
+ HloInstruction* kernel_convert = conv->mutable_operand (1 );
604
+ if (conv->shape ().element_type () != S32 ||
605
+ !is_int8_to_int32_cast (input_convert) ||
606
+ !is_int8_to_int32_cast (kernel_convert)) {
607
+ return Unimplemented (
608
+ " Integer convolutions for CuDNN must have this pattern: "
609
+ " conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), "
610
+ " convert<int32>(int8_y))" );
611
+ }
612
+ // Bypass the convert<int32> for both inputs.
613
+ conv->ReplaceOperandWithDifferentShape (
614
+ 0 , input_convert->mutable_operand (0 ));
615
+ conv->parent ()->RemoveInstructionAndUnusedOperands (input_convert);
616
+ conv->ReplaceOperandWithDifferentShape (
617
+ 1 , kernel_convert->mutable_operand (0 ));
618
+ conv->parent ()->RemoveInstructionAndUnusedOperands (kernel_convert);
594
619
}
595
620
return CreateCudnnConv (kCudnnConvForwardCallTarget , conv->shape (),
596
621
conv->mutable_operand (0 ), conv->mutable_operand (1 ),
@@ -600,8 +625,14 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
600
625
}
601
626
602
627
return nullptr ;
603
- }().ValueOrDie ();
628
+ }
629
+
630
+ // Tries to rewrite a single convolution into a call to cudnn.
631
+ StatusOr<bool > RunOnInstruction (HloInstruction* conv) {
632
+ CHECK_EQ (conv->opcode (), HloOpcode::kConvolution );
604
633
634
+ TF_ASSIGN_OR_RETURN (HloInstruction * custom_call,
635
+ CreateCustomCallHelper (conv));
605
636
if (custom_call == nullptr ) {
606
637
return false ;
607
638
}
@@ -612,8 +643,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
612
643
VLOG (1 ) << " Replacing convolution " << conv->ToString () << " with "
613
644
<< custom_call->ToString ();
614
645
615
- // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
616
- // the conv result and replace `conv` with it.
646
+ // The CustomCall returns a tuple (conv_result, scratch_memory). Extract
647
+ // out the conv result and replace `conv` with it.
617
648
TF_RETURN_IF_ERROR (conv->parent ()->ReplaceWithNewInstruction (
618
649
conv,
619
650
HloInstruction::CreateGetTupleElement (conv->shape (), custom_call, 0 )));
0 commit comments