Skip to content

Commit

Permalink
Introduce MakeConstantWithShape in hlo_creation_util file.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635641775
  • Loading branch information
farzinhoushmand authored and tensorflower-gardener committed May 21, 2024
1 parent 32e5b6f commit 97a794b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
2 changes: 2 additions & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,7 @@ cc_library(
"//xla/client/lib:comparators",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -3247,6 +3248,7 @@ cc_library(
":call_inliner",
":collective_ops_utils",
":flatten_call_graph",
":hlo_creation_utils",
":hlo_cse",
":hlo_pass",
":pattern_matcher",
Expand Down
17 changes: 17 additions & 0 deletions third_party/xla/xla/service/hlo_creation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -869,4 +870,20 @@ HloInstruction* ExpandDegenerateReshape(HloInstruction* inst) {
return nullptr;
}

std::unique_ptr<HloInstruction> MakeConstantWithShape(const Shape& shape,
int64_t value) {
return primitive_util::PrimitiveTypeSwitch<std::unique_ptr<HloInstruction>>(
[&](auto literal_constant) -> std::unique_ptr<HloInstruction> {
if constexpr (primitive_util::IsIntegralType(literal_constant)) {
using NativeT = primitive_util::NativeTypeOf<literal_constant>;
auto constant = HloInstruction::CreateConstant(
LiteralUtil::CreateR0(static_cast<NativeT>(value)));
*constant->mutable_shape() = shape;
return std::move(constant);
}
LOG(FATAL) << "Literal is of non-integral type";
},
shape.element_type());
}

} // namespace xla
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/hlo_creation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ absl::StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
// adding and removing reshapes that changes only a single dimension.
HloInstruction* ExpandDegenerateReshape(HloInstruction* inst);

// Creates an integral constant with the given shape and integer value.
std::unique_ptr<HloInstruction> MakeConstantWithShape(const Shape& shape,
int64_t value);

} // namespace xla

#endif // XLA_SERVICE_HLO_CREATION_UTILS_H_
21 changes: 3 additions & 18 deletions third_party/xla/xla/service/while_loop_unroller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
#include "xla/service/call_inliner.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/flatten_call_graph.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_cse.h"
#include "xla/service/hlo_pass_fix.h"
#include "xla/service/pattern_matcher.h"
Expand All @@ -67,22 +68,6 @@ const int kUnrollTripCountThreshold = 64;
const int kUnrollInstructionCountThreshold = 800;
const int kUnrollExpandFactorThreshold = 10000;

std::unique_ptr<HloInstruction> GetConstantWithShape(const Shape& shape,
int64_t value) {
return primitive_util::PrimitiveTypeSwitch<std::unique_ptr<HloInstruction>>(
[&](auto literal_constant) -> std::unique_ptr<HloInstruction> {
if constexpr (primitive_util::IsIntegralType(literal_constant)) {
using NativeT = primitive_util::NativeTypeOf<literal_constant>;
auto constant = HloInstruction::CreateConstant(
LiteralUtil::CreateR0(static_cast<NativeT>(value)));
*constant->mutable_shape() = shape;
return std::move(constant);
}
LOG(FATAL) << "literal is of non-integral type";
},
shape.element_type());
}

// Helper function to create a condition for a single iteration while loop in
// the form of 'i <= init_value' where i is the induction variable.
std::unique_ptr<HloComputation> MakeTrivialLoopCondition(
Expand All @@ -99,7 +84,7 @@ std::unique_ptr<HloComputation> MakeTrivialLoopCondition(
param_instruction.value(), induction_idx));

HloInstruction* init_value_constant = condition_builder.AddInstruction(
GetConstantWithShape(indvar_instruction->shape(), init_value));
MakeConstantWithShape(indvar_instruction->shape(), init_value));

return condition_builder.Build(
condition_builder.AddInstruction(HloInstruction::CreateCompare(
Expand Down Expand Up @@ -155,7 +140,7 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op,
}

HloInstruction* induction_value_constant = while_body_clone->AddInstruction(
GetConstantWithShape(induction_var_hlo->shape(), induction_value));
MakeConstantWithShape(induction_var_hlo->shape(), induction_value));

// Finds all the uses of induction var within the while body and replace it
// with the constant.
Expand Down

0 comments on commit 97a794b

Please sign in to comment.