Skip to content

Commit

Permalink
Pass error status to the caller from RunOnInstruction().
Browse files Browse the repository at this point in the history
  • Loading branch information
yongfeng-nv committed Aug 12, 2019
1 parent 32572a3 commit 7e32aa5
Showing 1 changed file with 46 additions and 15 deletions.
61 changes: 46 additions & 15 deletions tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc
Expand Up @@ -556,11 +556,9 @@ CudnnConvBackendConfig GetDefaultBackendConfig() {
return config;
}

// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);

HloInstruction* custom_call = [&]() -> StatusOr<HloInstruction*> {
// Helper function to create a custom_call instruction to replace the given
// conv instruction
static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
Expand All @@ -584,13 +582,40 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
if (primitive_util::IsIntegralType(
conv->operand(0)->shape().element_type()) &&
conv->shape().element_type() != F32) {
return Unimplemented(
"The convolution instruction with integer inputs only allows "
"float outputs. Insert a clamp instruction with range [-128, 127) "
"followed by a convert "
"instruction after the convolution instruction for int8 outputs.");
conv->operand(0)->shape().element_type())) {
// In addition to replacing a convolution instruction with
// a custom call, integer convolutions must have this pattern to match
// CuDNN semantics:
// conv<InputT=int32, ResultT=int32>(
// convert<int32>(int8_x), convert<int32>(int8_y))
// We transform it to:
// custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward)
//
// We will error out, if the pattern is not found for integer
// convolution.
const auto is_int8_to_int32_cast =
[](const HloInstruction* instr) -> bool {
return (instr->opcode() == HloOpcode::kConvert &&
instr->operand(0)->shape().element_type() == S8 &&
instr->shape().element_type() == S32);
};
HloInstruction* input_convert = conv->mutable_operand(0);
HloInstruction* kernel_convert = conv->mutable_operand(1);
if (conv->shape().element_type() != S32 ||
!is_int8_to_int32_cast(input_convert) ||
!is_int8_to_int32_cast(kernel_convert)) {
return Unimplemented(
"Integer convolutions for CuDNN must have this pattern: "
"conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), "
"convert<int32>(int8_y))");
}
// Bypass the convert<int32> for both inputs.
conv->ReplaceOperandWithDifferentShape(
0, input_convert->mutable_operand(0));
conv->parent()->RemoveInstructionAndUnusedOperands(input_convert);
conv->ReplaceOperandWithDifferentShape(
1, kernel_convert->mutable_operand(0));
conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert);
}
return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(),
conv->mutable_operand(0), conv->mutable_operand(1),
Expand All @@ -600,8 +625,14 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
}

return nullptr;
}().ValueOrDie();
}

// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);

TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
CreateCustomCallHelper(conv));
if (custom_call == nullptr) {
return false;
}
Expand All @@ -612,8 +643,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
<< custom_call->ToString();

// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract
// out the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
conv,
HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
Expand Down

0 comments on commit 7e32aa5

Please sign in to comment.