Skip to content

Commit

Permalink
[tsl:concurrency] Add LLVM-style type casting to AsyncValuePtr<T>
Browse files Browse the repository at this point in the history
AsyncValueRef<T> type casting coming in followup CL, it will be a bit more tricky because it's not clear when to return a ref-copy and when to return a non-owning AsyncValuePtr.

PiperOrigin-RevId: 622718726
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 8, 2024
1 parent a773ace commit 0cf66c1
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 20 deletions.
15 changes: 15 additions & 0 deletions third_party/xla/third_party/tsl/tsl/concurrency/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ cc_library(
deps = [
":concurrent_vector",
":ref_count",
"//tsl/platform:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
Expand All @@ -42,13 +44,26 @@ tsl_cc_test(
],
)

tsl_cc_test(
name = "async_value_ptr_test",
srcs = ["async_value_ptr_test.cc"],
deps = [
":async_value",
"//tsl/platform:test",
"//tsl/platform:test_main",
"@com_google_absl//absl/status",
],
)

tsl_cc_test(
name = "async_value_ref_test",
srcs = ["async_value_ref_test.cc"],
deps = [
":async_value",
"//tsl/platform:test",
"//tsl/platform:test_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

Expand Down
5 changes: 2 additions & 3 deletions third_party/xla/third_party/tsl/tsl/concurrency/async_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ limitations under the License.
#include <cstdint>
#include <iostream>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "tsl/concurrency/concurrent_vector.h"
#include "tsl/concurrency/ref_count.h"
Expand Down Expand Up @@ -258,11 +256,12 @@ class AsyncValue {
// -----------------------------------------------------------
// Implementation details follow. Clients should ignore them.

friend class IndirectAsyncValue;

// Utility template for tag dispatching.
template <typename T>
struct TypeTag {};

friend class IndirectAsyncValue;
template <typename T>
AsyncValue(Kind kind, State state, bool is_refcounted, TypeTag<T>)
: refcount_(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "absl/status/status.h"
#include "tsl/concurrency/async_value_ref.h"
#include "tsl/platform/test.h"

Expand Down Expand Up @@ -75,4 +78,67 @@ TEST(AsyncValuePtrTest, AndThen) {
EXPECT_TRUE(executed);
}

namespace {
struct A {
virtual ~A() = default;
};
struct B : public A {};
struct C : public B {};
struct D : public A {};
} // namespace

TEST(AsyncValuePtrTest, Isa) {
// Empty async pointer always returns false for any Isa<T>().
AsyncValuePtr<A> null_ptr;
EXPECT_FALSE(Isa<A>(null_ptr));

AsyncValueRef<A> a_ref = MakeAvailableAsyncValueRef<A>();
AsyncValueRef<A> b_ref = MakeAvailableAsyncValueRef<B>();
AsyncValueRef<A> c_ref = MakeAvailableAsyncValueRef<C>();
AsyncValueRef<A> d_ref = MakeAvailableAsyncValueRef<D>();

EXPECT_TRUE(Isa<A>(a_ref.AsPtr()));
EXPECT_TRUE(Isa<B>(b_ref.AsPtr()));
EXPECT_TRUE(Isa<C>(c_ref.AsPtr()));
EXPECT_TRUE(Isa<D>(d_ref.AsPtr()));
}

TEST(AsyncValuePtrTest, DynCast) {
AsyncValueRef<A> a_ref = MakeAvailableAsyncValueRef<A>();
AsyncValueRef<A> b_ref = MakeAvailableAsyncValueRef<B>();
AsyncValueRef<A> c_ref = MakeAvailableAsyncValueRef<C>();
AsyncValueRef<A> d_ref = MakeAvailableAsyncValueRef<D>();

EXPECT_TRUE(DynCast<A>(a_ref.AsPtr()));
EXPECT_TRUE(DynCast<B>(b_ref.AsPtr()));
EXPECT_TRUE(DynCast<C>(c_ref.AsPtr()));
EXPECT_TRUE(DynCast<D>(d_ref.AsPtr()));

// No-op casts are always successful.
EXPECT_TRUE(DynCast<A>(c_ref.AsPtr()));

// We don't support casting to base (C inherits from B) because we can't do
// that safely relying just on AsyncValue type id. For safe conversion to base
// we need to introduce some kind of traits to the type hierarchy or rely on
// builtin `dynamic_cast` (will work only for constructed values).
EXPECT_FALSE(DynCast<B>(c_ref.AsPtr()));

// Types are unrelated, although they have same base.
EXPECT_FALSE(DynCast<C>(d_ref.AsPtr()));
}

TEST(AsyncValuePtrTest, Cast) {
AsyncValueRef<A> a_ref = MakeAvailableAsyncValueRef<A>();
AsyncValueRef<A> b_ref = MakeAvailableAsyncValueRef<B>();
AsyncValueRef<A> c_ref = MakeAvailableAsyncValueRef<C>();
AsyncValueRef<A> d_ref = MakeAvailableAsyncValueRef<D>();

EXPECT_TRUE(Cast<A>(a_ref.AsPtr()));
EXPECT_TRUE(Cast<B>(b_ref.AsPtr()));
EXPECT_TRUE(Cast<C>(c_ref.AsPtr()));
EXPECT_TRUE(Cast<D>(d_ref.AsPtr()));

EXPECT_TRUE(Cast<A>(c_ref.AsPtr()));
}

} // namespace tsl
94 changes: 78 additions & 16 deletions third_party/xla/third_party/tsl/tsl/concurrency/async_value_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@ limitations under the License.
#define TENSORFLOW_TSL_CONCURRENCY_ASYNC_VALUE_REF_H_

#include <cstddef>
#include <cstdlib>
#include <string_view>
#include <type_traits>
#include <utility>

#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "tsl/concurrency/async_value.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/platform/logging.h"

namespace tsl {

namespace internal {
// TODO(ezhulenev): Replace with C++20 concept when available.
// https://en.cppreference.com/w/cpp/concepts/derived_from
template <typename Subclass, typename Base>
using DerivedFrom =
typename std::enable_if_t<std::is_base_of_v<Base, Subclass>>;
} // namespace internal

// Forward declare non-owning typed async value pointer.
template <typename T>
class AsyncValuePtr;
Expand All @@ -51,14 +60,13 @@ class AsyncValueRef {
explicit AsyncValueRef(RCReference<AsyncValue> value)
: value_(std::move(value)) {}

// Support implicit conversion from AsyncValueRef<Derived> to
// Support implicit conversion from AsyncValueRef<Subclass> to
// AsyncValueRef<Base>.
template <typename DerivedT,
std::enable_if_t<std::is_base_of<T, DerivedT>::value>* = nullptr>
AsyncValueRef(AsyncValueRef<DerivedT>&& u) // NOLINT
template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValueRef(AsyncValueRef<Subclass>&& u) // NOLINT
: value_(u.ReleaseRCRef()) {}

// Support implicit conversion from RCReference<AsyncValue>.
// Support implicit conversion from RCReference<ErrorAsyncValue>.
AsyncValueRef(RCReference<ErrorAsyncValue> value) // NOLINT
: value_(std::move(value)) {}

Expand All @@ -85,10 +93,9 @@ class AsyncValueRef {

// Return the stored value as a subclass type. The AsyncValueRef must be
// available.
template <typename SubclassT,
std::enable_if_t<std::is_base_of<T, SubclassT>::value>* = nullptr>
SubclassT& get() const {
return value_->get<SubclassT>();
template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
Subclass& get() const {
return value_->get<Subclass>();
}

T* operator->() const { return &get(); }
Expand Down Expand Up @@ -130,7 +137,7 @@ class AsyncValueRef {
}

void SetError(absl::Status status) const {
assert(!status.ok() && "expected non-ok status");
DCHECK(!status.ok()) << "expected non-ok status";
return value_->SetError(std::move(status));
}

Expand Down Expand Up @@ -202,6 +209,35 @@ class AsyncValuePtr {
return *this;
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
bool Isa() {
return value_ && value_->IsType<Subclass>();
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> Cast() const {
DCHECK(value_) << "Async value must be not null";
DCHECK((std::is_same_v<Subclass, T> || value_->IsType<Subclass>()));
return AsyncValuePtr<Subclass>(value_);
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCast() const {
DCHECK(value_) << "Async value must be not null";
if (std::is_same_v<Subclass, T> || value_->IsType<Subclass>()) {
return AsyncValuePtr<Subclass>(value_);
}
return AsyncValuePtr<Subclass>(nullptr);
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCastOrNull() const {
if (std::is_same_v<Subclass, T> || (value_ && value_->IsType<Subclass>())) {
return AsyncValuePtr<Subclass>(value_);
}
return AsyncValuePtr<Subclass>(nullptr);
}

bool IsAvailable() const { return value_->IsAvailable(); }
bool IsUnavailable() const { return value_->IsUnavailable(); }

Expand All @@ -218,7 +254,7 @@ class AsyncValuePtr {
const absl::Status& GetError() const { return value_->GetError(); }

void SetError(absl::Status status) const {
assert(!status.ok() && "expected non-ok status");
DCHECK(!status.ok()) << "expected non-ok status";
return value_->SetError(std::move(status));
}

Expand Down Expand Up @@ -307,6 +343,36 @@ RCReference<ErrorAsyncValue> MakeErrorAsyncValueRef(std::string_view message);
// Construct an empty IndirectAsyncValue, not forwarding to anything.
RCReference<IndirectAsyncValue> MakeIndirectAsyncValue();

//===----------------------------------------------------------------------===//
// LLVM-style type casting library for async value refs and ptrs.
//===----------------------------------------------------------------------===//

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
bool Isa(AsyncValuePtr<T> ptr) {
return ptr.template Isa<Subclass>();
}

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> Cast(AsyncValuePtr<T> ptr) {
return ptr.template Cast<Subclass>();
}

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCast(AsyncValuePtr<T> ptr) {
return ptr.template DynCast<Subclass>();
}

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCastOrNull(AsyncValuePtr<T> ptr) {
return ptr.template DynCastOrNull<Subclass>();
}

//===----------------------------------------------------------------------===//
// Constructing reference-counted async values on the heap.
//===----------------------------------------------------------------------===//

namespace internal {
Expand All @@ -325,10 +391,6 @@ T* AllocateAndConstruct(Args&&... args) {

} // namespace internal

//===----------------------------------------------------------------------===//
// Constructing reference-counted async values on the heap.
//===----------------------------------------------------------------------===//

// Allocate an unconstructed AsyncValueRef. The AsyncValueRef should be made
// available later by invoking AsyncValueRef::emplace or
// AsyncValueRef::SetError.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ limitations under the License.

#include "tsl/concurrency/async_value_ref.h"

#include <memory>
#include <cstdint>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "tsl/platform/test.h"

namespace tsl {
Expand Down

0 comments on commit 0cf66c1

Please sign in to comment.