Skip to content

Commit

Permalink
Add support for nested tuples in MutableBorrowingLiteral
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627266147
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 23, 2024
1 parent 63e6743 commit e0d2d88
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
22 changes: 21 additions & 1 deletion third_party/xla/xla/literal.cc
Expand Up @@ -2700,6 +2700,25 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs,
}
}

MutableBorrowingLiteral::MutableBorrowingLiteral(ShapeTree<char*> src_buf_ptrs)
: MutableLiteralBase() {
shape_ = std::make_unique<Shape>(src_buf_ptrs.shape());

root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
BuildPieceSubtree(*shape_, root_piece_);

root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::GetSubshape(*shape_, index).IsTuple()) {
DCHECK_EQ(src_buf_ptrs.element(index), nullptr)
<< "Tuples should not have buffer pointers";
return;
}
piece->set_buffer(const_cast<char*>(src_buf_ptrs.element(index)));
});
}

MutableBorrowingLiteral::~MutableBorrowingLiteral() {
if (root_piece_ != nullptr) {
delete root_piece_;
Expand Down Expand Up @@ -2749,7 +2768,8 @@ BorrowingLiteral::BorrowingLiteral(ShapeTree<const char*> src_buf_ptrs)
root_piece_.ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::GetSubshape(*shape_, index).IsTuple()) {
DCHECK_EQ(src_buf_ptrs.element(index), nullptr);
DCHECK_EQ(src_buf_ptrs.element(index), nullptr)
<< "Tuples should not have buffer pointers";
return;
}
piece->set_buffer(const_cast<char*>(src_buf_ptrs.element(index)));
Expand Down
13 changes: 11 additions & 2 deletions third_party/xla/xla/literal.h
Expand Up @@ -1484,12 +1484,21 @@ class MutableBorrowingLiteral : public MutableLiteralBase {
MutableBorrowingLiteral(MutableLiteralBase* literal);
MutableBorrowingLiteral(MutableBorrowingLiteral literal,
const ShapeIndex& view_root);

// 'src_buf_ptr' is not owned by this class and must outlive the
// lifetime of this class. It points to an appropriately sized buffer with
// data interpreted as indicated by 'shape'.
// This constructor is only used for array shapes.
MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);

// Create a literal from a list of buffers and a shape.
// Returns a tuple literal if `shape` is a tuple type.
// Similar as above, except to be used for constructing non-nested tuples.
MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape);

// Similar as above, except to be used for constructing literals with
// potentially nested tuples (same shape as `src_buf_ptrs`) with borrowed
// buffers for each shape index.
explicit MutableBorrowingLiteral(ShapeTree<char*> src_buf_ptrs);

private:
const Piece& root_piece() const override { return *root_piece_; };
// Recursively copies the subtree from the `src_piece` at the given child
Expand Down
19 changes: 19 additions & 0 deletions third_party/xla/xla/literal_test.cc
Expand Up @@ -2022,6 +2022,25 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromShapeTree) {
EXPECT_THAT(literal.data<float>({1}), ElementsAre(1.0, 2.0, 3.0));
}

TEST_F(LiteralUtilTest, MutableBorrowingLiteralFromShapeTree) {
std::vector<float> data = {1.0, 2.0, 3.0};

Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {3});
Shape tuple = ShapeUtil::MakeTupleShape({shape, shape});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, shape});

ShapeTree<char*> ptr_tree(nested_tuple);
*ptr_tree.mutable_element({0, 0}) = reinterpret_cast<char*>(data.data());
*ptr_tree.mutable_element({0, 1}) = reinterpret_cast<char*>(data.data());
*ptr_tree.mutable_element({1}) = reinterpret_cast<char*>(data.data());

MutableBorrowingLiteral literal(ptr_tree);

EXPECT_THAT(literal.data<float>({0, 0}), ElementsAre(1.0, 2.0, 3.0));
EXPECT_THAT(literal.data<float>({0, 1}), ElementsAre(1.0, 2.0, 3.0));
EXPECT_THAT(literal.data<float>({1}), ElementsAre(1.0, 2.0, 3.0));
}

TEST_F(LiteralUtilTest, LiteralMove) {
Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal literal(std::move(matrix));
Expand Down

0 comments on commit e0d2d88

Please sign in to comment.