Skip to content

Commit

Permalink
Merge pull request #30771 from yongfeng-nv:int8-convolution-phase3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 268764354
  • Loading branch information
tensorflower-gardener committed Sep 12, 2019
2 parents 3576dfe + f04619b commit eb94f41
Show file tree
Hide file tree
Showing 19 changed files with 1,106 additions and 174 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/literal.cc
Expand Up @@ -944,6 +944,8 @@ absl::optional<complex128> LiteralBase::GetAsComplex128(
return {Get<complex64>(multi_index)};
case C128:
return {Get<complex128>(multi_index)};
case S8:
return {Get<int8>(multi_index)};
default:
return absl::nullopt;
}
Expand Down
25 changes: 17 additions & 8 deletions tensorflow/compiler/xla/service/gpu/BUILD
Expand Up @@ -716,6 +716,7 @@ tf_cc_test(
deps = [
":cudnn_conv_rewriter",
":ir_emission_utils",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo",
Expand Down Expand Up @@ -949,11 +950,12 @@ cc_library(
)

cc_library(
name = "cudnn_conv_pad_for_tensor_cores",
srcs = ["cudnn_conv_pad_for_tensor_cores.cc"],
hdrs = ["cudnn_conv_pad_for_tensor_cores.h"],
name = "cudnn_pad_for_convolutions",
srcs = ["cudnn_pad_for_convolutions.cc"],
hdrs = ["cudnn_pad_for_convolutions.h"],
deps = [
":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
Expand All @@ -963,10 +965,10 @@ cc_library(
)

tf_cc_test(
name = "cudnn_conv_pad_for_tensor_cores_test",
srcs = ["cudnn_conv_pad_for_tensor_cores_test.cc"],
name = "cudnn_pad_for_convolutions_test",
srcs = ["cudnn_pad_for_convolutions_test.cc"],
deps = [
":cudnn_conv_pad_for_tensor_cores",
":cudnn_pad_for_convolutions",
":ir_emission_utils",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
Expand Down Expand Up @@ -1054,6 +1056,7 @@ cc_library(
":cudnn_batchnorm_rewriter",
":cudnn_conv_padding_legalization",
":cudnn_conv_rewriter",
":cudnn_pad_for_convolutions",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
Expand Down Expand Up @@ -1154,10 +1157,10 @@ cc_library(
deps = [
":cublas_gemm_pad_for_tensor_cores",
":cudnn_conv_algorithm_picker",
":cudnn_conv_pad_for_tensor_cores",
":cudnn_conv_padding_legalization",
":cudnn_conv_rewriter",
":cudnn_fused_conv_rewriter",
":cudnn_pad_for_convolutions",
":cusolver_rewriter",
":gemm_algorithm_picker",
":gemm_rewriter",
Expand Down Expand Up @@ -1507,9 +1510,15 @@ cc_library(
tf_cc_test(
name = "cudnn_fused_conv_rewriter_test",
srcs = ["cudnn_fused_conv_rewriter_test.cc"],
tags = tf_cuda_tests_tags(),
tags = [
"noasan",
"nomsan",
"requires-gpu-sm70",
],
deps = [
":cudnn_fused_conv_rewriter",
":ir_emission_utils",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
Expand Down
Expand Up @@ -271,10 +271,10 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(

int64 rng_state = 0;

const auto initialize_buffer = [stream, &result_shape,
&rng_state](DeviceMemoryBase buffer) {
InitializeFloatBuffer(stream, result_shape.element_type(), &rng_state,
buffer);
const auto initialize_buffer = [&stream, &rng_state](
DeviceMemoryBase buffer,
const Shape& buffer_shape) {
InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
};

const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
Expand All @@ -287,13 +287,13 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
TF_ASSIGN_OR_RETURN(auto buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(operand->shape())));
initialize_buffer(buffer);
initialize_buffer(buffer, operand->shape());
operand_buffers.push_back(buffer);
}
TF_ASSIGN_OR_RETURN(auto result_buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(result_shape)));
initialize_buffer(result_buffer);
initialize_buffer(result_buffer, result_shape);

TF_ASSIGN_OR_RETURN(auto backend_config,
instr->backend_config<CudnnConvBackendConfig>());
Expand Down

This file was deleted.

108 changes: 75 additions & 33 deletions tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc
Expand Up @@ -619,43 +619,85 @@ CudnnConvBackendConfig GetDefaultBackendConfig() {
return config;
}

// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
// Helper function to create a custom_call instruction to replace the given
// conv instruction
static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
HloInstruction* rhs;
HloInstruction* lhs;

std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
conv->mutable_operand(0), rhs, window, dnums,
conv->feature_group_count(), conv->metadata());
}

HloInstruction* custom_call = [&]() -> HloInstruction* {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
HloInstruction* rhs;
HloInstruction* lhs;

std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
conv->mutable_operand(0), rhs, window, dnums,
conv->feature_group_count(), conv->metadata());
}
std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv);
if (match) {
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
lhs, conv->mutable_operand(1), window, dnums,
conv->feature_group_count(), conv->metadata());
}

std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv);
if (match) {
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
lhs, conv->mutable_operand(1), window, dnums,
conv->feature_group_count(), conv->metadata());
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
if (primitive_util::IsIntegralType(
conv->operand(0)->shape().element_type())) {
// In addition to replacing a convolution instruction with
// a custom call, integer convolutions must have this pattern to match
// CuDNN semantics:
// conv<InputT=int32, ResultT=int32>(
// convert<int32>(int8_x), convert<int32>(int8_y))
// We transform it to:
// custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward)
//
// We will error out, if the pattern is not found for integer
// convolution.
const auto is_int8_to_int32_cast =
[](const HloInstruction* instr) -> bool {
return (instr->opcode() == HloOpcode::kConvert &&
instr->operand(0)->shape().element_type() == S8 &&
instr->shape().element_type() == S32);
};
HloInstruction* input_convert = conv->mutable_operand(0);
HloInstruction* kernel_convert = conv->mutable_operand(1);
if (conv->shape().element_type() != S32 ||
!is_int8_to_int32_cast(input_convert) ||
!is_int8_to_int32_cast(kernel_convert)) {
return Unimplemented(
"Integer convolutions for CuDNN must have this pattern: "
"conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), "
"convert<int32>(int8_y))");
}
// Bypass the convert<int32> for both inputs.
TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
0, input_convert->mutable_operand(0)));
TF_RETURN_IF_ERROR(
conv->parent()->RemoveInstructionAndUnusedOperands(input_convert));
TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
1, kernel_convert->mutable_operand(0)));
TF_RETURN_IF_ERROR(
conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert));
}
return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(),
conv->mutable_operand(0), conv->mutable_operand(1),
conv->window(),
conv->convolution_dimension_numbers(),
conv->feature_group_count(), conv->metadata());
}

// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(),
conv->mutable_operand(0), conv->mutable_operand(1),
conv->window(),
conv->convolution_dimension_numbers(),
conv->feature_group_count(), conv->metadata());
}
return nullptr;
}

return nullptr;
}();
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);

TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
CreateCustomCallHelper(conv));
if (custom_call == nullptr) {
return false;
}
Expand All @@ -666,8 +708,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
<< custom_call->ToString();

// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract
// out the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
conv,
HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h
Expand Up @@ -24,6 +24,14 @@ namespace gpu {

// Rewrites plain convolutions, backwards-filter convolutions, and
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
// For integer convolution, it requires the following pattern:
// conv<InputT=int32, ResultT=int32>(
// convert<int32>(int8_x), convert<int32>(int8_y))
// We transform it to:
// custom_call<int32>(int8_x, int8_y, target=cudnnForwardConvolution)
// Note that this pattern is necessary but not sufficient to map convolutions
// to CuDNN. More patterns will be matched in cudnn_fused_conv_rewriter.

class CudnnConvRewriter : public HloModulePass {
public:
absl::string_view name() const override { return "cudnn-conv-rewriter"; }
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc
Expand Up @@ -711,6 +711,21 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) {
0));
}

// Check that a forward convolution instruction with int8 inputs is not allowed
TEST_F(CudnnConvRewriterTest, TestForwardInt8Convolution) {
const string module_str = absl::StrFormat(R"(
HloModule Test
ENTRY Test {
input = s8[1,2,3,3] parameter(0)
filter = s8[3,3,2,5] parameter(1)
ROOT conv = s8[1,5,3,3] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
})");
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));

ASSERT_FALSE(CudnnConvRewriter().Run(m.get()).ok());
}
} // anonymous namespace
} // namespace gpu
} // namespace xla

0 comments on commit eb94f41

Please sign in to comment.