Skip to content

Commit

Permalink
[tsl:concurrency] NFC: Use port::Aligned(Malloc|Free) instead of cust…
Browse files Browse the repository at this point in the history
…om ones

PiperOrigin-RevId: 622929778
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Apr 8, 2024
1 parent d626e3f commit 6ff4b63
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 62 deletions.
1 change: 1 addition & 0 deletions third_party/xla/third_party/tsl/tsl/concurrency/BUILD
Expand Up @@ -24,6 +24,7 @@ cc_library(
":concurrent_vector",
":ref_count",
"//tsl/platform:logging",
"//tsl/platform:platform_port",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
Expand Down
71 changes: 18 additions & 53 deletions third_party/xla/third_party/tsl/tsl/concurrency/async_value.cc
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#include "tsl/concurrency/async_value.h"

#include <atomic>
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <utility>

#include "absl/container/inlined_vector.h"
Expand All @@ -25,46 +27,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tsl/concurrency/async_value_ref.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/platform/logging.h"

namespace tsl {

namespace internal {

void* AlignedAlloc(size_t alignment, size_t size) {
size = (size + alignment - 1) / alignment * alignment;
#ifdef _WIN32
// MSVC runtime doesn't support aligned_alloc(). See
// https://developercommunity.visualstudio.com/t/c17-stdaligned-alloc%E7%BC%BA%E5%A4%B1/468021#T-N473365
return _aligned_malloc(size, alignment);
#elif defined(__ANDROID__) || defined(OS_ANDROID)
return memalign(alignment, size);
#else
// posix_memalign requires that the requested alignment be at least
// alignof(void*). In this case, fall back on malloc which should return
// memory aligned to at least the size of a pointer.
if (alignment <= alignof(void*)) return std::malloc(size);
void* ptr = nullptr;
if (posix_memalign(&ptr, alignment, size) != 0)
return nullptr;
else
return ptr;
#endif
}

void AlignedFree(void* ptr) {
#ifdef _WIN32
// _aligned_alloc() must be paired with _aligned_free().
//
// Attempting to use free() with a pointer returned by _aligned_malloc()
// results in runtime issues that are hard to debug.
_aligned_free(ptr);
#else
free(ptr);
#endif
}

} // namespace internal

// This is a singly linked list of nodes waiting for notification, hanging off
// of AsyncValue. When the value becomes available or if an error occurs, the
// callbacks are informed.
Expand All @@ -83,9 +49,8 @@ class NotifierListNode {
uint16_t AsyncValue::CreateTypeInfoAndReturnTypeIdImpl(
const TypeInfo& type_info) {
size_t type_id = GetTypeInfoTableSingleton()->emplace_back(type_info) + 1;
// Detect overflow.
assert(type_id < std::numeric_limits<uint16_t>::max() &&
"Too many different AsyncValue types.");
DCHECK(type_id < std::numeric_limits<uint16_t>::max())
<< "Too many different AsyncValue types.";
return type_id;
}

Expand All @@ -99,7 +64,7 @@ std::atomic<size_t> AsyncValue::total_allocated_async_values_;

const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const {
TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton();
assert(type_id_ != 0);
DCHECK_NE(type_id_, 0);
return (*type_info_table)[type_id_ - 1];
}

Expand All @@ -108,17 +73,17 @@ const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const {
// need to change our state and clear out the notifications. The current state
// must be unavailable (i.e. kUnconstructed or kConstructed).
void AsyncValue::NotifyAvailable(State available_state) {
assert((kind() == Kind::kConcrete || kind() == Kind::kIndirect) &&
"Should only be used by ConcreteAsyncValue or IndirectAsyncValue");
DCHECK((kind() == Kind::kConcrete || kind() == Kind::kIndirect))
<< "Should only be used by ConcreteAsyncValue or IndirectAsyncValue";

assert(available_state == State::kConcrete ||
DCHECK(available_state == State::kConcrete ||
available_state == State::kError);

// Mark the value as available, ensuring that new queries for the state see
// the value that got filled in.
auto old_value = waiters_and_state_.exchange(
WaitersAndState(nullptr, available_state), std::memory_order_acq_rel);
assert(old_value.state() == State::kUnconstructed ||
DCHECK(old_value.state() == State::kUnconstructed ||
old_value.state() == State::kConstructed);

RunWaiters(old_value.waiter());
Expand Down Expand Up @@ -158,7 +123,7 @@ void AsyncValue::EnqueueWaiter(absl::AnyInvocable<void()> waiter,
// so, just run the waiter.
if (old_value.state() == State::kConcrete ||
old_value.state() == State::kError) {
assert(old_value.waiter() == nullptr);
DCHECK(old_value.waiter() == nullptr);
node->notification_();
delete node;
return;
Expand All @@ -169,16 +134,16 @@ void AsyncValue::EnqueueWaiter(absl::AnyInvocable<void()> waiter,

// compare_exchange_weak succeeds. The old_value must be in either
// kUnconstructed or kConstructed state.
assert(old_value.state() == State::kUnconstructed ||
DCHECK(old_value.state() == State::kUnconstructed ||
old_value.state() == State::kConstructed);
}

void AsyncValue::SetError(absl::Status status) {
assert(!status.ok());
DCHECK(!status.ok());
if (kind() == Kind::kConcrete) {
GetTypeInfo().set_error(this, std::move(status));
} else {
assert(kind() == Kind::kIndirect);
DCHECK(kind() == Kind::kIndirect);
auto error_av = MakeErrorAsyncValueRef(std::move(status));
static_cast<IndirectAsyncValue*>(this)->ForwardTo(std::move(error_av));
}
Expand All @@ -187,17 +152,17 @@ void AsyncValue::SetError(absl::Status status) {
// Mark this IndirectAsyncValue as forwarding to the specified value. This
// gives the IndirectAsyncValue a +1 reference.
void IndirectAsyncValue::ForwardTo(RCReference<AsyncValue> value) {
assert(IsUnavailable());
DCHECK(IsUnavailable());

auto s = value->state();
if (s == State::kConcrete || s == State::kError) {
assert(!value_ && "IndirectAsyncValue::ForwardTo is called more than once");
DCHECK(!value_) << "IndirectAsyncValue::ForwardTo is called more than once";
auto* concrete_value = value.release();
if (concrete_value->kind() == Kind::kIndirect) {
auto* indirect_value = static_cast<IndirectAsyncValue*>(concrete_value);
concrete_value = indirect_value->value_;
assert(concrete_value != nullptr);
assert(concrete_value->kind() == Kind::kConcrete);
DCHECK(concrete_value != nullptr);
DCHECK(concrete_value->kind() == Kind::kConcrete);
concrete_value->AddRef();
indirect_value->DropRef();
}
Expand Down
10 changes: 3 additions & 7 deletions third_party/xla/third_party/tsl/tsl/concurrency/async_value.h
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tsl/concurrency/concurrent_vector.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/platform/mem.h"

namespace tsl {

Expand All @@ -43,11 +44,6 @@ class ConcreteAsyncValue;
template <typename T>
constexpr bool kMaybeBase = std::is_class<T>::value && !std::is_final<T>::value;

// TODO(ezhulenev): Switch to `tsl::port::Aligned(Malloc|Free)` once TFRT will
// be able to properly depend on TSL in the open source build.
void* AlignedAlloc(size_t alignment, size_t size);
void AlignedFree(void* ptr);

} // namespace internal

// This is a future of the specified value type. Arbitrary C++ types may be used
Expand Down Expand Up @@ -974,12 +970,12 @@ inline void AsyncValue::Destroy() {
// explicit check and instead make ~IndirectAsyncValue go through the
// GetTypeInfo().destructor case below.
static_cast<IndirectAsyncValue*>(this)->~IndirectAsyncValue();
if (was_ref_counted) internal::AlignedFree(this);
if (was_ref_counted) port::AlignedFree(this);
return;
}

GetTypeInfo().destructor(this);
if (was_ref_counted) internal::AlignedFree(this);
if (was_ref_counted) port::AlignedFree(this);
}

inline bool AsyncValue::IsUnique() const {
Expand Down
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tsl/concurrency/async_value.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/mem.h"

namespace tsl {

Expand Down Expand Up @@ -466,8 +467,7 @@ T* PlacementConstruct(void* buf, Args&&... args) {

template <typename T, typename... Args>
T* AllocateAndConstruct(Args&&... args) {
// TODO(ezhulenev): `port::AlignedMalloc` has a different order of arguments!
void* buf = internal::AlignedAlloc(alignof(T), sizeof(T));
void* buf = port::AlignedMalloc(sizeof(T), alignof(T));
return PlacementConstruct<T, Args...>(buf, std::forward<Args>(args)...);
}

Expand Down

0 comments on commit 6ff4b63

Please sign in to comment.