Skip to content

Commit

Permalink
[tsl:concurrency] Add LLVM-style type casting to AsyncValueRef<T>
Browse files Browse the repository at this point in the history
+ cleanup ref_count header and consistently use DerivedFrom helper

Coming-next: for rvalue references we can do Cast and DynCast by moving RCReference and save wasted AddRef/DropRef
PiperOrigin-RevId: 622723868
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 8, 2024
1 parent 0cf66c1 commit 2bae250
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 55 deletions.
139 changes: 92 additions & 47 deletions third_party/xla/third_party/tsl/tsl/concurrency/async_value_ref.h
Expand Up @@ -30,14 +30,6 @@ limitations under the License.

namespace tsl {

namespace internal {
// TODO(ezhulenev): Replace with C++20 concept when available.
// https://en.cppreference.com/w/cpp/concepts/derived_from
template <typename Subclass, typename Base>
using DerivedFrom =
typename std::enable_if_t<std::is_base_of_v<Base, Subclass>>;
} // namespace internal

// Forward declare non-owning typed async value pointer.
template <typename T>
class AsyncValuePtr;
Expand All @@ -60,10 +52,10 @@ class AsyncValueRef {
explicit AsyncValueRef(RCReference<AsyncValue> value)
: value_(std::move(value)) {}

// Support implicit conversion from AsyncValueRef<Subclass> to
// Support implicit conversion from AsyncValueRef<Derived> to
// AsyncValueRef<Base>.
template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValueRef(AsyncValueRef<Subclass>&& u) // NOLINT
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef(AsyncValueRef<Derived>&& u) // NOLINT
: value_(u.ReleaseRCRef()) {}

// Support implicit conversion from RCReference<ErrorAsyncValue>.
Expand Down Expand Up @@ -91,11 +83,40 @@ class AsyncValueRef {
// Return the stored value. The AsyncValueRef must be available.
T& get() const { return value_->get<T>(); }

// Return the stored value as a subclass type. The AsyncValueRef must be
// Return the stored value as a derived type. The AsyncValueRef must be
// available.
template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
Subclass& get() const {
return value_->get<Subclass>();
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
Derived& get() const {
return value_->get<Derived>();
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
bool Isa() const {
return value_ && value_->IsType<Derived>();
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> Cast() const {
DCHECK(value_) << "Async value must be not null";
DCHECK((std::is_same_v<Derived, T> || value_->IsType<Derived>()));
return AsyncValueRef<Derived>(value_);
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> DynCast() const {
DCHECK(value_) << "Async value must be not null";
if (std::is_same_v<Derived, T> || value_->IsType<Derived>()) {
return AsyncValueRef<Derived>(value_);
}
return AsyncValueRef<Derived>(nullptr);
}

template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> DynCastOrNull() const {
if (std::is_same_v<Derived, T> || (value_ && value_->IsType<Derived>())) {
return AsyncValueRef<Derived>(value_);
}
return AsyncValueRef<Derived>(nullptr);
}

T* operator->() const { return &get(); }
Expand Down Expand Up @@ -209,33 +230,33 @@ class AsyncValuePtr {
return *this;
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
bool Isa() {
return value_ && value_->IsType<Subclass>();
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
bool Isa() const {
return value_ && value_->IsType<Derived>();
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> Cast() const {
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> Cast() const {
DCHECK(value_) << "Async value must be not null";
DCHECK((std::is_same_v<Subclass, T> || value_->IsType<Subclass>()));
return AsyncValuePtr<Subclass>(value_);
DCHECK((std::is_same_v<Derived, T> || value_->IsType<Derived>()));
return AsyncValuePtr<Derived>(value_);
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCast() const {
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> DynCast() const {
DCHECK(value_) << "Async value must be not null";
if (std::is_same_v<Subclass, T> || value_->IsType<Subclass>()) {
return AsyncValuePtr<Subclass>(value_);
if (std::is_same_v<Derived, T> || value_->IsType<Derived>()) {
return AsyncValuePtr<Derived>(value_);
}
return AsyncValuePtr<Subclass>(nullptr);
return AsyncValuePtr<Derived>(nullptr);
}

template <typename Subclass, internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCastOrNull() const {
if (std::is_same_v<Subclass, T> || (value_ && value_->IsType<Subclass>())) {
return AsyncValuePtr<Subclass>(value_);
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> DynCastOrNull() const {
if (std::is_same_v<Derived, T> || (value_ && value_->IsType<Derived>())) {
return AsyncValuePtr<Derived>(value_);
}
return AsyncValuePtr<Subclass>(nullptr);
return AsyncValuePtr<Derived>(nullptr);
}

bool IsAvailable() const { return value_->IsAvailable(); }
Expand Down Expand Up @@ -347,28 +368,52 @@ RCReference<IndirectAsyncValue> MakeIndirectAsyncValue();
// LLVM-style type casting library for async value refs and ptrs.
//===----------------------------------------------------------------------===//

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
bool Isa(const AsyncValueRef<T>& ref) {
return ref.template Isa<Derived>();
}

template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> Cast(const AsyncValueRef<T>& ref) {
return ref.template Cast<Derived>();
}

template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> DynCast(const AsyncValueRef<T>& ref) {
return ref.template DynCast<Derived>();
}

template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef<Derived> DynCastOrNull(const AsyncValueRef<T>& ref) {
return ref.template DynCastOrNull<Derived>();
}

template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
bool Isa(AsyncValuePtr<T> ptr) {
return ptr.template Isa<Subclass>();
return ptr.template Isa<Derived>();
}

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> Cast(AsyncValuePtr<T> ptr) {
return ptr.template Cast<Subclass>();
template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> Cast(AsyncValuePtr<T> ptr) {
return ptr.template Cast<Derived>();
}

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCast(AsyncValuePtr<T> ptr) {
return ptr.template DynCast<Subclass>();
template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> DynCast(AsyncValuePtr<T> ptr) {
return ptr.template DynCast<Derived>();
}

template <typename Subclass, typename T,
internal::DerivedFrom<Subclass, T>* = nullptr>
AsyncValuePtr<Subclass> DynCastOrNull(AsyncValuePtr<T> ptr) {
return ptr.template DynCastOrNull<Subclass>();
template <typename Derived, typename T,
internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValuePtr<Derived> DynCastOrNull(AsyncValuePtr<T> ptr) {
return ptr.template DynCastOrNull<Derived>();
}

//===----------------------------------------------------------------------===//
Expand Down
Expand Up @@ -189,4 +189,67 @@ TEST(AsyncValueRefTest, Nullptr) {
EXPECT_FALSE(av_int2);
}

namespace {
struct A {
virtual ~A() = default;
};
struct B : public A {};
struct C : public B {};
struct D : public A {};
} // namespace

TEST(AsyncValueRefTest, Isa) {
// Empty async reference always returns false for any Isa<T>().
AsyncValueRef<A> null_ref;
EXPECT_FALSE(Isa<A>(null_ref));

AsyncValueRef<A> a_ref = MakeAvailableAsyncValueRef<A>();
AsyncValueRef<A> b_ref = MakeAvailableAsyncValueRef<B>();
AsyncValueRef<A> c_ref = MakeAvailableAsyncValueRef<C>();
AsyncValueRef<A> d_ref = MakeAvailableAsyncValueRef<D>();

EXPECT_TRUE(Isa<A>(a_ref));
EXPECT_TRUE(Isa<B>(b_ref));
EXPECT_TRUE(Isa<C>(c_ref));
EXPECT_TRUE(Isa<D>(d_ref));
}

TEST(AsyncValueRefTest, DynCast) {
AsyncValueRef<A> a_ref = MakeAvailableAsyncValueRef<A>();
AsyncValueRef<A> b_ref = MakeAvailableAsyncValueRef<B>();
AsyncValueRef<A> c_ref = MakeAvailableAsyncValueRef<C>();
AsyncValueRef<A> d_ref = MakeAvailableAsyncValueRef<D>();

EXPECT_TRUE(DynCast<A>(a_ref));
EXPECT_TRUE(DynCast<B>(b_ref));
EXPECT_TRUE(DynCast<C>(c_ref));
EXPECT_TRUE(DynCast<D>(d_ref));

// No-op casts are always successful.
EXPECT_TRUE(DynCast<A>(c_ref));

// We don't support casting to base (C inherits from B) because we can't do
// that safely relying just on AsyncValue type id. For safe conversion to base
// we need to introduce some kind of traits to the type hierarchy or rely on
// builtin `dynamic_cast` (will work only for constructed values).
EXPECT_FALSE(DynCast<B>(c_ref));

// Types are unrelated, although they have same base.
EXPECT_FALSE(DynCast<C>(d_ref));
}

TEST(AsyncValueRefTest, Cast) {
AsyncValueRef<A> a_ref = MakeAvailableAsyncValueRef<A>();
AsyncValueRef<A> b_ref = MakeAvailableAsyncValueRef<B>();
AsyncValueRef<A> c_ref = MakeAvailableAsyncValueRef<C>();
AsyncValueRef<A> d_ref = MakeAvailableAsyncValueRef<D>();

EXPECT_TRUE(Cast<A>(a_ref));
EXPECT_TRUE(Cast<B>(b_ref));
EXPECT_TRUE(Cast<C>(c_ref));
EXPECT_TRUE(Cast<D>(d_ref));

EXPECT_TRUE(Cast<A>(c_ref));
}

} // namespace tsl
21 changes: 13 additions & 8 deletions third_party/xla/third_party/tsl/tsl/concurrency/ref_count.h
Expand Up @@ -19,11 +19,19 @@ limitations under the License.
#include <atomic>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>

namespace tsl {

namespace internal {
// TODO(ezhulenev): Replace with C++20 concept when available.
// https://en.cppreference.com/w/cpp/concepts/derived_from
template <typename Derived, typename Base>
using DerivedFrom = typename std::enable_if_t<std::is_base_of_v<Base, Derived>>;
} // namespace internal

#ifndef NDEBUG
inline std::atomic<size_t> total_reference_counted_objects;

Expand Down Expand Up @@ -110,8 +118,7 @@ class ReferenceCounted {
};

// This is a smart pointer that keeps the specified reference counted value
// around. It is move-only to avoid accidental copies, but it can be copied
// explicitly.
// around.
template <typename T>
class RCReference {
public:
Expand All @@ -138,14 +145,12 @@ class RCReference {
}

// Support implicit conversion from RCReference<Derived> to RCReference<Base>.
template <typename U,
typename = std::enable_if_t<std::is_base_of<T, U>::value>>
RCReference(RCReference<U>&& u) : pointer_(u.pointer_) { // NOLINT
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
RCReference(RCReference<Derived>&& u) : pointer_(u.pointer_) { // NOLINT
u.pointer_ = nullptr;
}
template <typename U,
typename = std::enable_if_t<std::is_base_of<T, U>::value>>
RCReference(const RCReference<U>& u) : pointer_(u.pointer_) { // NOLINT
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
RCReference(const RCReference<Derived>& u) : pointer_(u.pointer_) { // NOLINT
if (pointer_) pointer_->AddRef();
}

Expand Down

0 comments on commit 2bae250

Please sign in to comment.