Skip to content

Commit

Permalink
Simplify LogicalBufferStruct constructor.
Browse files Browse the repository at this point in the history
Reverts fd5333c

PiperOrigin-RevId: 621256228
  • Loading branch information
tensorflower-gardener committed Apr 2, 2024
1 parent 18a8957 commit aae698e
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/layout_util.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
Expand All @@ -57,12 +58,11 @@ using ::xla::LogicalBufferProto;
using ::xla::Shape;
using ::xla::ShapeUtil;

const Shape* ResolveShapeIndex(const Shape* shape,
absl::Span<const int64_t> shape_index) {
for (int64_t value : shape_index) {
shape = &shape->tuple_shapes(value);
}
return shape;
Shape ResolveShapeIndex(const xla::ShapeProto& shape_proto,
absl::Span<const int64_t> shape_index) {
if (shape_index.empty()) return Shape(shape_proto);
// Choosing the last subshape to maintain historical behavior.
return Shape(shape_proto.tuple_shapes(shape_proto.tuple_shapes_size() - 1));
}

std::string ShapeDescription(const Shape& shape) {
Expand Down Expand Up @@ -132,12 +132,12 @@ struct LogicalBufferStruct {
LogicalBufferStruct(const LogicalBufferProto& p,
const BufferAllocationStruct& b,
const ::xla::HloInstructionProto& i, uint64_t offset)
: proto(p), buffer_allocation(b), hlo_instruction(i), offset(offset) {
// Get shape of logical buffer.
const Shape top_level_shape(hlo_instruction.shape());
shape =
*ResolveShapeIndex(&top_level_shape, proto.defined_at().shape_index());
}
: proto(p),
buffer_allocation(b),
hlo_instruction(i),
offset(offset),
shape(ResolveShapeIndex(hlo_instruction.shape(),
proto.defined_at().shape_index())) {}

absl::string_view instruction_name() const { return hlo_instruction.name(); }

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/profiler/lib/profiler_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <functional>
#include <memory>
#include <utility>
#include <vector>

#include "absl/base/macros.h"
Expand All @@ -40,7 +41,7 @@ using ProfilerFactor ABSL_DEPRECATE_AND_INLINE() =
// Registers a profiler factory. Should be invoked at most once per factory.
ABSL_DEPRECATE_AND_INLINE()
inline void RegisterProfilerFactory(ProfilerFactor factory) {
tsl::profiler::RegisterProfilerFactory(factory);
tsl::profiler::RegisterProfilerFactory(std::move(factory));
}

// Invokes all registered profiler factories with the given options, and
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/profiler/lib/traceme_encode.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include <initializer_list>
#include <string>
#include <utility>

#include "absl/base/macros.h"
#include "absl/strings/match.h"
Expand All @@ -42,7 +43,7 @@ using TraceMeArg ABSL_DEPRECATE_AND_INLINE() =
ABSL_DEPRECATE_AND_INLINE()
inline std::string TraceMeEncode(std::string name,
std::initializer_list<TraceMeArg> args) {
return tsl::profiler::TraceMeEncode(name, args);
return tsl::profiler::TraceMeEncode(std::move(name), args);
}

ABSL_DEPRECATE_AND_INLINE()
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,6 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
Expand Down
50 changes: 4 additions & 46 deletions third_party/xla/xla/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ limitations under the License.
#include "absl/base/attributes.h"
#include "absl/base/casts.h"
#include "absl/base/config.h"
#include "absl/base/optimization.h"
#include "absl/functional/function_ref.h"
#include "absl/hash/hash.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/array.h"
Expand Down Expand Up @@ -355,20 +353,6 @@ class LiteralBase {
return LiteralBase::Hash(std::move(state), value);
}

private:
// With C++20, we can use `requires { absl::Hash<NativeT>(); }`.
template <typename T>
static constexpr bool IsAbslHashable() {
#ifdef _MSC_VER
// `std::is_invocable_v<absl::Hash<T>, T>` doesn't work on MSVC.
// See https://godbolt.org/z/Wj9d7zrav.
return std::is_arithmetic_v<T>;
#else
return std::is_invocable_v<absl::Hash<T>, T>;
#endif
}

public:
template <typename H, bool kIsLayoutSensitive = true,
int64_t kByteLimit = std::numeric_limits<int64_t>::max()>
static H Hash(H state, const LiteralBase& literal) {
Expand All @@ -382,36 +366,10 @@ class LiteralBase {
}

CHECK(LayoutUtil::IsDenseArray(subshape));
const auto hash_func = [&](auto primitive_type_constant) {
using NativeT =
primitive_util::NativeTypeOf<primitive_type_constant>;
// If we can hash NativeT, then do so. Otherwise, hash raw buffer
// data taking care to avoid invalid parts of 4-bit type data.
if constexpr (IsAbslHashable<NativeT>()) {
state = H::combine(std::move(state),
literal.piece(index).data<NativeT>());
} else {
const int64_t num_bytes =
std::min(kByteLimit, literal.size_bytes(index));
const char* buffer =
static_cast<const char*>(literal.untyped_data(index));
if (primitive_util::Is4BitType(subshape.element_type())) {
// Note: in this case, we could potentially read 8 bytes at a
// time, mask out the upper 4 bits of each byte, and then hash 8
// bytes, but it adds complexity and needs special handling for
// the non-divisible-by-8 leftover bytes.
for (int64_t i = 0; i < num_bytes; ++i) {
state =
H::combine(std::move(state), buffer[i] & uint8_t{0xf});
}
} else {
auto data = absl::MakeConstSpan(buffer, num_bytes);
state = H::combine(std::move(state), data);
}
}
};
primitive_util::ArrayTypeSwitch<void>(hash_func,
subshape.element_type());
auto data = absl::MakeConstSpan(
static_cast<const char*>(literal.untyped_data(index)),
std::min(kByteLimit, literal.size_bytes(index)));
state = H::combine(std::move(state), data);
});

return std::move(state);
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4328,6 +4328,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
Expand Down
34 changes: 30 additions & 4 deletions third_party/xla/xla/service/hlo_value_semantics_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -487,6 +488,13 @@ absl::Status EinsumDepthAnalysis::HandleCalledComputation(
}

absl::Status EinsumDepthAnalysis::HandleAfterAll(HloInstruction* after_all) {
auto depth_iter = GetDepthTreeOrDie(after_all);
const ShapeTree<int>& depth_tree = depth_iter->second;
int max_depth = GetMaxDepth(depth_tree);
for (HloInstruction* operand_token : after_all->mutable_operands()) {
CHECK(operand_token->shape().IsToken());
TF_RETURN_IF_ERROR(SetInstructionDepth(operand_token, max_depth));
}
return OkStatus();
}

Expand Down Expand Up @@ -697,7 +705,12 @@ absl::Status EinsumHeightAnalysis::HandleCalledComputation(
const HloComputation& computation,
absl::Span<HloInstruction* const> operands) {
if (!operands.empty()) {
CHECK(computation.num_parameters() == operands.size());
if (computation.num_parameters() != operands.size()) {
return absl::InvalidArgumentError(absl::StrCat(
operands.size(), " operands were passed for the computation ",
computation.name(), " with ", computation.num_parameters(),
" parameters."));
}
for (int parameter_index = 0;
parameter_index < computation.num_parameters(); ++parameter_index) {
HloInstruction* parameter =
Expand Down Expand Up @@ -802,9 +815,15 @@ absl::Status EinsumHeightAnalysis::HandleConditional(
RETURN_IF_HEIGHT_EXISTS(conditional);
auto conditional_height_iter = GetOrCreateHeightTree(conditional);
ShapeTree<int>& height_tree = conditional_height_iter->second;
for (HloComputation* computation : conditional->branch_computations()) {
TF_RETURN_IF_ERROR(
HandleCalledComputation(*computation, conditional->mutable_operands()));
for (size_t i = 0; i < conditional->branch_count(); ++i) {
HloComputation* computation = conditional->branch_computation(i);
// An N-way conditional op has N + 1 operands where the first one is the
// branch index determining what branch to take, and the remaining N
// operands correspond to arguments to be passed to each of the N branch
// computations, if they are executed. So the (i + 1)th operand corresponds
// to the ith branch computation.
TF_RETURN_IF_ERROR(HandleCalledComputation(
*computation, {conditional->mutable_operands()[i + 1]}));
auto branch_root_height_iter =
GetHeightTreeOrDie(computation->root_instruction());
SetHeight(height_tree, branch_root_height_iter->second);
Expand Down Expand Up @@ -1476,6 +1495,13 @@ HloValueSemanticsPropagation::MergeSemanticsForAnInstruction(
replace_operands_semantics_with(semantics);
continue;
}
if (operand_list[0].label() == HloValueSemanticLabel::kTupleOrToken &&
operand_list[1].label() == HloValueSemanticLabel::kTupleOrToken) {
HloValueSemantics semantics =
CopySemanticsWithNewOrigin(operand_list[0], instruction);
replace_operands_semantics_with(semantics);
continue;
}
LOG(FATAL) << "We don't expect to handle operands of label "
<< HloValueSemanticLabelToString(operand_list[0].label())
<< " and "
Expand Down
64 changes: 64 additions & 0 deletions third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class HloValueSemanticsAnalysisTest : public HloTestBase {
return HasLabel(hlo_value_semantics_analysis, module, instruction_name,
HloValueSemanticLabel::kWeightGradient);
}
bool IsTupleOrToken(
const HloValueSemanticsAnalysis& hlo_value_semantics_analysis,
HloModule* module, absl::string_view instruction_name) {
return HasLabel(hlo_value_semantics_analysis, module, instruction_name,
HloValueSemanticLabel::kTupleOrToken);
}
};

TEST_F(HloValueSemanticsAnalysisTest, OneMatmul) {
Expand Down Expand Up @@ -244,6 +250,41 @@ ENTRY entry {
EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "dot.2"));
}

TEST_F(HloValueSemanticsAnalysisTest, HandleConditional) {
const std::string module_str = R"(
HloModule Module
branch0 {
tparam = f32[4] parameter(0)
tgte1 = f32[4] ceil(tparam)
ROOT tuple = (f32[4], f32[4]) tuple(tparam, tgte1)
}
branch1 {
fparam = f32[4] parameter(0)
%async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo"
%async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start)
ROOT tuple = (f32[4], f32[4]) tuple(fparam, %async-done)
}
ENTRY entry {
p0 = f32[4] parameter(0)
b0 = s32[] parameter(1)
ROOT conditional = (f32[4], f32[4]) conditional(b0, p0, p0),
branch_computations={branch0, branch1}
}
)";

TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1,
/*num_partitions=*/2));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloValueSemanticsAnalysis> hlo_value_semantics_analysis,
HloValueSemanticsAnalysis::Run(*module));
EXPECT_TRUE(IsTupleOrToken(*hlo_value_semantics_analysis, module.get(),
"conditional"));
}

TEST_F(HloValueSemanticsAnalysisTest, TwoMatmuls) {
const std::string module_str = R"(
HloModule TwoMatmuls
Expand Down Expand Up @@ -622,6 +663,29 @@ TEST_F(EinsumDepthAnalysisTest, HandleConditional) {
0);
}

TEST_F(EinsumDepthAnalysisTest, HandleAfterAll) {
const char* const hlo_string = R"(
ENTRY entry {
after-all.1 = token[] after-all()
parameter.1 = f32[] parameter(0)
send.1 = (f32[], u32[], token[]) send(parameter.1, after-all.1), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"}
send-done.1 = token[] send-done(send.1), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"}
ROOT after-all.2 = token[] after-all(send-done.1), frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<EinsumDepthAnalysis> einsum_depth_analysis,
EinsumDepthAnalysis::Run(*module->entry_computation(),
SendRecvGroupMap(*module)));
const EinsumDepthMap& einsum_depth_map =
einsum_depth_analysis->GetEinsumDepthMap();
HloComputation* computation = module->GetComputationWithName("entry");
EXPECT_EQ(GetInstructionDepth(einsum_depth_map, computation, "after-all.2"),
0);
}

class EinsumHeightAnalysisTest : public HloTestBase {
public:
int GetInstructionHeight(const EinsumHeightMap& height_map,
Expand Down

0 comments on commit aae698e

Please sign in to comment.