Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -993,17 +993,14 @@ void XrtComputationClient::ReleaseHandle(int64 handle, const string& device,
triggered_task_->Activate();
}

void XrtComputationClient::ReleaseXrtData(XrtData* xrt_data) {
ReleaseHandle(xrt_data->get_handle(), xrt_data->device(),
&released_data_handles_);
void XrtComputationClient::ReleaseXrtData(const string& device, int64 handle) {
ReleaseHandle(handle, device, &released_data_handles_);
ReleaseDataHandlesCounter()->AddValue(1);
}

void XrtComputationClient::ReleaseXrtComputation(
XrtComputation* xrt_computation) {
ReleaseHandle(xrt_computation->get_handle(),
xrt_computation->compilation_device,
&released_compile_handles_);
const string& compilation_device, int64 handle) {
ReleaseHandle(handle, compilation_device, &released_compile_handles_);
ReleaseCompileHandlesCounter()->AddValue(1);
}

Expand Down
37 changes: 16 additions & 21 deletions third_party/xla_client/xrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ class XrtComputationClient : public ComputationClient {
};

struct XrtHandle {
XrtHandle(XrtComputationClient* self, int64 handle)
: self(self), handle(handle) {}
XrtHandle(int64 handle, std::function<void()> releaser)
: handle(handle), releaser(std::move(releaser)) {}

~XrtHandle() { releaser(); }

XrtComputationClient* self;
int64 handle;
std::function<void()> releaser;
};

using XrtHandlePtr = std::shared_ptr<XrtHandle>;
Expand All @@ -55,13 +57,10 @@ class XrtComputationClient : public ComputationClient {
XrtData(XrtComputationClient* self, string device, Shape device_shape,
int64 handle)
: Data(std::move(device), std::move(device_shape)),
handle_ptr(std::make_shared<XrtHandle>(self, handle)) {}

~XrtData() override {
if (handle_ptr != nullptr && handle_ptr.use_count() == 1) {
handle_ptr->self->ReleaseXrtData(this);
}
}
handle_ptr(std::make_shared<XrtHandle>(
handle, [self, device = this->device(), handle]() {
self->ReleaseXrtData(device, handle);
})) {}

int64 get_handle() const { return handle_ptr->handle; }

Expand All @@ -78,19 +77,15 @@ class XrtComputationClient : public ComputationClient {
int64 handle, string compilation_device)
: Computation(std::move(computation), std::move(program_shape),
std::move(devices)),
handle_ptr(std::make_shared<XrtHandle>(self, handle)),
compilation_device(std::move(compilation_device)) {}

~XrtComputation() override {
if (handle_ptr.use_count() == 1) {
handle_ptr->self->ReleaseXrtComputation(this);
}
}
handle_ptr(std::make_shared<XrtHandle>(
handle, [self, compilation_device = std::move(compilation_device),
handle]() {
self->ReleaseXrtComputation(compilation_device, handle);
})) {}

int64 get_handle() const { return handle_ptr->handle; }

XrtHandlePtr handle_ptr;
string compilation_device;
};

public:
Expand Down Expand Up @@ -283,9 +278,9 @@ class XrtComputationClient : public ComputationClient {
void ReleaseHandle(int64 handle, const string& device,
std::vector<DeviceHandle>* handles);

void ReleaseXrtData(XrtData* xrt_data);
void ReleaseXrtData(const string& device, int64 handle);

void ReleaseXrtComputation(XrtComputation* xrt_computation);
void ReleaseXrtComputation(const string& compilation_device, int64 handle);

// Starts the handle releaser thread (which runs the HandleReleaser() API).
void StartHandleReleaser();
Expand Down