Skip to content

Commit

Permalink
Use absl::Status rather than tsl::Status
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622291171
  • Loading branch information
klucke authored and tensorflower-gardener committed Apr 5, 2024
1 parent 5b1f2fb commit 9989611
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
24 changes: 12 additions & 12 deletions third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc
Expand Up @@ -277,8 +277,8 @@ const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain) {
return "";
}

tsl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid,
const void* cbdata) {
absl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid,
const void* cbdata) {
/* Some APIs such as hipMalloc, implicitly work on th devices set by the
user using APIs such as hipSetDevice. API callbacks and activity records
for functions like hipMalloc does not return the device id (CUDA does). To
Expand Down Expand Up @@ -838,8 +838,8 @@ void RocmApiCallbackImpl::AddSynchronizeEventUponApiExit(
collector_->AddEvent(std::move(event), is_auxiliary);
}

tsl::Status RocmActivityCallbackImpl::operator()(const char* begin,
const char* end) {
absl::Status RocmActivityCallbackImpl::operator()(const char* begin,
const char* end) {
// we do not dump activities in this set in logger

static std::set<activity_op_t> dump_excluded_activities = {
Expand Down Expand Up @@ -1359,14 +1359,14 @@ void ApiCallback(uint32_t domain, uint32_t cbid, const void* cbdata,
tracer->ApiCallbackHandler(domain, cbid, cbdata).IgnoreError();
}

tsl::Status RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid,
const void* cbdata) {
absl::Status RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid,
const void* cbdata) {
if (api_tracing_enabled_)
TF_RETURN_IF_ERROR((*api_cb_impl_)(domain, cbid, cbdata));
return tsl::OkStatus();
}

tsl::Status RocmTracer::EnableApiTracing() {
absl::Status RocmTracer::EnableApiTracing() {
if (api_tracing_enabled_) return tsl::OkStatus();
api_tracing_enabled_ = true;

Expand All @@ -1392,7 +1392,7 @@ tsl::Status RocmTracer::EnableApiTracing() {
return tsl::OkStatus();
}

tsl::Status RocmTracer::DisableApiTracing() {
absl::Status RocmTracer::DisableApiTracing() {
if (!api_tracing_enabled_) return tsl::OkStatus();
api_tracing_enabled_ = false;

Expand Down Expand Up @@ -1423,8 +1423,8 @@ void ActivityCallback(const char* begin, const char* end, void* user_data) {
tracer->ActivityCallbackHandler(begin, end).IgnoreError();
}

tsl::Status RocmTracer::ActivityCallbackHandler(const char* begin,
const char* end) {
absl::Status RocmTracer::ActivityCallbackHandler(const char* begin,
const char* end) {
if (activity_tracing_enabled_) {
TF_RETURN_IF_ERROR((*activity_cb_impl_)(begin, end));
} else {
Expand Down Expand Up @@ -1452,7 +1452,7 @@ tsl::Status RocmTracer::ActivityCallbackHandler(const char* begin,
return tsl::OkStatus();
}

tsl::Status RocmTracer::EnableActivityTracing() {
absl::Status RocmTracer::EnableActivityTracing() {
if (activity_tracing_enabled_) return tsl::OkStatus();
activity_tracing_enabled_ = true;

Expand Down Expand Up @@ -1493,7 +1493,7 @@ tsl::Status RocmTracer::EnableActivityTracing() {
return tsl::OkStatus();
}

tsl::Status RocmTracer::DisableActivityTracing() {
absl::Status RocmTracer::DisableActivityTracing() {
if (!activity_tracing_enabled_) return tsl::OkStatus();

for (auto& iter : options_->activity_tracing) {
Expand Down
18 changes: 9 additions & 9 deletions third_party/xla/xla/backends/profiler/gpu/rocm_tracer.h
Expand Up @@ -59,7 +59,7 @@ class RocmApiCallbackImpl {
RocmTraceCollector* collector)
: options_(options), tracer_(tracer), collector_(collector) {}

tsl::Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata);
absl::Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata);

private:
void AddKernelEventUponApiExit(uint32_t cbid, const hip_api_data_t* data,
Expand Down Expand Up @@ -97,7 +97,7 @@ class RocmActivityCallbackImpl {
RocmTraceCollector* collector)
: options_(options), tracer_(tracer), collector_(collector) {}

tsl::Status operator()(const char* begin, const char* end);
absl::Status operator()(const char* begin, const char* end);

private:
void AddHipKernelActivityEvent(const roctracer_record_t* record);
Expand Down Expand Up @@ -127,9 +127,9 @@ class RocmTracer {
void Enable(const RocmTracerOptions& options, RocmTraceCollector* collector);
void Disable();

tsl::Status ApiCallbackHandler(uint32_t domain, uint32_t cbid,
const void* cbdata);
tsl::Status ActivityCallbackHandler(const char* begin, const char* end);
absl::Status ApiCallbackHandler(uint32_t domain, uint32_t cbid,
const void* cbdata);
absl::Status ActivityCallbackHandler(const char* begin, const char* end);

static uint64_t GetTimestamp();
static int NumGpus();
Expand All @@ -153,11 +153,11 @@ class RocmTracer {
explicit RocmTracer() : num_gpus_(NumGpus()) {}

private:
tsl::Status EnableApiTracing();
tsl::Status DisableApiTracing();
absl::Status EnableApiTracing();
absl::Status DisableApiTracing();

tsl::Status EnableActivityTracing();
tsl::Status DisableActivityTracing();
absl::Status EnableActivityTracing();
absl::Status DisableActivityTracing();

int num_gpus_;
std::optional<RocmTracerOptions> options_;
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc
Expand Up @@ -61,7 +61,7 @@ class ProfilerStatusHelper {
stream_executor::tpu::ProfilerApiFn()->TpuStatus_FreeFn(c_status);
}

static tsl::Status FromC( // TENSORFLOW_STATUS_OK
static absl::Status FromC( // TENSORFLOW_STATUS_OK
TF_Status* const c_status) {
if (stream_executor::tpu::ProfilerApiFn()->TpuStatus_CodeFn(c_status) ==
TSL_OK) {
Expand All @@ -80,7 +80,7 @@ class ProfilerStatusHelper {
TSL_OK;
}

tsl::Status status() const { // TENSORFLOW_STATUS_OK
absl::Status status() const { // TENSORFLOW_STATUS_OK
return FromC(c_status);
}

Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/stream_executor/tpu/tsl_status_helper.h
Expand Up @@ -28,7 +28,8 @@ class TslStatusHelper {

~TslStatusHelper() { TSL_DeleteStatus(c_status); }

static tsl::Status FromC(TF_Status* const c_status) { // TENSORFLOW_STATUS_OK
static absl::Status FromC(
TF_Status* const c_status) { // TENSORFLOW_STATUS_OK
absl::StatusCode code = tsl::StatusCodeFromTSLCode(TSL_GetCode(c_status));
if (code == absl::StatusCode::kOk) {
return tsl::OkStatus();
Expand All @@ -41,7 +42,7 @@ class TslStatusHelper {
absl::StatusCode::kOk;
}

tsl::Status status() const { // TENSORFLOW_STATUS_OK
absl::Status status() const { // TENSORFLOW_STATUS_OK
return FromC(c_status);
}

Expand Down

0 comments on commit 9989611

Please sign in to comment.