Skip to content

Commit

Permalink
[ifrt] Change Client::GetTopologyForDevices to take a DeviceList inst…
Browse files Browse the repository at this point in the history
…ead of a Span of device pointers. This would enable client implementations to make use of the cached hash of the DeviceList.

PiperOrigin-RevId: 623324679
  • Loading branch information
cezheng authored and tensorflower-gardener committed Apr 10, 2024
1 parent 90f3897 commit 9c0eafe
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/ifrt/client.h
Expand Up @@ -158,7 +158,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {

// Returns a topology description for that covers the provided devices.
virtual absl::StatusOr<std::shared_ptr<const xla::PjRtTopologyDescription>>
GetTopologyForDevices(absl::Span<Device* const> devices) const = 0;
GetTopologyForDevices(const DeviceList& devices) const = 0;

// Returns the default layout on `device` for a buffer with `dtype` and
// single-shard dimensions `dims`.
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/ifrt/mock.cc
Expand Up @@ -175,7 +175,7 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
return delegated_->GetDefaultCompiler();
});
ON_CALL(*this, GetTopologyForDevices)
.WillByDefault([this](absl::Span<xla::ifrt::Device* const> devices) {
.WillByDefault([this](const xla::ifrt::DeviceList& devices) {
return delegated_->GetTopologyForDevices(devices);
});
ON_CALL(*this, GetDefaultLayoutForDevice)
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/ifrt/mock.h
Expand Up @@ -135,7 +135,7 @@ class MockClient final : public llvm::RTTIExtends<MockClient, Client> {
MOCK_METHOD(Compiler*, GetDefaultCompiler, (), (final));
MOCK_METHOD(
absl::StatusOr<std::shared_ptr<const xla::PjRtTopologyDescription>>,
GetTopologyForDevices, (absl::Span<xla::ifrt::Device* const> devices),
GetTopologyForDevices, (const xla::ifrt::DeviceList& devices),
(const, final));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<xla::PjRtLayout>>,
GetDefaultLayoutForDevice,
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/python/ifrt_proxy/client/client.h
Expand Up @@ -112,8 +112,7 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
return &default_compiler_;
}
absl::StatusOr<std::shared_ptr<const xla::PjRtTopologyDescription>>
GetTopologyForDevices(
absl::Span<xla::ifrt::Device* const> devices) const override {
GetTopologyForDevices(const xla::ifrt::DeviceList& devices) const override {
return absl::UnimplementedError(
"GetTopologyForDevices is not supported for the IFRT proxy client.");
}
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc
Expand Up @@ -227,7 +227,7 @@ absl::StatusOr<tsl::RCReference<Tuple>> PjRtClient::MakeTuple(
}

absl::StatusOr<std::shared_ptr<const xla::PjRtTopologyDescription>>
PjRtClient::GetTopologyForDevices(absl::Span<Device* const> devices) const {
PjRtClient::GetTopologyForDevices(const xla::ifrt::DeviceList& devices) const {
// TODO(parkers): Consider constructing a sub-slice topology based on the
// provided devices.
TF_ASSIGN_OR_RETURN(auto topology, pjrt_client_->GetTopologyDescription());
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/pjrt_ifrt/pjrt_compiler.h"
#include "xla/xla_data.pb.h"
#include "tsl/concurrency/ref_count.h"
Expand Down Expand Up @@ -154,7 +155,7 @@ class PjRtClient final
}

absl::StatusOr<std::shared_ptr<const xla::PjRtTopologyDescription>>
GetTopologyForDevices(absl::Span<Device* const> devices) const override;
GetTopologyForDevices(const DeviceList& devices) const override;

absl::StatusOr<std::unique_ptr<xla::PjRtLayout>> GetDefaultLayoutForDevice(
DType dtype, absl::Span<const int64_t> dims,
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/python/py_compile_only_client.cc
Expand Up @@ -224,8 +224,7 @@ class CompileOnlyIfRtClient final
const PjRtTopologyDescription& topology() const { return *topology_; }

absl::StatusOr<std::shared_ptr<const xla::PjRtTopologyDescription>>
GetTopologyForDevices(
absl::Span<ifrt::Device* const> devices) const override {
GetTopologyForDevices(const xla::ifrt::DeviceList& devices) const override {
return topology_;
}

Expand Down
7 changes: 4 additions & 3 deletions third_party/xla/xla/python/xla.cc
Expand Up @@ -463,7 +463,7 @@ NB_MODULE(xla_extension, m_nb) {
"get_topology_for_devices requires >= 1 devices.");
}
auto client = py_devices[0]->client();
std::vector<PjRtDevice*> ifrt_devices;
ifrt::DeviceList::Devices ifrt_devices;
ifrt_devices.reserve(py_devices.size());
for (const auto& py_device : py_devices) {
if (py_device->client().get() != client.get()) {
Expand All @@ -473,8 +473,9 @@ NB_MODULE(xla_extension, m_nb) {
}
ifrt_devices.push_back(py_device->device());
}
return xla::ValueOrThrow(client->ifrt_client()->GetTopologyForDevices(
absl::MakeSpan(ifrt_devices)));
ifrt::DeviceList device_list(std::move(ifrt_devices));
return xla::ValueOrThrow(
client->ifrt_client()->GetTopologyForDevices(device_list));
});

TF_CHECK_OK(PyArray::RegisterTypes(m_nb));
Expand Down

0 comments on commit 9c0eafe

Please sign in to comment.