Skip to content

Commit

Permalink
[tsl:concurrency] Specify Isa/DynCast/Cast semantics for indirect asy…
Browse files Browse the repository at this point in the history
…nc values

PiperOrigin-RevId: 622971220
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 8, 2024
1 parent 3eba170 commit 2e2fad9
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 28 deletions.
Expand Up @@ -170,6 +170,18 @@ TEST(AsyncValuePtrTest, Isa) {

EXPECT_TRUE(Isa<A>(a_err.AsPtr()));
EXPECT_TRUE(Isa<B>(b_err.AsPtr()));

// Indirect async value is Isa<T> only if it would be a no-op cast.
auto indirect = MakeIndirectAsyncValue();
AsyncValueRef<A> c_indirect(indirect);
EXPECT_TRUE(Isa<A>(c_indirect.AsPtr()));
EXPECT_FALSE(Isa<C>(c_indirect.AsPtr()));

// After forwarding indirect async value to a concrete one it correctly
// returns true from Isa<T> check.
indirect->ForwardTo(c_ref.CopyRCRef());
EXPECT_TRUE(Isa<A>(c_indirect.AsPtr()));
EXPECT_TRUE(Isa<C>(c_indirect.AsPtr()));
}

TEST(AsyncValuePtrTest, DynCast) {
Expand Down Expand Up @@ -212,6 +224,19 @@ TEST(AsyncValuePtrTest, DynCast) {
EXPECT_TRUE(DynCast<A>(a_err.AsPtr()));
EXPECT_TRUE(DynCast<B>(b_err.AsPtr()));
EXPECT_FALSE(DynCast<C>(a_err.AsPtr()));

// Indirect async value can't be DynCast until it's forwarded unless it's a
// no-op DynCast to the same type.
auto indirect = MakeIndirectAsyncValue();
AsyncValueRef<A> c_indirect(indirect);
EXPECT_TRUE(DynCast<A>(c_indirect.AsPtr()));
EXPECT_FALSE(DynCast<C>(c_indirect.AsPtr()));

// After forwarding indirect async value to a concrete one it can be DynCast
// to a concrete type.
indirect->ForwardTo(c_ref.CopyRCRef());
EXPECT_TRUE(DynCast<A>(c_indirect.AsPtr()));
EXPECT_TRUE(DynCast<C>(c_indirect.AsPtr()));
}

TEST(AsyncValuePtrTest, Cast) {
Expand Down Expand Up @@ -243,6 +268,18 @@ TEST(AsyncValuePtrTest, Cast) {

EXPECT_TRUE(Cast<A>(a_err.AsPtr()));
EXPECT_TRUE(Cast<B>(b_err.AsPtr()));

// Indirect async value can't be Cast until it's forwarded unless it's a
// no-op Cast to the same type.
auto indirect = MakeIndirectAsyncValue();
AsyncValueRef<A> c_indirect(indirect);
EXPECT_TRUE(Cast<A>(c_indirect.AsPtr()));

// After forwarding indirect async value to a concrete one it can be Cast
// to a concrete type.
indirect->ForwardTo(c_ref.CopyRCRef());
EXPECT_TRUE(Cast<A>(c_indirect.AsPtr()));
EXPECT_TRUE(Cast<C>(c_indirect.AsPtr()));
}

} // namespace tsl
46 changes: 18 additions & 28 deletions third_party/xla/third_party/tsl/tsl/concurrency/async_value_ref.h
Expand Up @@ -96,8 +96,13 @@ class AsyncValueRef {

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
bool Isa() const {
return value_ && (value_->IsType<Derived>() ||
value_->IsType<DummyValueForErrorAsyncValue>());
// Isa is successful if:
// (1) This is no-op cast even if concrete payload has different type.
// (2) Type id of a concrete payload matches Derived type id.
// (3) Payload is for a special case of ErrorAsyncValue.
return value_ && (std::is_same_v<Derived, T> || // (1)
value_->IsType<Derived>() || // (2)
value_->IsType<DummyValueForErrorAsyncValue>()); // (3)
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
Expand All @@ -109,18 +114,8 @@ class AsyncValueRef {
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> DynCast() const {
DCHECK(value_) << "Async value must be not null";
// Cast is successful if:
// (1) This is no-op cast even if concrete payload has different type.
// (2) Type id of a concrete payload matches Derived type id.
// (3) Payload is for a special case of ErrorAsyncValue.
if (std::is_same_v<Derived, T> || // (1)
value_->IsType<Derived>() || // (1)
value_->IsType<DummyValueForErrorAsyncValue>()) // (3)
{
return AsyncValueRef<Derived>(value_);
} else {
return AsyncValueRef<Derived>(nullptr);
}
return Isa<Derived>() ? AsyncValueRef<Derived>(value_)
: AsyncValueRef<Derived>(nullptr);
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
Expand Down Expand Up @@ -241,8 +236,13 @@ class AsyncValuePtr {

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
bool Isa() const {
return value_ && (value_->IsType<Derived>() ||
value_->IsType<DummyValueForErrorAsyncValue>());
// Isa is successful if:
// (1) This is no-op cast even if concrete payload has different type.
// (2) Type id of a concrete payload matches Derived type id.
// (3) Payload is for a special case of ErrorAsyncValue.
return value_ && (std::is_same_v<Derived, T> || // (1)
value_->IsType<Derived>() || // (2)
value_->IsType<DummyValueForErrorAsyncValue>()); // (3)
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
Expand All @@ -254,18 +254,8 @@ class AsyncValuePtr {
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> DynCast() const {
DCHECK(value_) << "Async value must be not null";
// DynCast is successful if:
// (1) This is no-op cast even if concrete payload has different type.
// (2) Type id of a concrete payload matches Derived type id.
// (3) Payload is for a special case of ErrorAsyncValue.
if (std::is_same_v<Derived, T> || // (1)
value_->IsType<Derived>() || // (1)
value_->IsType<DummyValueForErrorAsyncValue>()) // (3)
{
return AsyncValuePtr<Derived>(value_);
} else {
return AsyncValuePtr<Derived>(nullptr);
}
return Isa<Derived>() ? AsyncValuePtr<Derived>(value_)
: AsyncValuePtr<Derived>(nullptr);
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
Expand Down
Expand Up @@ -225,6 +225,18 @@ TEST(AsyncValueRefTest, Isa) {

EXPECT_TRUE(Isa<A>(a_err));
EXPECT_TRUE(Isa<B>(b_err));

// Indirect async value is Isa<T> only if it would be a no-op cast.
auto indirect = MakeIndirectAsyncValue();
AsyncValueRef<A> c_indirect(indirect);
EXPECT_TRUE(Isa<A>(c_indirect));
EXPECT_FALSE(Isa<C>(c_indirect));

// After forwarding indirect async value to a concrete one it correctly
// returns true from Isa<T> check.
indirect->ForwardTo(c_ref.CopyRCRef());
EXPECT_TRUE(Isa<A>(c_indirect));
EXPECT_TRUE(Isa<C>(c_indirect));
}

TEST(AsyncValueRefTest, DynCast) {
Expand Down Expand Up @@ -267,6 +279,19 @@ TEST(AsyncValueRefTest, DynCast) {
EXPECT_TRUE(DynCast<A>(a_err));
EXPECT_TRUE(DynCast<B>(b_err));
EXPECT_FALSE(DynCast<C>(a_err));

// Indirect async value can't be DynCast until it's forwarded unless it's a
// no-op DynCast to the same type.
auto indirect = MakeIndirectAsyncValue();
AsyncValueRef<A> c_indirect(indirect);
EXPECT_TRUE(DynCast<A>(c_indirect));
EXPECT_FALSE(DynCast<C>(c_indirect));

// After forwarding indirect async value to a concrete one it can be DynCast
// to a concrete type.
indirect->ForwardTo(c_ref.CopyRCRef());
EXPECT_TRUE(DynCast<A>(c_indirect));
EXPECT_TRUE(DynCast<C>(c_indirect));
}

TEST(AsyncValueRefTest, Cast) {
Expand Down Expand Up @@ -298,6 +323,18 @@ TEST(AsyncValueRefTest, Cast) {

EXPECT_TRUE(Cast<A>(a_err));
EXPECT_TRUE(Cast<B>(b_err));

// Indirect async value can't be Cast until it's forwarded unless it's a
// no-op Cast to the same type.
auto indirect = MakeIndirectAsyncValue();
AsyncValueRef<A> c_indirect(indirect);
EXPECT_TRUE(Cast<A>(c_indirect));

// After forwarding indirect async value to a concrete one it can be Cast
// to a concrete type.
indirect->ForwardTo(c_ref.CopyRCRef());
EXPECT_TRUE(Cast<A>(c_indirect));
EXPECT_TRUE(Cast<C>(c_indirect));
}

} // namespace tsl

0 comments on commit 2e2fad9

Please sign in to comment.