Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/eager/test_eager_all_reduce_in_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _mp_fn(index):

device = torch_xla.device()

if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):
if xm.xla_device_hw(device) not in ('TPU', 'NEURON'):
return

ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device)
Expand Down
2 changes: 0 additions & 2 deletions test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def _ddp_init(index: int = ...):
def test_ddp_init(self):
pjrt.run_multiprocess(self._ddp_init)

@absltest.skipIf(xr.device_type() == 'CUDA',
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)

Expand Down
30 changes: 12 additions & 18 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(
xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}'
Expand Down Expand Up @@ -108,9 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}'
Expand Down Expand Up @@ -168,9 +166,8 @@ def test_single_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{replicated}'
Expand Down Expand Up @@ -340,9 +337,8 @@ def test_single_host_replicated_cpu(self):
# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}
# e.g.: sharding={replicated}

@unittest.skipIf(
xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_multi_host_tiled_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
Expand Down Expand Up @@ -468,9 +464,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`.")
def test_multi_host_partial_replication_tpu(self):
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
Expand Down Expand Up @@ -560,9 +555,8 @@ def test_multi_host_partial_replication_tpu(self):
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
@unittest.skipIf(xr.device_type() == 'CPU',
f"Requires PJRT_DEVICE set to `TPU`.")
@unittest.skipIf(xr.global_runtime_device_count() != 8,
f"Limit test num_devices to 8 for function consistency")
def test_multi_host_replicated_tpu(self):
Expand Down
3 changes: 0 additions & 3 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
# the gradient checkpointing A/B test run for it.
SKIP_GRADIENT_CHECKPOINTING: bool = False

skipOnGpu = unittest.skipIf(xr.device_type() == 'CUDA',
'https://github.com/pytorch/xla/issues/9128')


@contextmanager
def extended_argv(args):
Expand Down
14 changes: 1 addition & 13 deletions test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,6 @@ def test_global_runtime_device_count(self):
self.assertGreaterEqual(xr.global_runtime_device_count(), 4)
elif device_type == "CPU":
self.assertEqual(xr.global_runtime_device_count(), 1)
elif device_type == 'CUDA':
command = 'nvidia-smi --list-gpus | wc -l'
result = subprocess.run(
command,
capture_output=True,
shell=True,
check=True,
text=True,
)
expected_gpu_cnt = int(result.stdout)
self.assertEqual(xr.global_runtime_device_count(), expected_gpu_cnt)

def test_addressable_runtime_device_count(self):
device_type = os.environ['PJRT_DEVICE']
Expand Down Expand Up @@ -145,8 +134,7 @@ class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest):
def setUpClass(cls):
super().setUpClass()

@unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'],
f"TPU/GPU autocast test.")
@unittest.skipIf(xr.device_type() not in ('TPU',), f"TPU autocast test.")
def test_xla_autocast_api(self):
device = torch_xla.device()
t1 = torch.ones([2, 3], device=device, dtype=torch.float32)
Expand Down
20 changes: 0 additions & 20 deletions test/test_assume_pure_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ def setUp(self):

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding(self):
x = torch.randn((8, 4, 5, 128), device='xla')
result = assume_pure(mark_sharding)(x, self.spmd_mesh,
Expand All @@ -52,10 +48,6 @@ def test_assume_pure_works_with_mark_sharding(self):

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding_with_gradients(self):
x = torch.randn((8, 4, 5, 128)).to('xla').requires_grad_(True)
result = assume_pure(mark_sharding_with_gradients)(
Expand All @@ -71,10 +63,6 @@ def test_assume_pure_works_with_mark_sharding_with_gradients(self):

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding_nested(self):
mesh = get_2d_mesh("model", "batch")
set_global_mesh(mesh)
Expand All @@ -88,10 +76,6 @@ def test_assume_pure_works_with_mark_sharding_nested(self):

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self):
mesh = get_2d_mesh("model", "batch")
set_global_mesh(mesh)
Expand All @@ -109,10 +93,6 @@ def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self):

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_convert_to_jax_mesh(self):
jax_mesh = self.spmd_mesh.get_jax_mesh()
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
Expand Down
9 changes: 2 additions & 7 deletions test/test_fsdp_auto_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ def forward(self, x):
hidden2 = self.fc2(x)
return hidden1, hidden2

@unittest.skipIf(
xr.device_type() == 'CUDA',
"This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
)
def test(self):
dev = torch_xla.device()
input = torch.zeros([16, 16], device=dev)
Expand All @@ -49,13 +45,12 @@ def test(self):

def _mp_fn(index):
device = torch_xla.device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
if xm.xla_device_hw(device) in ('TPU',):
test = unittest.main(exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
else:
print(
'Default device {} is not a TPU or CUDA device'.format(device),
file=sys.stderr)
'Default device {} is not a TPU device'.format(device), file=sys.stderr)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _mp_fn(index):
device = torch_xla.device()
world_size = xr.world_size()
input_list_size = 5
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'NEURON'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down Expand Up @@ -161,7 +161,7 @@ def _mp_fn(index):
# TODO: add test for torch.compile when support for list input is ready

else:
print(f'{device} is not a TPU or GPU device', file=sys.stderr)
print(f'{device} is not a TPU device', file=sys.stderr)


if __name__ == '__main__':
Expand Down
5 changes: 2 additions & 3 deletions test/test_mp_distributed_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def _mp_fn(index):
device = torch_xla.device()

if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
if xm.xla_device_hw(device) in ('TPU',):
world_size = xr.world_size()
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
torch.manual_seed(11)
Expand All @@ -34,8 +34,7 @@ def _mp_fn(index):
sys.exit(1)
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)
'Default device {} is not a TPU device'.format(device), file=sys.stderr)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions test/test_mp_early_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def _mp_fn():
dist.init_process_group('xla', init_method='xla://')
device = torch_xla.device()
if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
if xm.xla_device_hw(device) in ('TPU',):
train_loader = xu.SampleGenerator(
data=torch.zeros(1, 12), sample_count=1024)
train_loader = pl.MpDeviceLoader(train_loader, device)
Expand All @@ -23,7 +23,7 @@ def _mp_fn():
if step > max_steps:
break
else:
print(f'{device} is not a TPU or GPU device', file=sys.stderr)
print(f'{device} is not a TPU device', file=sys.stderr)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _mp_fn(index):
shard_size = 2
input_list_size = 5

if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'CPU']:
if xm.xla_device_hw(device) in ['TPU', 'CPU']:
rand = torch.rand((32, shard_size * world_size, 32))
xrand = rand.to(device)

Expand Down
6 changes: 2 additions & 4 deletions test/test_torch_distributed_fsdp_frozen_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@

def _mp_fn(index):
dev = torch_xla.device()
if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'):
print(
'Default device {} is not a TPU or CUDA device'.format(dev),
file=sys.stderr)
if xm.xla_device_hw(dev) not in ('TPU',):
print('Default device {} is not a TPU device'.format(dev), file=sys.stderr)
return

model = nn.Linear(1024, 1024)
Expand Down
5 changes: 1 addition & 4 deletions test/torch_distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.test.test_utils import skipIfCUDA

# Setup import folders.
xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
Expand All @@ -25,7 +24,7 @@ def _ddp_correctness(rank,
# We cannot run this guard before XMP,
# see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing.
device = torch_xla.device()
if xm.xla_device_hw(device) not in ('TPU', 'CUDA'):
if xm.xla_device_hw(device) not in ('TPU',):
print(
'Default device {} is not a TPU device'.format(device),
file=sys.stderr)
Expand All @@ -39,8 +38,6 @@ def _ddp_correctness(rank,
def test_ddp_correctness(self):
torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug))

# Ref: https://github.com/pytorch/xla/pull/8593
@skipIfCUDA("GPU CI is failing")
def test_ddp_correctness_with_gradient_as_bucket_view(self):
torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug, True))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = torch_xla.device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'NEURON'):
world_size = xr.world_size()
rank = xr.global_ordinal()

Expand All @@ -30,8 +30,7 @@ def _mp_fn(index):
assert torch.all(xoutput0.cpu() == expected0), f'{xoutput0} != {expected0}'
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)
'Default device {} is not a TPU device'.format(device), file=sys.stderr)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = torch_xla.device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'NEURON'):
world_size = xr.world_size()
dist.init_process_group('xla', init_method='xla://')
# note that we can't use torch.tensor(torch.distributed.get_rank()) directly
Expand All @@ -25,8 +25,7 @@ def _mp_fn(index):
xla_rank_tensor.cpu() == expected), f'{xla_rank_tensor} != {expected}'
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)
'Default device {} is not a TPU device'.format(device), file=sys.stderr)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = torch_xla.device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'NEURON'):
world_size = xr.world_size()
rank = xr.global_ordinal()

Expand All @@ -35,8 +35,7 @@ def _mp_fn(index):
scale)) == torch.tensor(True)
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)
'Default device {} is not a TPU device'.format(device), file=sys.stderr)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _mp_fn(index):
device = torch_xla.device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
if xm.xla_device_hw(device) in ('TPU', 'NEURON'):
world_size = xr.world_size()
rank = xr.global_ordinal()

Expand All @@ -31,8 +31,7 @@ def _mp_fn(index):
xinputs.cpu() == expected), f'trial {i}, {xinputs} != {expected}'
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)
'Default device {} is not a TPU device'.format(device), file=sys.stderr)


if __name__ == '__main__':
Expand Down
Loading