diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index d59073df140f..1f5025239b32 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -160,9 +160,7 @@ function run_xla_op_tests1 { run_test "$CDIR/dynamo/test_graph_input_matcher.py" run_test "$CDIR/dynamo/test_dynamo_config.py" run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py" - #run_test "$CDIR/test_data_type.py" - run_use_bf16 "$CDIR/test_data_type.py" - run_downcast_bf16 "$CDIR/test_data_type.py" + run_test "$CDIR/test_data_type.py" #run_test "$CDIR/test_fp8.py" run_xla_ir_debug "$CDIR/test_env_var_mapper.py" run_xla_hlo_debug "$CDIR/test_env_var_mapper.py" diff --git a/test/neuron/test_neuron_data_types.py b/test/neuron/test_neuron_data_types.py index 1687e712ebc7..ae9b3db2e3a4 100644 --- a/test/neuron/test_neuron_data_types.py +++ b/test/neuron/test_neuron_data_types.py @@ -27,10 +27,10 @@ def test_datatypes(self): (torch.double, "f32", torch.floor_divide), (torch.int16, "s32", torch.add), (torch.int32, "s32", torch.add), - (torch.int64, "s32", torch.add), + (torch.int64, "s64", torch.add), (torch.uint16, "u32", torch.add), (torch.uint32, "u32", torch.add), - (torch.uint64, "u32", torch.add)] + (torch.uint64, "u64", torch.add)] for dtype, op_xla_dtype, op in test_cases: with self.subTest(dtype=dtype, op_xla_dtype=op_xla_dtype, op=op): diff --git a/test/run_tests.sh b/test/run_tests.sh index da0ebc15a06c..ba82255cb045 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -178,8 +178,6 @@ function run_xla_op_tests1 { run_test "$CDIR/dynamo/test_dynamo_config.py" run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py" run_test "$CDIR/test_data_type.py" - run_use_bf16 "$CDIR/test_data_type.py" - run_downcast_bf16 "$CDIR/test_data_type.py" run_test "$CDIR/test_fp8.py" run_xla_ir_debug run_test "$CDIR/test_env_var_mapper.py" run_xla_hlo_debug run_test "$CDIR/test_env_var_mapper.py" diff --git a/test/test_data_type.py b/test/test_data_type.py index 8e06e15b40a5..556f35d4a599 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -1,73 +1,82 @@ import os +import sys +import unittest import torch import torch_xla import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu -import unittest -def check_env_flag(name, default=''): - return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] +class XlaDataTypeTest(unittest.TestCase): + def setUp(cls): + cls.original_env = { + 'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'), + 'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'), + 'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG') + } -class XlaDataTypeTest(unittest.TestCase): + def tearDown(self): + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value - def test_datatype_f32(self): - t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) - t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) - t3 = torch.div(t1, t2, rounding_mode='floor') - assert t3.dtype == torch.float + def _set_env(self, **kwargs): + for key, value in kwargs.items(): + os.environ[key] = value - hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) - device_data_hlo = hlo_text.split('\n')[1] - assert 'xla::device_data' in device_data_hlo, device_data_hlo - if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'): - assert 'bf16' in device_data_hlo, device_data_hlo - elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'): - assert 'f16' in device_data_hlo, device_data_hlo - else: - assert 'f32' in device_data_hlo, device_data_hlo - - def test_datatype_f64(self): - t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) - t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) - t3 = torch.div(t1, t2, rounding_mode='floor') - assert t3.dtype == torch.double + def _test_datatype(self, dtype, expected_type, op): + t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t3 = op(t1, t2) + self.assertEqual(t3.dtype, dtype) hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) - device_data_hlo = hlo_text.split('\n')[1] - assert 'xla::device_data' in device_data_hlo, device_data_hlo - if check_env_flag('XLA_USE_BF16'): - assert 'bf16' in device_data_hlo, device_data_hlo - elif check_env_flag('XLA_USE_FP16'): - assert 'f16' in device_data_hlo, device_data_hlo - elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag( - 'XLA_DOWNCAST_FP16'): - assert 'f32' in device_data_hlo, device_data_hlo - else: - assert 'f64' in device_data_hlo, device_data_hlo + device_data_hlo = hlo_text.split('\n')[2] + self.assertIn('xla::device_data', device_data_hlo) + self.assertIn(expected_type, device_data_hlo) + + def test_datatype_use_bf16(self): + self._set_env(XLA_USE_BF16='1') + self._test_datatype(torch.double, 'bf16', torch.floor_divide) + self._test_datatype(torch.float, 'bf16', torch.floor_divide) + + def test_datatype_downcast_bf16(self): + self._set_env(XLA_DOWNCAST_BF16='1') + self._test_datatype(torch.double, 'bf16', torch.floor_divide) + self._test_datatype(torch.float, 'bf16', torch.floor_divide) + + def test_datatype_use_32bit_long(self): + self._set_env(XLA_USE_32BIT_LONG='1') + self._test_datatype(torch.int64, 's32', torch.add) + self._test_datatype(torch.uint64, 'u32', torch.add) def test_module_to_dtype(self): device = torch_xla.device() linear = torch.nn.Linear( 5, 10, dtype=torch.float32).to(device).to(torch.bfloat16) - input = torch.randn( - 10, - 5, - ).to(device).to(torch.bfloat16) + input = torch.randn(10, 5).to(device).to(torch.bfloat16) xm.mark_step() res = linear(input) hlo_text = torch_xla._XLAC._get_xla_tensors_text([res]) res_hlo = hlo_text.split('\n')[-3] - assert 'bf16' in res_hlo, res_hlo + self.assertIn('bf16', res_hlo) linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight ]).split('\n')[-3] - assert 'bf16' in linear_weight_hlo, linear_weight_hlo + self.assertIn('bf16', linear_weight_hlo) if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + suite = unittest.TestSuite() + suite.addTest(XlaDataTypeTest("test_datatype_use_bf16")) + suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16")) + suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long")) + suite.addTest(XlaDataTypeTest("test_module_to_dtype")) + runner = unittest.TextTestRunner(failfast=True) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 99f8f6aa628b..76959214300a 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -46,6 +46,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py" python3 "$TEST_CDIR/quantized_ops/test_dot_general.py" run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py" +python3 "$TEST_CDIR/test_data_type.py" # run examples, each test should takes <2 minutes python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py" diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index e214e7a47a77..a600baab999b 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -178,9 +178,7 @@ def _setup_tpu_vm_library_path() -> bool: def _check_deprecated_env_var(): - deprecated_env_vars = [ - 'XLA_USE_FP16', 'XLA_DOWNCAST_FP16', 'XLA_USE_32BIT_LONG' - ] + deprecated_env_vars = ['XLA_USE_FP16', 'XLA_DOWNCAST_FP16'] for env_var in deprecated_env_vars: if os.environ.get(env_var): warnings.warn(f"The environment variable '{env_var}' is deprecated " diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 759c045f8f20..484151c6d5ed 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -30,6 +30,17 @@ bool ShouldDowncastToBF16() { return downcast_bf16; } +bool ShouldUse32BitLong() { + bool use_32bit_long = + runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false); + if (use_32bit_long) { + std::cout + << "XLA_USE_32BIT_LONG will be deprecated after the 2.6 release\n"; + TF_LOG(INFO) << "Using 32bit integers for kLong values"; + } + return use_32bit_long; +} + bool UseBF16() { static bool use_bf16 = ShouldUseBF16(); return use_bf16; @@ -40,6 +51,11 @@ bool DowncastBF16() { return downcast_bf16; } +bool Use32BitLong() { + static bool use_32bit_long = ShouldUse32BitLong(); + return use_32bit_long; +} + } // namespace at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { @@ -143,11 +159,9 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 : xla::PrimitiveType::S16; case xla::PrimitiveType::S64: - return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 - : xla::PrimitiveType::S64; + return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: - return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32 - : xla::PrimitiveType::U64; + return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; case xla::PrimitiveType::C128: return xla::PrimitiveType::C128; default: