Skip to content

Commit f996701

Browse files
authored
Update Python device API for SPMD (#5129)
* Make python Api to respect the virtual device when SPMD is enabled * fix typo
1 parent 901d154 commit f996701

12 files changed

+152
-36
lines changed

test/pjrt/test_runtime_tpu.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,22 @@ def test_spawn_threads(self):
188188
{i: torch.device(f'xla:{i}') for i in range(self.num_devices)})
189189

190190
@staticmethod
191-
def _device_attributes():
192-
return xr.device_attributes(str(xm.xla_device()))
191+
def _runtime_device_attributes():
192+
return xr.runtime_device_attributes(str(xm.xla_device()))
193193

194-
def test_device_attributes(self):
195-
result = pjrt.run_multiprocess(self._device_attributes)
194+
def test_runtime_device_attributes(self):
195+
result = pjrt.run_multiprocess(self._runtime_device_attributes)
196196
for device in result.values():
197197
self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys()))
198198
self.assertIsInstance(device['coords'], list)
199199
self.assertIsInstance(device['core_on_chip'], int)
200200

201201
@staticmethod
202-
def _global_device_attributes():
203-
return xr.global_device_attributes()
202+
def _global_runtime_device_attributes():
203+
return xr.global_runtime_device_attributes()
204204

205-
def test_global_device_attributes(self):
206-
results = pjrt.run_multiprocess(self._global_device_attributes)
205+
def test_global_runtime_device_attributes(self):
206+
results = pjrt.run_multiprocess(self._global_runtime_device_attributes)
207207
for result in results.values():
208208
for device in result:
209209
self.assertCountEqual(['coords', 'core_on_chip'], list(device.keys()))
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import unittest
2+
import os
3+
import sys
4+
5+
import torch
6+
import torch_xla
7+
import torch_xla.core.xla_model as xm
8+
import test_xla_sharding_base
9+
10+
11+
class BasicXMAPITest(test_xla_sharding_base.XlaShardingTest):
12+
13+
@classmethod
14+
def setUpClass(cls):
15+
os.environ["XLA_USE_SPMD"] = "1"
16+
super().setUpClass()
17+
18+
def test_get_xla_supported_devices(self):
19+
device_type = os.environ['PJRT_DEVICE']
20+
devices = xm.get_xla_supported_devices(device_type)
21+
self.assertEqual(len(devices), 1)
22+
23+
def test_world_size(self):
24+
self.assertEqual(xm.xrt_world_size(), 1)
25+
26+
def test_get_ordinal(self):
27+
self.assertEqual(xm.get_ordinal(), 0)
28+
29+
def test_get_local_ordinal(self):
30+
self.assertEqual(xm.get_local_ordinal(), 0)
31+
32+
def test_is_master_ordinal(self):
33+
self.assertTrue(xm.is_master_ordinal())
34+
35+
def test_xla_device(self):
36+
device = xm.xla_device()
37+
self.assertEqual(device, torch.device('xla:0'))
38+
39+
def test_xla_real_devices(self):
40+
device = xm.xla_device()
41+
device_type = os.environ['PJRT_DEVICE']
42+
self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0'])
43+
44+
def test_xla_device_hw(self):
45+
device = xm.xla_device()
46+
device_type = os.environ['PJRT_DEVICE']
47+
replication_devices = xm.xla_replication_devices([device])
48+
self.assertEqual(xm.xla_device_hw(device), device_type)
49+
50+
def test_xla_replication_devices(self):
51+
device = xm.xla_device()
52+
device_type = os.environ['PJRT_DEVICE']
53+
replication_devices = xm.xla_replication_devices([device])
54+
self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0'])
55+
56+
57+
if __name__ == '__main__':
58+
test = unittest.main()
59+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_train_spmd_imagenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def train_imagenet():
197197

198198
input_mesh = None
199199
if FLAGS.sharding:
200-
num_devices = xr.global_device_count()
200+
num_devices = xr.global_runtime_device_count()
201201
device_ids = np.arange(num_devices)
202202
# Model sharding
203203
if 'conv' in FLAGS.sharding:

