Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609262042
  • Loading branch information
tensorflower-gardener committed Feb 22, 2024
1 parent 6125020 commit ef4aab0
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 20 deletions.
6 changes: 3 additions & 3 deletions third_party/xla/xla/pjrt/gpu/gpu_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ limitations under the License.
namespace xla {

// Builds an xla::LocalClient for the GPU platform.
StatusOr<LocalClient*> GetGpuXlaClient(
absl::StatusOr<LocalClient*> GetGpuXlaClient(
const std::optional<std::string>& platform_name,
const std::optional<std::set<int>>& allowed_devices) {
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -71,7 +71,7 @@ void EnablePeerAccess(absl::Span<se::StreamExecutor* const> executors) {
}

// Builds a BFCAllocator for all local GPUs.
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
se::StreamExecutor* executor, double memory_fraction, bool preallocate) {
bool enable_unified_memory;
Status status = tsl::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", false,
Expand Down Expand Up @@ -119,7 +119,7 @@ StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
}

// Builds a BFCAllocator for all local GPUs that uses collective memory.
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
se::StreamExecutor* executor, double memory_fraction,
size_t collective_memory_size) {
int device_ordinal = executor->device_ordinal();
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/pjrt/gpu/gpu_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ limitations under the License.
namespace xla {

// Builds an xla::LocalClient for the GPU platform.
StatusOr<LocalClient*> GetGpuXlaClient(
absl::StatusOr<LocalClient*> GetGpuXlaClient(
const std::optional<std::string>& platform_name,
const std::optional<std::set<int>>& allowed_devices);

Expand Down Expand Up @@ -70,11 +70,11 @@ std::unique_ptr<tsl::BFCAllocator> GetGpuHostAllocator(
se::StreamExecutor* executor);

// Builds a BFCAllocator for all local GPUs.
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
se::StreamExecutor* executor, double memory_fraction, bool preallocate);

// Builds a BFCAllocator for all local GPUs that uses collective memory.
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
se::StreamExecutor* executor, double memory_fraction,
size_t collective_memory_size);

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/gpu/nccl_id_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ limitations under the License.

namespace xla {

StatusOr<gpu::NcclCliqueId> NcclIdStore::GetNcclUniqueId(
absl::StatusOr<gpu::NcclCliqueId> NcclIdStore::GetNcclUniqueId(
const gpu::NcclCliqueKey& key) {
// The caller must ensure that threads calling this method concurrently have
// unique keys, otherwise the global key-value store may hold the wrong value.
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/pjrt/gpu/nccl_id_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class NcclIdStore {
device_to_node_(std::move(device_to_node)),
kv_store_(std::move(kv_store)) {}

StatusOr<gpu::NcclCliqueId> GetNcclUniqueId(const gpu::NcclCliqueKey& key);
absl::StatusOr<gpu::NcclCliqueId> GetNcclUniqueId(
const gpu::NcclCliqueKey& key);

private:
const int node_id_;
Expand Down
18 changes: 9 additions & 9 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,26 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
tsl::Fingerprint64(platform_name), platform_name,
devices_.back()->device_kind(), devices_)) {}

xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
absl::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;

absl::string_view platform_version() const override;

StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(absl::Span<const Shape> shapes,
PjRtDevice* device) override;

PjRtFuture<Status> CopyRawSubBufferToHost(PjRtBuffer* buffer, void* dst,
int64_t offset,
int64_t transfer_size) override;

StatusOr<const xla::PjRtTopologyDescription*> GetTopologyDescription()
absl::StatusOr<const xla::PjRtTopologyDescription*> GetTopologyDescription()
const override {
return &topology_;
}

// TODO(b/285385306): Enable loading a non-loaded PjRtExecutable.
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Load(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Load(
std::unique_ptr<PjRtExecutable> executable,
const LoadOptions& load_options) override {
return absl::WrapUnique<PjRtLoadedExecutable>(
Expand All @@ -210,20 +210,20 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {

// TODO(b/296466237): Unify `Load` method after (de)serialization and tests on
// existing use cases are done.
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Load(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Load(
std::unique_ptr<PjRtExecutable> executable);

// TODO(b/296466237): Unify `LoadSerializedExecutable` after fixing existing
// tests.
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> LoadSerialized(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> LoadSerialized(
absl::string_view serialized, std::optional<CompileOptions> options,
const LoadOptions& load_options);

StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> DeserializeExecutable(
absl::string_view serialized,
std::optional<CompileOptions> options) override;

StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;

private:
Expand Down Expand Up @@ -253,7 +253,7 @@ struct GpuClientOptions {
bool enable_mock_nccl = false;
};

StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
const GpuClientOptions& options);

} // namespace xla
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ using ::testing::ElementsAre;
using ::testing::HasSubstr;
using ::tsl::testing::StatusIs;

StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>> CompileExecutable(
absl::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>> CompileExecutable(
absl::string_view program, xla::PjRtClient& client,
xla::CompileOptions compile_options = xla::CompileOptions()) {
TF_ASSIGN_OR_RETURN(auto hlo_module,
Expand All @@ -67,8 +67,8 @@ StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable>> CompileExecutable(

// Given the result of a PjrtExecutable::Execute call (TF-status of vectors of
// vectors), extract the zeroth result from the zeroth device.
StatusOr<std::shared_ptr<xla::Literal>> ExtractSingleResult(
xla::StatusOr<std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>>>&
absl::StatusOr<std::shared_ptr<xla::Literal>> ExtractSingleResult(
absl::StatusOr<std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>>>&
result) {
TF_RETURN_IF_ERROR(result.status());
TF_RET_CHECK(result->size() == 1);
Expand Down

0 comments on commit ef4aab0

Please sign in to comment.