diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 119bfb8d05..835bd742de 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -51,7 +51,11 @@ torch_version_at_least, ) -_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +_DEVICE = torch.device( + torch.accelerator.current_accelerator().type + if torch.accelerator.is_available() + else "cpu" +) class M(nn.Module): diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index ab25a38bb3..bfe0a9457f 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -34,11 +34,17 @@ quantize_int8_rowwise, ) from torchao.quantization.quant_api import quantize_ +from torchao.utils import get_current_accelerator_device if common_utils.SEED is None: common_utils.SEED = 1234 -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_DEVICES = ( + ["cpu"] + + (["cuda"] if torch.cuda.is_available() else []) + + (["xpu"] if torch.xpu.is_available() else []) +) +_DEVICE = get_current_accelerator_device() def _reset(): @@ -182,12 +188,14 @@ def test_int8_weight_only_training(self, compile, device): ], ) @parametrize("module_swap", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch.accelerator.is_available(), reason="GPU not available" + ) def test_int8_mixed_precision_training(self, compile, config, module_swap): _reset() bsize = 64 embed_dim = 64 - device = "cuda" + device = _DEVICE linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) @@ -221,7 +229,9 @@ def snr(ref, actual): @pytest.mark.skip("Flaky on CI") @parametrize("compile", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch.accelerator.is_available(), reason="GPU not available" + ) def test_bitnet_training(self, compile): # reference implementation # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf @@ -246,7 +256,7 @@ def forward(self, x): _reset() bsize = 4 embed_dim = 32 - device = "cuda" + device = _DEVICE # only use 1 matmul shape to reduce triton autotune time model_ref = nn.Sequential( @@ -342,7 +352,7 @@ def _run_subtest(self, args): dropout_p=0, ) torch.manual_seed(42) - base_model = Transformer(model_args).cuda() + base_model = Transformer(model_args).to(_DEVICE) fsdp_model = copy.deepcopy(base_model) quantize_(base_model.layers, quantize_fn) @@ -362,7 +372,7 @@ def _run_subtest(self, args): torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): - inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE) fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() fsdp_loss.backward() @@ -387,14 +397,18 @@ def _run_subtest(self, args): ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch.accelerator.is_available(), reason="GPU not available" + ) def test_precompute_bitnet_scale(self): from torchao.prototype.quantized_training.bitnet import ( get_bitnet_scale, precompute_bitnet_scale_for_fsdp, ) - model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda() + model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to( + _DEVICE + ) model_fsdp = copy.deepcopy(model) quantize_(model_fsdp, bitnet_training()) fully_shard(model_fsdp)