test/spmd/test_train_spmd_linear_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import nn
55
import torch_xla
66
import torch_xla.core.xla_model as xm
7+
import torch_xla.runtime as xr
78
import torch_xla.debug.profiler as xp
89
import torch_xla.distributed.parallel_loader as pl
910
import torch_xla.experimental.xla_sharding as xs
@@ -66,7 +67,7 @@ def train():
6667
torch.manual_seed(42)
6768
model = SimpleLinear().to(device)
6869

69-
num_devices = len(xm.get_xla_supported_devices())
70+
num_devices = xr.global_runtime_device_count()
7071
print(f'num_devices: {num_devices}')
7172
# Define a mesh with all devices along one axis
7273
mesh_shape = (num_devices, 1)

test/spmd/test_xla_distributed_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_sharded_to_unsharded(self):
112112

113113
# TODO(jonbolin): Enable tests for resharding into coarser meshes
114114
@unittest.skip("View assignment with virtual device is not yet supported")
115-
@unittest.skipIf(xr.global_device_count() == 1,
115+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
116116
"Multiple devices needed to change mesh")
117117
def test_different_device_mesh(self):
118118
dim = self.n_devices // 2
@@ -170,7 +170,7 @@ def test_local_load_plan(self):
170170
# If unsharded, there should be a single ReadItem per model parameter
171171
self.assertEqual(parameter_count, len(plan.items))
172172

173-
@unittest.skipIf(xr.global_device_count() == 1,
173+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
174174
"Multiple devices required to shard tensors")
175175
def test_resolve_and_commit_sharded_tensor(self):
176176
model = self._get_sharded_model()
@@ -261,7 +261,7 @@ def _write_item_assertions(plan, n_devices, parameter_count):
261261
parameter_count = len(list(model.parameters()))
262262
_write_item_assertions(plan, self.n_devices, parameter_count)
263263

264-
@unittest.skipIf(xr.global_device_count() == 1,
264+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
265265
"Multiple devices required to shard tensors")
266266
def test_resolve_shard_data(self):
267267
model = self._get_sharded_model()

test/spmd/test_xla_sharding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
1919
import test_xla_sharding_base
2020

21+
import torch_xla.core.xla_env_vars as xenv
22+
import torch_xla.utils.utils as xu
23+
2124

2225
class BasicShardingTest(test_xla_sharding_base.XlaShardingTest):
2326

@@ -649,7 +652,7 @@ def test_2d_tensor_3d_mesh(self):
649652

650653
@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU v2")
651654
@unittest.skipUnless(
652-
xm.get_xla_supported_devices("TPU"),
655+
xu.getenv_as(xenv.PJRT_DEVICE, str) == "TPU",
653656
f"Requires PJRT_DEVICE set to `TPU`.")
654657
def test_hybrid_mesh_shape(self):
655658
mesh = self._get_mesh((1, self.n_devices))
@@ -659,7 +662,7 @@ def test_hybrid_mesh_shape(self):
659662
hybrid_mesh.get_logical_mesh().shape)
660663

661664
@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU v2")
662-
@patch('torch_xla.runtime.global_device_attributes')
665+
@patch('torch_xla.runtime.global_runtime_device_attributes')
663666
@patch('torch_xla.core.xla_model.xla_device_hw')
664667
def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
665668
# mock device attributes for 2 slices of v4-8

test/spmd/test_xla_sharding_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
import torch_xla.core.xla_model as xm
66
import torch_xla.experimental.xla_sharding as xs
77
import torch_xla.runtime as xr
8+
import torch_xla.core.xla_env_vars as xenv
9+
import torch_xla.utils.utils as xu
810

911

