Skip to content

Commit

Permalink
[XLA:GPU] Minor refactoring to cleanup argument passing in `conv_algo…
Browse files Browse the repository at this point in the history
…rithm_picker`.

The `PickBest...` methods have more arguments that necessary. This change removes the unnecessary arguments, pushing some logic to leaf methods. This makes the top level `PickBestAlgorithmNoCache` simpler and symmetrical when it comes to the different platforms.

PiperOrigin-RevId: 626162732
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Apr 18, 2024
1 parent 0e9cf5f commit 0731ae2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 30 deletions.
40 changes: 15 additions & 25 deletions third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
Expand Up @@ -395,7 +395,6 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(

absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
const HloCustomCallInstruction* instr) {
AutotuneCacheKey key(config_.GetModelStr(), *instr);
if (config_.IsDeviceless()) {
// Return an autotune result with algo id -1, which means that we autotune
// at runtime.
Expand Down Expand Up @@ -423,24 +422,16 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithmNoCache(
"Failed to synchronize GPU for autotuning conv instruction");
}

// allocator either points to this->allocator_ or, if that's null, to a
// se::StreamExecutorMemoryAllocator for stream_exec.
se::DeviceMemoryAllocator* allocator = config_.GetAllocator();

absl::StatusOr<AutotuneResult> result_or(Internal("Unknown platform."));
// Check StreamExecutor on which platform it is. ROCm and Cuda implementation
// have diverged. Specifically, we need to make sure redzone allocator related
// utilities are not used in ROCm routine
se::Platform::Id platform_id = stream_exec->platform()->id();
if (platform_id == se::rocm::kROCmPlatformId) {
result_or = PickBestAlgorithmNoCacheRocm(instr, allocator);
result_or = PickBestAlgorithmNoCacheRocm(instr);
} else if (platform_id == se::cuda::kCudaPlatformId) {
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
DebugOptions debug_opts = instr->GetModule()->config().debug_options();
TF_ASSIGN_OR_RETURN(
AutotuneRuntimeArguments runtime_arguments,
AutotuneRuntimeArguments::FromInstruction(instr, config_, debug_opts));
result_or = PickBestAlgorithmNoCacheCuda(instr, key, runtime_arguments);
result_or = PickBestAlgorithmNoCacheCuda(instr);
#endif
}

Expand Down Expand Up @@ -506,7 +497,7 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
AlgorithmDesc alg_key(alg.algo_id(), alg.tensor_ops_enabled(), std::nullopt);

std::string instr_str = instruction_info.has_value()
? instruction_info->GetHlo().data()
? std::string(instruction_info->GetHlo())
: "<unknown>";

if (absl::c_linear_search(disabled_algos, alg_key)) {
Expand Down Expand Up @@ -748,30 +739,27 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(

absl::StatusOr<AutotuneResult>
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
const HloCustomCallInstruction* instr,
std::optional<AutotuneCacheKey> instruction_info,
const AutotuneRuntimeArguments& runtime_arguments) {
se::StreamExecutor* stream_exec = config_.GetExecutor();

std::string instr_str = instruction_info.has_value()
? instruction_info->GetHlo().data()
: "<unknown>";

const HloCustomCallInstruction* instr) {
AutotuneCacheKey instruction_info{config_.GetModelStr(), *instr};
std::string instr_str(instruction_info.GetHlo());
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
"GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr_str));

const DebugOptions& debug_options =
runtime_arguments.hlo_module_config.debug_options();

instr->GetModule()->config().debug_options();
const bool crash_on_checking_failure =
debug_options.xla_gpu_crash_on_verification_failures();

std::string blas_version;
se::StreamExecutor* stream_exec = config_.GetExecutor();
if (auto* blas = stream_exec->AsBlas()) {
(void)blas->GetVersion(&blas_version);
}

absl::Span<const AlgorithmDesc> disabled_algos;
TF_ASSIGN_OR_RETURN(
AutotuneRuntimeArguments runtime_arguments,
AutotuneRuntimeArguments::FromInstruction(instr, config_, debug_options));
if (runtime_arguments.canonical_hlo.has_value()) {
disabled_algos = GetDisabledConvAlgorithms(
GetComputeCapability(stream_exec), GetCudnnVersion(stream_exec),
Expand Down Expand Up @@ -884,8 +872,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(

absl::StatusOr<AutotuneResult>
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
const HloCustomCallInstruction* instr,
se::DeviceMemoryAllocator* allocator) {
const HloCustomCallInstruction* instr) {
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
"GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));

Expand All @@ -901,6 +888,9 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
const auto device_ordinal = stream_exec->device_ordinal();
std::vector<se::DeviceMemoryBase> operand_buffers;

// allocator either points to this->allocator_ or, if that's null, to a
// se::StreamExecutorMemoryAllocator for stream_exec.
se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
ScratchAllocator input_output_allocator(device_ordinal, allocator);
TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());
const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
Expand Down
7 changes: 2 additions & 5 deletions third_party/xla/xla/service/gpu/conv_algorithm_picker.h
Expand Up @@ -144,14 +144,11 @@ class GpuConvAlgorithmPicker : public HloModulePass {

// Pick the best algorithm for CUDA platform.
absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheCuda(
const HloCustomCallInstruction* instr,
std::optional<AutotuneCacheKey> instruction_info,
const AutotuneRuntimeArguments& runtime_arguments);
const HloCustomCallInstruction* instr);
#endif

absl::StatusOr<AutotuneResult> PickBestAlgorithmNoCacheRocm(
const HloCustomCallInstruction* instr,
se::DeviceMemoryAllocator* allocator);
const HloCustomCallInstruction* instr);

private:
AutotuneConfig config_;
Expand Down

0 comments on commit 0731ae2

Please sign in to comment.