@@ -836,12 +836,47 @@ void InitXlaModuleBindings(py::module m) {
836836 [](const at::Tensor& tensor) { return GetTensorViewAliasId (tensor); });
837837 m.def (" _xla_get_tensor_id" ,
838838 [](const at::Tensor& tensor) { return GetTensorId (tensor); });
839- m.def (" _xla_get_devices" ,
839+ m.def (" _xla_get_devices" , []() {
840+ if (UseVirtualDevice ()) {
841+ // Under SPMD context, there is only one virtual devices from user
842+ // perspective.
843+ std::vector<std::string> all_devices =
844+ runtime::GetComputationClient ()->GetAllDevices ();
845+ all_devices.resize (1 );
846+ return all_devices;
847+ } else {
848+ return runtime::GetComputationClient ()->GetLocalDevices ();
849+ }
850+ });
851+ m.def (" _xla_num_devices" , []() -> int64_t {
852+ if (UseVirtualDevice ()) {
853+ return 1 ;
854+ } else {
855+ return runtime::GetComputationClient ()->GetNumDevices ();
856+ }
857+ });
858+ m.def (" _xla_get_all_devices" , []() {
859+ std::vector<std::string> all_devices =
860+ runtime::GetComputationClient ()->GetAllDevices ();
861+ if (UseVirtualDevice ()) {
862+ // Under SPMD context, there is only one virtual devices from user
863+ // perspective.
864+ std::vector<std::string> devices = {all_devices[0 ]};
865+ return devices;
866+ } else {
867+ return all_devices;
868+ }
869+ });
870+ m.def (" _xla_get_runtime_devices" ,
840871 []() { return runtime::GetComputationClient ()->GetLocalDevices (); });
841- m.def (" _xla_num_devices" ,
842- []() { return runtime::GetComputationClient ()->GetNumDevices (); });
843- m.def (" _xla_get_all_devices" ,
844- []() { return runtime::GetComputationClient ()->GetAllDevices (); });
872+ m.def (" _xla_num_runtime_devices" , []() -> int64_t {
873+ return runtime::GetComputationClient ()->GetNumDevices ();
874+ });
875+ m.def (" _xla_get_all_runtime_devices" , []() {
876+ std::vector<std::string> all_devices =
877+ runtime::GetComputationClient ()->GetAllDevices ();
878+ return all_devices;
879+ });
845880 m.def (" _xla_real_devices" , [](const std::vector<std::string>& devices) {
846881 std::vector<std::string> xla_devices;
847882 {
0 commit comments