10-
@unittest.skipIf(not xr.using_pjrt() or xm.get_xla_supported_devices("GPU"),
12+
@unittest.skipIf(not xr.using_pjrt() or
13+
xu.getenv_as(xenv.PJRT_DEVICE, str) == "GPU",
1114
f"Requires PJRT_DEVICE set to `TPU` or `CPU`.")
1215
class XlaShardingTest(unittest.TestCase):
1316

@@ -29,7 +32,7 @@ def forward(self, x):
2932

3033
@classmethod
3134
def setUpClass(cls):
32-
cls.n_devices = len(xm.get_xla_supported_devices())
35+
cls.n_devices = xr.global_runtime_device_count()
3336
cls.device_ids = np.array(range(cls.n_devices))
3437

3538
def _get_mesh(self, mesh_shape, device_ids=None):

test/tpu/xla_test_job.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ spec:
4848
python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
4949
python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
5050
python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py
51+
python3 /src/pytorch/xla/test/spmd/test_spmd_xla_model_api.py
5152
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v
5253
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v
5354
python3 /src/pytorch/xla/test/test_autocast.py

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -836,12 +836,47 @@ void InitXlaModuleBindings(py::module m) {
836836
[](const at::Tensor& tensor) { return GetTensorViewAliasId(tensor); });
837837
m.def("_xla_get_tensor_id",
838838
[](const at::Tensor& tensor) { return GetTensorId(tensor); });
839-
m.def("_xla_get_devices",
839+
m.def("_xla_get_devices", []() {
840+
if (UseVirtualDevice()) {
841+
// Under SPMD context, there is only one virtual devices from user
842+
// perspective.
843+
std::vector<std::string> all_devices =
844+
runtime::GetComputationClient()->GetAllDevices();
845+
all_devices.resize(1);
846+
return all_devices;
847+
} else {
848+
return runtime::GetComputationClient()->GetLocalDevices();
849+
}
850+
});
851+
m.def("_xla_num_devices", []() -> int64_t {
852+
if (UseVirtualDevice()) {
853+
return 1;
854+
} else {
855+
return runtime::GetComputationClient()->GetNumDevices();
856+
}
857+
});
858+
m.def("_xla_get_all_devices", []() {
859+
std::vector<std::string> all_devices =
860+
runtime::GetComputationClient()->GetAllDevices();
861+
if (UseVirtualDevice()) {
862+
// Under SPMD context, there is only one virtual devices from user
863+
// perspective.
864+
std::vector<std::string> devices = {all_devices[0]};
865+
return devices;
866+
} else {
867+
return all_devices;
868+
}
869+
});
870+
m.def("_xla_get_runtime_devices",
840871
[]() { return runtime::GetComputationClient()->GetLocalDevices(); });
841-
m.def("_xla_num_devices",
842-
[]() { return runtime::GetComputationClient()->GetNumDevices(); });
843-
m.def("_xla_get_all_devices",
844-
[]() { return runtime::GetComputationClient()->GetAllDevices(); });
872+
m.def("_xla_num_runtime_devices", []() -> int64_t {
873+
return runtime::GetComputationClient()->GetNumDevices();
874+
});
875+
m.def("_xla_get_all_runtime_devices", []() {
876+
std::vector<std::string> all_devices =
877+
runtime::GetComputationClient()->GetAllDevices();
878+
return all_devices;
879+
});
845880
m.def("_xla_real_devices", [](const std::vector<std::string>& devices) {
846881
std::vector<std::string> xla_devices;
847882
{

torch_xla/experimental/pjrt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
aliases = [
99
runtime.addressable_device_count,
10-
runtime.device_attributes,
1110
runtime.device_type,
12-
runtime.global_device_attributes,
1311
runtime.global_device_count,
1412
runtime.global_ordinal,
1513
runtime.local_device_count,
@@ -28,6 +26,9 @@
2826
]
2927

3028
rendezvous = deprecated(this_module, xm.xla_rendezvous)
29+
device_attributes = deprecated(this_module, runtime.runtime_device_attributes)
30+
global_device_attributes = deprecated(this_module,
31+
runtime.global_runtime_device_attributes)
3132

3233
for alias in aliases:
3334
register_deprecated(this_module, alias)

0 commit comments

Comments
 (0)