Skip to content

Commit 7e32aa5

Browse files
committed
Pass error status to the caller from RunOnInstruction().
1 parent 32572a3 commit 7e32aa5

File tree

1 file changed

+46
-15
lines changed

1 file changed

+46
-15
lines changed

tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -556,11 +556,9 @@ CudnnConvBackendConfig GetDefaultBackendConfig() {
556556
return config;
557557
}
558558

559-
// Tries to rewrite a single convolution into a call to cudnn.
560-
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
561-
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
562-
563-
HloInstruction* custom_call = [&]() -> StatusOr<HloInstruction*> {
559+
// Helper function to create a custom_call instruction to replace the given
560+
// conv instruction
561+
static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
564562
bool match;
565563
Window window;
566564
ConvolutionDimensionNumbers dnums;
@@ -584,13 +582,40 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
584582
// If all else fails, try a forward convolution.
585583
if (CanImplementAsCudnnForwardConv(conv)) {
586584
if (primitive_util::IsIntegralType(
587-
conv->operand(0)->shape().element_type()) &&
588-
conv->shape().element_type() != F32) {
589-
return Unimplemented(
590-
"The convolution instruction with integer inputs only allows "
591-
"float outputs. Insert a clamp instruction with range [-128, 127) "
592-
"followed by a convert "
593-
"instruction after the convolution instruction for int8 outputs.");
585+
conv->operand(0)->shape().element_type())) {
586+
// In addition to replacing a convolution instruction with
587+
// a custom call, integer convolutions must have this pattern to match
588+
// CuDNN semantics:
589+
// conv<InputT=int32, ResultT=int32>(
590+
// convert<int32>(int8_x), convert<int32>(int8_y))
591+
// We transform it to:
592+
// custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward)
593+
//
594+
// We will error out, if the pattern is not found for integer
595+
// convolution.
596+
const auto is_int8_to_int32_cast =
597+
[](const HloInstruction* instr) -> bool {
598+
return (instr->opcode() == HloOpcode::kConvert &&
599+
instr->operand(0)->shape().element_type() == S8 &&
600+
instr->shape().element_type() == S32);
601+
};
602+
HloInstruction* input_convert = conv->mutable_operand(0);
603+
HloInstruction* kernel_convert = conv->mutable_operand(1);
604+
if (conv->shape().element_type() != S32 ||
605+
!is_int8_to_int32_cast(input_convert) ||
606+
!is_int8_to_int32_cast(kernel_convert)) {
607+
return Unimplemented(
608+
"Integer convolutions for CuDNN must have this pattern: "
609+
"conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), "
610+
"convert<int32>(int8_y))");
611+
}
612+
// Bypass the convert<int32> for both inputs.
613+
conv->ReplaceOperandWithDifferentShape(
614+
0, input_convert->mutable_operand(0));
615+
conv->parent()->RemoveInstructionAndUnusedOperands(input_convert);
616+
conv->ReplaceOperandWithDifferentShape(
617+
1, kernel_convert->mutable_operand(0));
618+
conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert);
594619
}
595620
return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(),
596621
conv->mutable_operand(0), conv->mutable_operand(1),
@@ -600,8 +625,14 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
600625
}
601626

602627
return nullptr;
603-
}().ValueOrDie();
628+
}
629+
630+
// Tries to rewrite a single convolution into a call to cudnn.
631+
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
632+
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
604633

634+
TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
635+
CreateCustomCallHelper(conv));
605636
if (custom_call == nullptr) {
606637
return false;
607638
}
@@ -612,8 +643,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
612643
VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
613644
<< custom_call->ToString();
614645

615-
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
616-
// the conv result and replace `conv` with it.
646+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract
647+
// out the conv result and replace `conv` with it.
617648
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
618649
conv,
619650
HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));

0 commit comments

Comments
 (0)