diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 9ad731eba82a..48c02f15df66 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -462,6 +462,8 @@ std::vector PjRtComputationClient::TransferFromServer( metrics::TimedSection timed(TransferFromServerMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromServer", tsl::profiler::TraceMeLevel::kInfo); + std::vector> futures; + futures.reserve(handles.size()); std::vector literals; literals.reserve(handles.size()); int64_t total_size = 0; @@ -471,12 +473,16 @@ std::vector PjRtComputationClient::TransferFromServer( auto new_handle = ReplicateShardedData(handle); const PjRtData& pjrt_data = dynamic_cast(*new_handle); - auto& literal = + xla::Literal& literal = literals.emplace_back(host_output_shape(pjrt_data.buffer.get())); - XLA_CHECK_OK(pjrt_data.buffer->ToLiteralSync(&literal)); + futures.push_back(pjrt_data.buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } + for (auto& future : futures) { + tsl::Status status = future.Await(); + XLA_CHECK_OK(status); + } InboundDataMetric()->AddSample(total_size); return literals;