Skip to content

Commit

Permalink
PR #12991: [GPU] Refactor GEMM fusion autotuner.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#12991

Copybara import of the project:

--
f6e85691b51bdd17fc6167ad8f8cb00d7411fec9 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Replace a hash map by a vector of pairs in GEMM fusion autotuner.

This will be needed for multi-host autotuning to split tasks between hosts reliably.

--
44682f67eb6e9c521f381a7c3dd93fee95f59366 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU][NFC] Refactor for loop iterators.

Merging this change closes #12991

PiperOrigin-RevId: 636506422
  • Loading branch information
sergachev authored and tensorflower-gardener committed May 23, 2024
1 parent 59c9b09 commit 2dda854
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 37 deletions.
50 changes: 19 additions & 31 deletions third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ namespace xla {
namespace gpu {

using Config = GemmFusionAutotunerImpl::Config;
using TilingConfigs = GemmFusionAutotunerImpl::TilingConfigs;
using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;

namespace {
Expand Down Expand Up @@ -204,16 +205,13 @@ class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor {
AutotuneConfig config_;
};

using TilingConfigsMap =
absl::flat_hash_map<const HloFusionInstruction*, std::vector<Config>>;

class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {
public:
explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl)
: impl_(impl) {}

// Find configurations to tune.
absl::StatusOr<TilingConfigsMap> CollectGemmConfigSets(
absl::StatusOr<TilingConfigs> CollectGemmConfigSets(
const HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
error_out_on_cache_miss_ =
Expand Down Expand Up @@ -242,9 +240,9 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {

AutotuneCacheKey key = AutotunerUtil::GetKey(hlo, impl_->GetConfig());

auto insertion_result = fusion_count_map_.insert({key, 1});
if (!insertion_result.second) {
++(insertion_result.first->second);
auto [iterator, inserted] = fusion_count_map_.insert({key, 1});
if (!inserted) {
++(iterator->second);
}

if (AutotunerUtil::IsInCache(key) || handled_fusions_.contains(key)) {
Expand All @@ -265,8 +263,7 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {

TF_ASSIGN_OR_RETURN(std::vector<Config> configs,
impl_->GenerateConfigs(*fusion));
TF_RET_CHECK(
gemm_config_sets_.insert({fusion, std::move(configs)}).second);
gemm_config_sets_.push_back({fusion, std::move(configs)});
}

handled_fusions_.insert(key);
Expand All @@ -280,7 +277,7 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault {
private:
bool error_out_on_cache_miss_;
GemmFusionAutotunerImpl* impl_;
TilingConfigsMap gemm_config_sets_;
TilingConfigs gemm_config_sets_;
AutoTuneCacheKeyCount fusion_count_map_;
absl::flat_hash_set<AutotuneCacheKey> handled_fusions_;
};
Expand Down Expand Up @@ -699,10 +696,8 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
absl::StatusOr<absl::flat_hash_map<
const HloFusionInstruction*,
std::vector<GemmFusionAutotunerImpl::ExecutableCandidate>>>
GemmFusionAutotunerImpl::CompileAll(
AutotunerCompileUtil& compile_util,
const absl::flat_hash_map<const HloFusionInstruction*, std::vector<Config>>&
task) {
GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util,
const TilingConfigs& task) {
tsl::profiler::ScopedAnnotation annotation("XlaAutotunerCompilation");
absl::Mutex results_mu;
absl::flat_hash_map<const HloFusionInstruction*,
Expand All @@ -714,8 +709,8 @@ GemmFusionAutotunerImpl::CompileAll(

const int log_every_n = GetLogEveryN();
int64_t config_count = 0;
for (const auto& key_value : task) {
config_count += key_value.second.size();
for (const auto& [unused, configs] : task) {
config_count += configs.size();
}

std::atomic<int> done_count = 0;
Expand Down Expand Up @@ -820,10 +815,7 @@ GemmFusionAutotunerImpl::CompileAll(
<< task.size() << " fusions on a single thread.";
}

for (const auto& key_value : task) {
const HloFusionInstruction* fusion = key_value.first;
const auto& gemm_config_set = key_value.second;

for (const auto& [fusion, gemm_config_set] : task) {
VLOG(10) << "Compiling fusion: " << fusion->name();
VLOG(10) << "Dumping fusion computation: "
<< fusion->called_computation()->ToString();
Expand Down Expand Up @@ -1060,28 +1052,24 @@ absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts,
}

absl::Status GemmFusionAutotunerImpl::Autotune(
AutotunerCompileUtil& compile_util,
const absl::flat_hash_map<const HloFusionInstruction*, std::vector<Config>>&
gemm_config_sets,
AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
AutoTuneCacheKeyCount fusion_count_map) {
TF_ASSIGN_OR_RETURN(auto executable_sets,
CompileAll(compile_util, gemm_config_sets));

// Sort the candidates to make their execution order well-defined for each
// fusion.
for (auto& key_value : executable_sets) {
absl::c_sort(key_value.second, [](const auto& a, const auto& b) {
for (auto& [unused, candidates] : executable_sets) {
absl::c_sort(candidates, [](const auto& a, const auto& b) {
return a.config < b.config;
});
}

AutotuningLogs autotuning_logs;
int fusion_id = 0;
for (const auto& key_value : executable_sets) {
const HloFusionInstruction* fusion = key_value.first;
TF_ASSIGN_OR_RETURN(
std::vector<AutotuneResult> results,
Profile(compile_util, *fusion, /*candidates=*/key_value.second));
for (const auto& [fusion, candidates] : executable_sets) {
TF_ASSIGN_OR_RETURN(std::vector<AutotuneResult> results,
Profile(compile_util, *fusion, candidates));

// The reference config (if it exists) will be the first in the results,
// due to how sorting the variants work.
Expand Down Expand Up @@ -1143,7 +1131,7 @@ absl::StatusOr<bool> GemmFusionAutotuner::Run(
GemmFusionAutotunerImpl autotuner(config_, toolkit_version_, debug_options,
thread_pool_);
GemmConfigSetCollector gemm_config_set_collector(&autotuner);
TF_ASSIGN_OR_RETURN(TilingConfigsMap gemm_config_sets,
TF_ASSIGN_OR_RETURN(TilingConfigs gemm_config_sets,
gemm_config_set_collector.CollectGemmConfigSets(
module, execution_threads));

Expand Down
10 changes: 4 additions & 6 deletions third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class GemmFusionAutotunerImpl {
bool operator<(const CuDnnConfig& other) const;
};
using Config = std::variant<CuBlasConfig, CuDnnConfig, TritonGemmConfig>;
using TilingConfigs =
std::vector<std::pair<const HloFusionInstruction*, std::vector<Config>>>;

struct ExecutableCandidate {
Config config;
Expand All @@ -102,9 +104,7 @@ class GemmFusionAutotunerImpl {
// Compile all executables for all fusions.
absl::StatusOr<absl::flat_hash_map<const HloFusionInstruction*,
std::vector<ExecutableCandidate>>>
CompileAll(AutotunerCompileUtil& compile_util,
const absl::flat_hash_map<const HloFusionInstruction*,
std::vector<Config>>& task);
CompileAll(AutotunerCompileUtil& compile_util, const TilingConfigs& task);

// Profile all executables for a fusion.
absl::StatusOr<std::vector<AutotuneResult>> Profile(
Expand All @@ -113,9 +113,7 @@ class GemmFusionAutotunerImpl {

// Autotune and save the results to the autotuning cache.
absl::Status Autotune(
AutotunerCompileUtil& compile_util,
const absl::flat_hash_map<const HloFusionInstruction*,
std::vector<Config>>& gemm_config_sets,
AutotunerCompileUtil& compile_util, const TilingConfigs& gemm_config_sets,
absl::flat_hash_map<AutotuneCacheKey, uint64_t> fusion_count_map);

// Helper methods.
Expand Down

0 comments on commit 2dda854

Please sign in to comment.