Skip to content

Commit

Permalink
[XLA] Fix an incorrect use of a hashmap in HloCSE.
Browse files Browse the repository at this point in the history
Before the fix there were values where x==b, but hash(a) != hash(b). This was because the equality was insensitive to the layout whereas the hash was sensitive.

PiperOrigin-RevId: 642297727
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 11, 2024
1 parent ae908e1 commit 0f05e18
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 67 deletions.
53 changes: 28 additions & 25 deletions third_party/xla/xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1891,39 +1891,42 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
subshape().element_type());
}

bool LiteralBase::operator==(const LiteralBase& other) const {
bool LiteralBase::Equal(const LiteralBase& other, bool layout_sensitive) const {
// Checking the structure of tuple literals. Checks for dense arrays are
// performed below.
if (!ShapeUtil::EqualStructure(shape(), other.shape())) {
return false;
}

return root_piece().ForEachSubpieceWithBool(
[&](const ShapeIndex& index, const Piece& piece) {
const Piece& other_piece = other.piece(index);
const Shape& subshape = piece.subshape();
const Shape& other_subshape = other_piece.subshape();
if (subshape.element_type() != other_subshape.element_type()) {
return false;
}
if (!piece.subshape().IsArray()) {
return true;
}
if (subshape.rank() != other_subshape.rank()) {
return false;
}
return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
const Piece& piece) {
const Piece& other_piece = other.piece(index);
const Shape& subshape = piece.subshape();
const Shape& other_subshape = other_piece.subshape();
if (subshape.element_type() != other_subshape.element_type()) {
return false;
}
if (!piece.subshape().IsArray()) {
return true;
}
if (subshape.rank() != other_subshape.rank()) {
return false;
}
if (layout_sensitive && (subshape.layout() != other_subshape.layout())) {
return false;
}

for (int64_t i = 0; i < subshape.rank(); ++i) {
if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
return false;
}
}
for (int64_t i = 0; i < subshape.rank(); ++i) {
if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) {
return false;
}
}

if (!piece.EqualElements(other_piece)) {
return false;
}
return true;
});
if (!piece.EqualElements(other_piece)) {
return false;
}
return true;
});
}

