Skip to content

Commit

Permalink
Add virtual method IsInlineableCallOp to CallInliner to allow sub…
Browse files Browse the repository at this point in the history
…classes to change which call instructions to inline. And clean up `#include`s in `call_inliner.cc`.

PiperOrigin-RevId: 624170852
  • Loading branch information
bartchr808 authored and tensorflower-gardener committed Apr 12, 2024
1 parent bbb92f0 commit ea461c3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
9 changes: 9 additions & 0 deletions third_party/xla/xla/service/BUILD
Expand Up @@ -969,10 +969,19 @@ cc_library(
":hlo_dce",
":hlo_domain_isolator",
":hlo_pass",
"//xla:status",
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)

Expand Down
24 changes: 20 additions & 4 deletions third_party/xla/xla/service/call_inliner.cc
Expand Up @@ -16,14 +16,26 @@ limitations under the License.
#include "xla/service/call_inliner.h"

#include <memory>

#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding_metadata.h"
#include "xla/service/call_graph.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_domain_isolator.h"
#include "xla/status.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -136,6 +148,11 @@ CallInliner::Inline(HloInstruction* call) {
return visitor.ConsumeInstructionMap();
}

bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return instruction->opcode() == HloOpcode::kCall &&
!instruction->parent()->IsAsyncComputation();
}

absl::StatusOr<bool> CallInliner::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -156,8 +173,7 @@ absl::StatusOr<bool> CallInliner::Run(
// used for parallel device computation.
// TODO(b/229887502): update the inliner to ignore only parallel
// device type async call instead of all.
if (instruction->opcode() == HloOpcode::kCall &&
!instruction->parent()->IsAsyncComputation()) {
if (IsInlineableCallOp(instruction)) {
const auto& callees = instruction->called_computations();
TF_RET_CHECK(callees.size() == 1);
if (!single_call_site_ || call_graph->GetNode(instruction->to_apply())
Expand All @@ -182,7 +198,7 @@ absl::StatusOr<bool> CallInliner::Run(
// Run DCE to remove called computations which are now becoming unused.
// This can result then in problems if within the called computation, there
// were send/recv instructions, which the module group verifier will flag as
// error findingthe same channel ID used for multiple send/recv
// error finding the same channel ID used for multiple send/recv
// instructions.
TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status());
}
Expand Down
10 changes: 9 additions & 1 deletion third_party/xla/xla/service/call_inliner.h
Expand Up @@ -17,8 +17,12 @@ limitations under the License.
#define XLA_SERVICE_CALL_INLINER_H_

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/statusor.h"

namespace xla {

Expand Down Expand Up @@ -48,6 +52,10 @@ class CallInliner : public HloModulePass {
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

// Returns true if the instruction is a kCall operation and is eligible for
// inlining.
virtual bool IsInlineableCallOp(HloInstruction* instruction) const;

private:
bool single_call_site_;
bool update_domain_;
Expand Down

0 comments on commit ea461c3

Please sign in to comment.