Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions test/cpp/test_status_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace cpp_test {

// Prefix of the C++ stacktrace PyTorch adds to the error message.
constexpr inline char kTorchCppStacktracePrefix[] =
"Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:";
"Exception raised from OkOrThrow at torch_xla/csrc/status.cpp:";

constexpr inline char kNewMessage[] = "New test error message";
constexpr inline char kMessage[] = "Test error message";
Expand All @@ -100,15 +100,15 @@ inline std::string GetStatusPropagationTrace(const absl::Status& status) {
: "";
}

TEST_P(StatusTest, MaybeThrowWithOkStatus) {
TEST_P(StatusTest, OkOrThrowWithOkStatus) {
absl::Status ok_status = absl::OkStatus();
EXPECT_NO_THROW(MaybeThrow(ok_status));
EXPECT_NO_THROW(OkOrThrow(ok_status));
}

TEST_P(StatusTest, MaybeThrowWithErrorStatus) {
TEST_P(StatusTest, OkOrThrowWithErrorStatus) {
try {
absl::Status error_status = absl::InvalidArgumentError(kMessage);
MaybeThrow(error_status);
OkOrThrow(error_status);
} catch (const c10::Error& error) {
if (IsShowCppStacktracesMode()) {
EXPECT_THAT(std::string_view(error.what()),
Expand Down Expand Up @@ -343,7 +343,7 @@ TEST_P(StatusTest, MacroErrorWithLocation) {
}
}

TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
TEST_P(StatusTest, OkOrThrowWithErrorPropagationWithNewMessage) {
int32_t errline0 = __LINE__ + 2;
auto innerfn = [&]() -> absl::Status {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
Expand All @@ -362,7 +362,7 @@ TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
};

try {
MaybeThrow(outerfn());
OkOrThrow(outerfn());
} catch (const c10::Error& error) {
if (IsShowCppStacktracesMode()) {
// Expected Error Message Prefix
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self,
} else {
auto dst_tensor = std::move(dst_tensor_status).value();
tensor_methods::copy_(dst_tensor, self_tensor_status.value());
MaybeThrow(bridge::ReplaceXlaTensor(dst, dst_tensor));
OkOrThrow(bridge::ReplaceXlaTensor(dst, dst_tensor));
}
return dst;
}
Expand Down Expand Up @@ -3438,7 +3438,7 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self,
const at::Tensor& source) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr source_tensor = GetValueOrThrow(bridge::GetXlaTensor(source));
MaybeThrow(bridge::ReplaceXlaTensor(self, source_tensor));
OkOrThrow(bridge::ReplaceXlaTensor(self, source_tensor));
return self;
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
pack->external_reference =
GetValueOrThrow(pjrt_buffer->AcquireExternalReference());
xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture();
MaybeThrow(future.Await());
OkOrThrow(future.Await());
}
pack->buffer_reference = pjrt_buffer;

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ void AllReduceInPlace(const std::string& reduce_type,
replica_groups, pin_layout);
std::vector<XLATensorPtr> new_xtensors =
GetValueOrThrow(bridge::GetXlaTensors(tensors));
MaybeThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors));
OkOrThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors));
}

at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TensorSource {

virtual std::vector<int64_t> byte_strides() const {
std::vector<int64_t> byte_strides(shape().dimensions_size());
MaybeThrow(
OkOrThrow(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/status.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,16 @@ static std::string LineBreakIfCppStacktracesEnabled() {
return torch::get_cpp_stacktraces_enabled() ? "\n" : "";
}

void MaybeThrow(const absl::Status& status) {
void OkOrThrow(const absl::Status& status) {
TORCH_CHECK(status.ok(), absl::StrCat(BuildStatusErrorMessage(status),
LineBreakIfCppStacktracesEnabled()));
}

void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); }
void GetValueOrThrow(const absl::Status& status) { OkOrThrow(status); }

void OkOrDie(const absl::Status& status, const char* file, const int32_t line,
const char* function, std::string_view message) {
void status_internal::OkOrDie(const absl::Status& status, const char* file,
const int32_t line, const char* function,
std::string_view message) {
if (status.ok()) {
return;
}
Expand Down
33 changes: 17 additions & 16 deletions torch_xla/csrc/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ constexpr char kStatusPropagationTraceKey[] =
// If `FnThatReturnStatus()` returns a non-ok status, this macro will
// call `ABSL_CHECK()`, which will crash.
//
#define XLA_CHECK_OK(status, ...) \
::torch_xla::OkOrDie(::torch_xla::status_internal::GetStatus(status), \
__FILE__, __LINE__, __FUNCTION__, ##__VA_ARGS__)
#define XLA_CHECK_OK(status, ...) \
::torch_xla::status_internal::OkOrDie( \
::torch_xla::status_internal::GetStatus(status), __FILE__, __LINE__, \
__FUNCTION__, ##__VA_ARGS__)

namespace status_internal {

Expand Down Expand Up @@ -190,6 +191,14 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
int32_t line, const char* function,
std::string_view new_message = "");

// Checks that `status` is an ok status.
//
// Otherwise, it will create a new status instance with the given source
// location information, and incorporate its message (alongside the
// status propagation trace) to the crash report.
void OkOrDie(const absl::Status& status, const char* file, const int32_t line,
const char* function, std::string_view message = "");

} // namespace status_internal

// Builds the complete error message for the given `status`.
Expand All @@ -200,43 +209,35 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
// It doesn't add a trailing line break.
std::string BuildStatusErrorMessage(const absl::Status& status);

// Maybe throws an exception if `status` has a non-ok code.
// Throws an exception if `status` has a non-ok code.
//
// Ideally, this function should be used only used in the project's
// boundary, e.g. when we need to throw an exception for the user to see.
void MaybeThrow(const absl::Status& status);
void OkOrThrow(const absl::Status& status);

// Either returns the value `status` holds, if it's an ok-status, or throw an
// exception from its error status.
template <class T>
T& GetValueOrThrow(absl::StatusOr<T>& status) {
MaybeThrow(status.status());
OkOrThrow(status.status());
return status.value();
}

template <class T>
const T& GetValueOrThrow(const absl::StatusOr<T>& status) {
MaybeThrow(status.status());
OkOrThrow(status.status());
return status.value();
}

template <class T>
T GetValueOrThrow(absl::StatusOr<T>&& status) {
MaybeThrow(status.status());
OkOrThrow(status.status());
return std::move(status).value();
}

// `GetValueOrThrow` overload for `Status`.
void GetValueOrThrow(const absl::Status& status);

// Checks that `status` is an ok status.
//
// Otherwise, it will create a new status instance with the given source
// location information, and incorporate its message (alongside the
// status propagation trace) to the crash report.
void OkOrDie(const absl::Status& status, const char* file, const int32_t line,
const char* function, std::string_view message = "");

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_STATUS_H_