Skip to content

Commit

Permalink
[PJRT:TFRT:CPU] Fix undefined behavior in JAX TFRT CPU backend.
Browse files Browse the repository at this point in the history
Depending on the evaluation order of EnqueueWorkWhenReady, the vector read and move can cause undefined behavior.

This shows up as incorrect results in open source which uses gcc.

google/jax#7128

PiperOrigin-RevId: 383904018
Change-Id: Id96338110bea9a39a3ce74dd88dd45a267940047
  • Loading branch information
zhangqiaorjc authored and tensorflower-gardener committed Jul 9, 2021
1 parent 440849d commit 8cc3ffa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
35 changes: 25 additions & 10 deletions tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,16 @@ static std::vector<tfrt::RCReference<tfrt::AsyncValue>> GetAsyncValues(
return avs;
}

static std::vector<tfrt::RCReference<tfrt::AsyncValue>> CopyAsyncValues(
absl::Span<const tfrt::RCReference<tfrt::AsyncValue>> events) {
std::vector<tfrt::RCReference<tfrt::AsyncValue>> avs;
avs.reserve(events.size());
for (const auto& ev : events) {
avs.push_back(ev.CopyRef());
}
return avs;
}

// Enqueue to TFRT non-blocking work queue when all `values` are ready.
static void EnqueueWorkWhenReady(
tfrt::HostContext* host_ctx,
Expand Down Expand Up @@ -964,6 +974,8 @@ void TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal,

std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs =
GetAsyncValues(device_buffer.buffer()->DefinitionEvents());
std::vector<tfrt::RCReference<tfrt::AsyncValue>> device_buffer_wait_avs_copy =
CopyAsyncValues(device_buffer_wait_avs);

bool should_sync_copy = device_buffer_wait_avs.empty() &&
literal->size_bytes() < kSmallDataTransferByteSize;
Expand Down Expand Up @@ -991,8 +1003,8 @@ void TfrtCpuBuffer::ToLiteral(MutableLiteralBase* literal,
EnqueueWorkWhenReady(
host_ctx, device_buffer_wait_avs,
[this, movable_device_buffer{device_buffer.ToClosure()},
device_buffer_wait_avs = std::move(device_buffer_wait_avs), literal,
on_ready{std::move(on_ready)}] {
device_buffer_wait_avs = std::move(device_buffer_wait_avs_copy),
literal, on_ready{std::move(on_ready)}] {
tensorflow::profiler::TraceMe traceme("D2H Dispatch");
TfrtCpuBuffer::ScopedHold device_buffer(movable_device_buffer);
// Errors in src buffer are surfaced to user.
Expand Down Expand Up @@ -1092,6 +1104,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuBuffer::CopyToDevice(
std::vector<tfrt::RCReference<tfrt::AsyncValue>>
src_device_buffer_definition_events_avs =
GetAsyncValues(src_device_buffer.buffer()->DefinitionEvents());
std::vector<tfrt::RCReference<tfrt::AsyncValue>>
src_device_buffer_definition_events_avs_copy =
CopyAsyncValues(src_device_buffer_definition_events_avs);

// Add d2d as usage event on src_buffer.
src_device_buffer.ConvertUsageHold(absl::MakeSpan(src_usage_events));
Expand All @@ -1101,7 +1116,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> TfrtCpuBuffer::CopyToDevice(
[client = client_, num_leaf_buffers, src_buffers = std::move(src_buffers),
dst_buffers_copies = dst_buffers, indirect_avs = std::move(indirect_avs),
src_device_buffer_definition_events_avs =
std::move(src_device_buffer_definition_events_avs)]() mutable {
std::move(src_device_buffer_definition_events_avs_copy)]() mutable {
tensorflow::profiler::TraceMe traceme("D2D Dispatch");
for (const auto& av : src_device_buffer_definition_events_avs) {
if (auto* error = av->GetErrorIfPresent()) {
Expand Down Expand Up @@ -1359,7 +1374,7 @@ TfrtCpuExecutable::ExecuteHelper(
// This also ensures that the returned `execute_event` dominates all inputs'
// events, and thus output buffer only need to contain `execute_event` as the
// single definition event.
std::vector<tfrt::AsyncValueRef<CpuEvent>> input_deps;
std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps;
input_deps.reserve(argument_handles.size());

auto donate_it = parameters_that_must_be_donated_.begin();
Expand Down Expand Up @@ -1394,7 +1409,7 @@ TfrtCpuExecutable::ExecuteHelper(
// Definition events are never modified after buffer construction.
for (const auto& ev : device_buffer->DefinitionEvents()) {
if (!ev.IsAvailable()) {
input_deps.push_back(ev.CopyRef());
input_deps.push_back(ev.CopyRCRef());
}
}
// If we are trying to donate this buffer, we must wait on its usage
Expand All @@ -1406,7 +1421,7 @@ TfrtCpuExecutable::ExecuteHelper(
if (must_donate) {
for (const auto& ev : device_buffer->UsageEvents()) {
if (!ev.IsAvailable()) {
input_deps.push_back(ev.CopyRef());
input_deps.push_back(ev.CopyRCRef());
}
}
}
Expand Down Expand Up @@ -1500,10 +1515,10 @@ TfrtCpuExecutable::ExecuteHelper(
if (is_a_collective_launch) {
client_->SetLastCollectiveLaunchEvent(execute_event.CopyRef());
}
std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps_avs =
GetAsyncValues(input_deps);
std::vector<tfrt::RCReference<tfrt::AsyncValue>> input_deps_avs_copy =
CopyAsyncValues(input_deps);
EnqueueWorkWhenReady(
host_context, input_deps_avs,
host_context, input_deps,
[cpu_executable, result_buffer,
buffer_pointers = std::move(buffer_pointers),
buffer_table = std::move(buffer_table),
Expand All @@ -1513,7 +1528,7 @@ TfrtCpuExecutable::ExecuteHelper(
compute_reservation = std::move(compute_reservation),
tracked_buffers = std::move(tracked_buffers),
execute_event = execute_event.CopyRef(),
input_deps_avs = std::move(input_deps_avs)]() mutable {
input_deps_avs = std::move(input_deps_avs_copy)]() mutable {
for (const auto& av : input_deps_avs) {
if (auto* error = av->GetErrorIfPresent()) {
execute_event.SetError(absl::StrCat(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes.
_version = 26
_version = 27

xla_platform_names = {
'cpu': 'Host',
Expand Down

0 comments on commit 8cc3ffa

Please sign in to comment.