diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index 1dfd87f8c3b241..e45db4c6d82057 100755 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -556,11 +556,9 @@ CudnnConvBackendConfig GetDefaultBackendConfig() { return config; } -// Tries to rewrite a single convolution into a call to cudnn. -StatusOr RunOnInstruction(HloInstruction* conv) { - CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - - HloInstruction* custom_call = [&]() -> StatusOr { +// Helper function to create a custom_call instruction to replace the given +// conv instruction +static StatusOr CreateCustomCallHelper(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; @@ -584,13 +582,40 @@ StatusOr RunOnInstruction(HloInstruction* conv) { // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { if (primitive_util::IsIntegralType( - conv->operand(0)->shape().element_type()) && - conv->shape().element_type() != F32) { - return Unimplemented( - "The convolution instruction with integer inputs only allows " - "float outputs. Insert a clamp instruction with range [-128, 127) " - "followed by a convert " - "instruction after the convolution instruction for int8 outputs."); + conv->operand(0)->shape().element_type())) { + // In addition to replacing a convolution instruction with + // a custom call, integer convolutions must have this pattern to match + // CuDNN semantics: + // conv( + // convert(int8_x), convert(int8_y)) + // We transform it to: + // custom_call(int8_x, int8_y, target=cudnnConvolutionForward) + // + // We will error out, if the pattern is not found for integer + // convolution. + const auto is_int8_to_int32_cast = + [](const HloInstruction* instr) -> bool { + return (instr->opcode() == HloOpcode::kConvert && + instr->operand(0)->shape().element_type() == S8 && + instr->shape().element_type() == S32); + }; + HloInstruction* input_convert = conv->mutable_operand(0); + HloInstruction* kernel_convert = conv->mutable_operand(1); + if (conv->shape().element_type() != S32 || + !is_int8_to_int32_cast(input_convert) || + !is_int8_to_int32_cast(kernel_convert)) { + return Unimplemented( + "Integer convolutions for CuDNN must have this pattern: " + "conv(convert(int8_x), " + "convert(int8_y))"); + } + // Bypass the convert for both inputs. + conv->ReplaceOperandWithDifferentShape( + 0, input_convert->mutable_operand(0)); + conv->parent()->RemoveInstructionAndUnusedOperands(input_convert); + conv->ReplaceOperandWithDifferentShape( + 1, kernel_convert->mutable_operand(0)); + conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert); } return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), @@ -600,8 +625,14 @@ StatusOr RunOnInstruction(HloInstruction* conv) { } return nullptr; - }().ValueOrDie(); +} + +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); + TF_ASSIGN_OR_RETURN(HloInstruction * custom_call, + CreateCustomCallHelper(conv)); if (custom_call == nullptr) { return false; } @@ -612,8 +643,8 @@ StatusOr RunOnInstruction(HloInstruction* conv) { VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << custom_call->ToString(); - // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out - // the conv result and replace `conv` with it. + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract + // out the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( conv, HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));