template <typename NativeT>
Expand Down
89 changes: 49 additions & 40 deletions third_party/xla/xla/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,19 @@ class LiteralBase {
virtual ~LiteralBase() = 0;

// Literals are equal if they have compatible shapes and the same data
// values. Layout is not compared.
bool operator==(const LiteralBase& other) const;
// values. Layout is not compared. For a layout sensitive comparison
// call Equal() with layout_sensitive=true.
bool operator==(const LiteralBase& other) const {
return Equal(other, false);
}
bool operator!=(const LiteralBase& other) const { return !(*this == other); }

// Compares two literals with optional layout sensitivity. If you use
// literals in a hash map, together with AbslHashValue or Hash defined below,
// you must use this method instead of operator== to ensure proper layout
// handling.
bool Equal(const LiteralBase& other, bool layout_sensitive) const;

// Returns the shape of the literal.
const Shape& shape() const;

Expand Down Expand Up @@ -347,56 +356,56 @@ class LiteralBase {
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}

// Compute a hash for this literal.
// Compute a hash for this literal. Always use this together with the Equal
// method and not operator== in order to handle layout sensitivity properly.
template <typename H>
friend H AbslHashValue(H state, const LiteralBase& value) {
return LiteralBase::Hash(std::move(state), value);
}

// Always use this together with the Equal method and not operator== in order
// to handle layout sensitivity properly.
template <typename H, bool kIsLayoutSensitive = true,
int64_t kByteLimit = std::numeric_limits<int64_t>::max()>
static H Hash(H state, const LiteralBase& literal) {
state =
Shape::Hash<H, kIsLayoutSensitive>(std::move(state), literal.shape());

ShapeUtil::ForEachSubshape(
literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (!subshape.IsArray()) {
return;
}
ShapeUtil::ForEachSubshape(literal.shape(), [&](const Shape& subshape,
const ShapeIndex& index) {
if (!subshape.IsArray()) {
return;
}

CHECK(LayoutUtil::IsDenseArray(subshape));
const int64_t size_bytes = literal.size_bytes(index);
const int64_t bytes_to_hash = std::min(size_bytes, kByteLimit);
// When layout insensitive, we need to hash the data bytes in logical
// order rather than physical order.
const bool use_physical_order =
kIsLayoutSensitive || !subshape.has_layout();
auto data = absl::MakeConstSpan(
static_cast<const char*>(literal.untyped_data(index)),
size_bytes);
if (use_physical_order) {
state = H::combine(std::move(state), data.first(bytes_to_hash));
return;
}
const int64_t elem_size =
ShapeUtil::ByteSizeOfPrimitiveType(subshape.element_type());
absl::Span<const int64_t> minor_to_major =
subshape.layout().minor_to_major();
DimensionVector elem_index(subshape.dimensions_size());
absl::Span<int64_t> elem_index_span(elem_index.data(),
elem_index.size());
int64_t bytes_hashed = 0;
while (bytes_hashed < bytes_to_hash) {
int64_t offset =
elem_size * IndexUtil::MultidimensionalIndexToLinearIndex(
subshape, minor_to_major, elem_index);
state =
H::combine(std::move(state), data.subspan(offset, elem_size));
if (!IndexUtil::BumpIndices(subshape, elem_index_span)) return;
bytes_hashed += elem_size;
}
});
CHECK(LayoutUtil::IsDenseArray(subshape));
const int64_t size_bytes = literal.size_bytes(index);
const int64_t bytes_to_hash = std::min(size_bytes, kByteLimit);
// When layout insensitive, we need to hash the data bytes in logical
// order rather than physical order.
const bool use_physical_order =
kIsLayoutSensitive || !subshape.has_layout();
auto data = absl::MakeConstSpan(
static_cast<const char*>(literal.untyped_data(index)), size_bytes);
if (use_physical_order) {
state = H::combine(std::move(state), data.first(bytes_to_hash));
return;
}
const int64_t elem_size =
ShapeUtil::ByteSizeOfPrimitiveType(subshape.element_type());
absl::Span<const int64_t> minor_to_major =
subshape.layout().minor_to_major();
DimensionVector elem_index(subshape.dimensions_size());
absl::Span<int64_t> elem_index_span(elem_index.data(), elem_index.size());
int64_t bytes_hashed = 0;
while (bytes_hashed < bytes_to_hash) {
int64_t offset =
elem_size * IndexUtil::MultidimensionalIndexToLinearIndex(
subshape, minor_to_major, elem_index);
state = H::combine(std::move(state), data.subspan(offset, elem_size));
if (!IndexUtil::BumpIndices(subshape, elem_index_span)) return;
bytes_hashed += elem_size;
}
});

return std::move(state);
}
Expand Down
19 changes: 19 additions & 0 deletions third_party/xla/xla/literal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,25 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
EXPECT_EQ(rowmajor, colmajor);
}

TEST_F(LiteralUtilTest, DifferentLayoutInEquality) {
// Test in equality with literals which have different layouts when layout
// sensitive equality is used.
Literal colmajor(ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 2}, {0, 1}));
colmajor.Set<float>({0, 0}, 1.0);
colmajor.Set<float>({0, 1}, 2.0);
colmajor.Set<float>({1, 0}, 3.0);
colmajor.Set<float>({1, 1}, 4.0);

Literal rowmajor(ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 2}, {1, 0}));
rowmajor.Set<float>({0, 0}, 1.0);
rowmajor.Set<float>({0, 1}, 2.0);
rowmajor.Set<float>({1, 0}, 3.0);
rowmajor.Set<float>({1, 1}, 4.0);

EXPECT_FALSE(rowmajor.Equal(colmajor, true));
EXPECT_FALSE(colmajor.Equal(rowmajor, true));
}

TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
auto scalar = LiteralUtil::CreateR0<float>(1.0);
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,

struct LiteralPtrEqualityFunctor {
bool operator()(const Literal* lhs, const Literal* rhs) const {
return *lhs == *rhs && lhs->shape().layout() == rhs->shape().layout();
return lhs->Equal(*rhs, true);
}
};

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/hlo_cse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct ConstantKey {
(kIsLayoutSensitive ? Shape::Equal()
: Shape::Equal().IgnoreLayout())(
lhs.hlo->shape(), rhs.hlo->shape()) &&
lhs.hlo->literal() == rhs.hlo->literal();
lhs.hlo->literal().Equal(rhs.hlo->literal(), kIsLayoutSensitive);
}
HloConstantInstruction* hlo;
int64_t domain;
Expand Down

0 comments on commit 0f05e18

Please sign in to comment.