diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 983f701849..5a6e89c25f 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() - mesh.device_type = "cuda" + mesh._device_type = "cuda" # Shard the models up_dist = self.colwise_shard(up_quant, mesh)