Skip to content

Commit

Permalink
Replace the use of xla::OkStatus with absl::OkStatus now that they're…
Browse files Browse the repository at this point in the history
… the same.

PiperOrigin-RevId: 635887485
  • Loading branch information
klucke authored and tensorflower-gardener committed May 21, 2024
1 parent 9201348 commit f734975
Show file tree
Hide file tree
Showing 24 changed files with 92 additions and 66 deletions.
4 changes: 4 additions & 0 deletions third_party/xla/xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ xla_cc_test(
"//xla/client:client_library",
"//xla/service:cpu_plugin",
"//xla/stream_executor:device_memory_allocator",
"@com_google_absl//absl/status",
"@local_tsl//tsl/platform:test_main",
],
)
Expand Down Expand Up @@ -364,6 +365,7 @@ cc_library(
"//xla/service:hlo_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:FuncDialect",
Expand Down Expand Up @@ -692,6 +694,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
Expand Down Expand Up @@ -864,6 +867,7 @@ xla_cc_test(
":host_callback",
":pjrt_client",
"//xla/tests:literal_test_util",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/lib/core:status_test_util",
],
Expand Down
7 changes: 4 additions & 3 deletions third_party/xla/xla/pjrt/host_callback_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <utility>

#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/tests/literal_test_util.h"
#include "tsl/lib/core/status_test_util.h"
Expand All @@ -46,7 +47,7 @@ class TestPjRtHostMemoryForDeviceManager
size_t dst_size, const Shape& dst_shape) override {
CHECK_EQ(src_size, dst_size);
std::memcpy(dst_data, src_data, src_size);
return OkStatus();
return absl::OkStatus();
}
};

Expand Down Expand Up @@ -80,7 +81,7 @@ TEST(HostCallbackTest, Basic) {
host_callback.results = {HostCallbackArgInfo{/*channel_id=*/2, shape}};
host_callback.callback = [byte_size](void** outputs, void** inputs) {
std::memcpy(outputs[0], inputs[0], byte_size);
return OkStatus();
return absl::OkStatus();
};

HostCallbackStates states;
Expand Down Expand Up @@ -128,7 +129,7 @@ TEST(HostCallbackTest, NonBlockingRecv) {
host_callback.results = {HostCallbackArgInfo{/*channel_id=*/2, shape}};
host_callback.callback = [byte_size](void** outputs, void** inputs) {
std::memcpy(outputs[0], inputs[0], byte_size);
return OkStatus();
return absl::OkStatus();
};

HostCallbackStates states;
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module,
ConvertMlirHloToHlo(module, &proto, use_tuple_args, return_tuple));

xla_computation = XlaComputation(std::move(*proto.mutable_hlo_module()));
return OkStatus();
return absl::OkStatus();
}

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ParseMlirModuleString(
Expand Down Expand Up @@ -343,7 +343,7 @@ absl::Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module) {
mlir::stablehlo::createStablehloDeserializePipeline(pm);
if (!mlir::succeeded(pm.run(mlir_module)))
return xla::InvalidArgument("Failed to upgrade versioned StableHLO.");
return OkStatus();
return absl::OkStatus();
}

} // namespace xla
24 changes: 12 additions & 12 deletions third_party/xla/xla/pjrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,31 +572,31 @@ absl::Status CompileOptions::ApplyOption(const std::string& key,
if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_BOOL &&
std::holds_alternative<bool>(value)) {
reflection->SetBool(&debug_options, xla_field, std::get<bool>(value));
return OkStatus();
return absl::OkStatus();
} else if (std::holds_alternative<std::string>(value)) {
TF_RETURN_IF_ERROR(
ApplyOptionFromString(xla_field, std::get<std::string>(value)));
return OkStatus();
return absl::OkStatus();
} else if (xla_field->type() ==
tsl::protobuf::FieldDescriptor::TYPE_INT32 &&
std::holds_alternative<int64_t>(value)) {
reflection->SetInt32(&debug_options, xla_field, std::get<int64_t>(value));
return OkStatus();
return absl::OkStatus();
} else if (xla_field->type() ==
tsl::protobuf::FieldDescriptor::TYPE_INT64 &&
std::holds_alternative<int64_t>(value)) {
reflection->SetInt64(&debug_options, xla_field, std::get<int64_t>(value));
return OkStatus();
return absl::OkStatus();
} else if (xla_field->type() ==
tsl::protobuf::FieldDescriptor::TYPE_FLOAT &&
std::holds_alternative<double>(value)) {
reflection->SetFloat(&debug_options, xla_field, std::get<double>(value));
return OkStatus();
return absl::OkStatus();
} else if (xla_field->type() ==
tsl::protobuf::FieldDescriptor::TYPE_DOUBLE &&
std::holds_alternative<double>(value)) {
reflection->SetDouble(&debug_options, xla_field, std::get<double>(value));
return OkStatus();
return absl::OkStatus();
} else {
return InvalidArgument(
"While setting option %s, '%s' is not a valid %s value.", key,
Expand All @@ -612,7 +612,7 @@ absl::Status CompileOptions::ApplyAllOptionOverrides() {
for (auto& option : env_option_overrides) {
TF_RETURN_IF_ERROR(ApplyOption(option.first, option.second));
}
return OkStatus();
return absl::OkStatus();
}

absl::Status CompileOptions::ApplyOptionFromString(
Expand All @@ -622,30 +622,30 @@ absl::Status CompileOptions::ApplyOptionFromString(
const tsl::protobuf::Reflection* reflection = debug_options.GetReflection();
if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_STRING) {
reflection->SetString(&debug_options, field, value);
return OkStatus();
return absl::OkStatus();
} else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT32) {
int int_value;
if (absl::SimpleAtoi(value, &int_value)) {
reflection->SetInt32(&debug_options, field, int_value);
return OkStatus();
return absl::OkStatus();
}
} else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT64) {
int int_value;
if (absl::SimpleAtoi(value, &int_value)) {
reflection->SetInt64(&debug_options, field, int_value);
return OkStatus();
return absl::OkStatus();
}
} else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_FLOAT) {
float float_value;
if (absl::SimpleAtof(value, &float_value)) {
reflection->SetFloat(&debug_options, field, float_value);
return OkStatus();
return absl::OkStatus();
}
} else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_BOOL) {
bool bvalue = value == "True";
if (value == "True" || value == "False") {
reflection->SetBool(&debug_options, field, bvalue);
return OkStatus();
return absl::OkStatus();
}
}
return InvalidArgument(
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ class ScopedHoldAsExternalReference : public PjRtBuffer::ExternalReference {
external_reference_->definition_events()) {
TF_RETURN_IF_ERROR(event->WaitForEventOnExternalStream(stream));
}
return OkStatus();
return absl::OkStatus();
}

private:
Expand Down Expand Up @@ -2129,7 +2129,7 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking,
buffer_on_device_shape.element_type() == PrimitiveType::PRED &&
buffer_on_device_shape.dimensions_size() == 1 &&
buffer_on_device_shape.dimensions(0) == 0) {
return OkStatus();
return absl::OkStatus();
}
// TODO(misard) Support casting of tuple parameters.
if (strict_shape_checking || buffer_on_device_shape.IsTuple()) {
Expand Down Expand Up @@ -2163,7 +2163,7 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking,
ShapeUtil::HumanStringWithLayout(buffer_on_device_shape));
}
}
return OkStatus();
return absl::OkStatus();
}

// Makes a tuple from the arguments to an execution.
Expand Down Expand Up @@ -2344,7 +2344,7 @@ absl::Status PjRtStreamExecutorLoadedExecutable::SetUpDonation(
parameters_that_must_be_donated_.emplace_back(
std::move(parameters_to_donate));
}
return OkStatus();
return absl::OkStatus();
}

absl::string_view PjRtStreamExecutorLoadedExecutable::name() const {
Expand Down Expand Up @@ -2919,7 +2919,7 @@ static Status GetFirstInputError(
}
}
}
return OkStatus();
return absl::OkStatus();
}

absl::StatusOr<PjRtLoadedExecutable::Result>
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/pjrt/pjrt_stream_executor_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
case kUninitialized:
return InvalidArgument("Buffer has not been initialized");
case kValid:
return OkStatus();
return absl::OkStatus();
case kMoved:
return InvalidArgument("Buffer has been moved.");
case kConverted:
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/pjrt/tracked_device_buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>

#include "absl/status/status.h"
#include "xla/client/client_library.h"
#include "xla/literal_util.h"
#include "xla/shape_util.h"
Expand All @@ -41,7 +42,7 @@ absl::StatusOr<std::shared_ptr<TrackedDeviceBuffer>> MakeArray(
client->backend().transfer_manager()->GetByteSizeRequirement(
subshape)));
device_buffers.push_back(device_memory.Release());
return OkStatus();
return absl::OkStatus();
}));
return std::make_shared<TrackedDeviceBuffer>(
client->backend().memory_allocator(), /*device_ordinal=*/0,
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/pjrt/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/blocking_counter.h"
Expand Down Expand Up @@ -694,12 +695,12 @@ static absl::Status ParseTilingSpecification(
if (ndim == 1) {
// Tiling doesn't do anything for a rank-1 array, except add padding. Since
// we're not going to touch any padding elements, we can ignore it.
return OkStatus();
return absl::OkStatus();
}
int offset = ndim;
offset -= tiling_spec.size();
absl::c_copy(tiling_spec, tiling.begin() + offset);
return OkStatus();
return absl::OkStatus();
}

// Helper function that builds a plan.
Expand Down
13 changes: 7 additions & 6 deletions third_party/xla/xla/pjrt/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -173,7 +174,7 @@ absl::Status ParseDeviceAssignmentCompileOptions(
*device_assignment =
std::make_shared<DeviceAssignment>(build_options->device_assignment());
}
return OkStatus();
return absl::OkStatus();
}

// Helper method that takes an ArrayAttr of DictionaryAttrs for each arg or
Expand Down Expand Up @@ -403,7 +404,7 @@ absl::Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module,
->mutable_map();
frontend_attrs["arg_layout_modes"] = GetFrontendAttr(arg_layout_modes);
frontend_attrs["out_layout_modes"] = GetFrontendAttr(out_layout_modes);
return OkStatus();
return absl::OkStatus();
}

static std::string GetFrontendAttrForMemorySpace(
Expand Down Expand Up @@ -431,7 +432,7 @@ absl::Status AddMemoryKindsToFrontendAttrs(mlir::ModuleOp module,
GetFrontendAttrForMemorySpace(arg_memory_spaces);
frontend_attrs["out_memory_spaces"] =
GetFrontendAttrForMemorySpace(out_memory_spaces);
return OkStatus();
return absl::OkStatus();
}

static absl::StatusOr<std::vector<LayoutMode>> GetLayoutModesFromFrontendAttr(
Expand Down Expand Up @@ -746,7 +747,7 @@ absl::Status DetermineArgumentLayoutsFromCompileOptions(
choose_compact_layout_for_shape_function(sharded_subshape));
*subshape->mutable_layout() = layout.layout();
}
return OkStatus();
return absl::OkStatus();
});
};
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
Expand All @@ -768,7 +769,7 @@ absl::Status DetermineArgumentLayoutsFromCompileOptions(
}
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options->set_result_layout(result_layout);
return OkStatus();
return absl::OkStatus();
}

absl::StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated(
Expand Down Expand Up @@ -821,7 +822,7 @@ absl::StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated(
}
parameters_to_donate.push_back(this_parameter);
}
return OkStatus();
return absl::OkStatus();
}));
absl::c_sort(parameters_to_donate);
return parameters_to_donate;
Expand Down
11 changes: 10 additions & 1 deletion third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ cc_library(
"//xla/service:hlo_dataflow_analysis",
"//xla/service:hlo_verifier",
"//xla/service:transfer_manager",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
],
)
Expand Down Expand Up @@ -145,6 +146,7 @@ cc_library(
"//xla/service:hlo_module_config",
"//xla/service:hlo_parser",
"//xla/service:hlo_verifier",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
Expand Down Expand Up @@ -283,6 +285,7 @@ cc_library(
"//xla/service:interpreter_plugin", # reference backend
"//xla/service:platform_util",
"//xla/stream_executor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/lib/core:bitmap",
Expand All @@ -302,6 +305,7 @@ cc_library(
":filecheck",
"//xla/service:llvm_compiler",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/status",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:test",
],
Expand Down Expand Up @@ -564,6 +568,7 @@ xla_test(
"//xla/service:stream_pool",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:regexp",
Expand Down Expand Up @@ -633,6 +638,7 @@ xla_test(
"//xla/client:local_client",
"//xla/client:xla_builder",
"//xla/client:xla_computation",
"@com_google_absl//absl/status",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:test",
],
Expand Down Expand Up @@ -2421,7 +2427,10 @@ xla_cc_test(
"//xla/stream_executor/cuda:cuda_platform_id",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:rocm_platform_id",
]),
]) + [
"//xla/stream_executor/rocm:rocm_platform_id",
"@com_google_absl//absl/status",
],
)

xla_test(
Expand Down
Loading

0 comments on commit f734975

Please sign in to comment.