-
Notifications
You must be signed in to change notification settings - Fork 566
Re-introduce "XLA_USE_32BIT_LONG" flag #8571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jeffhataws
merged 3 commits into
pytorch:master
from
rpsilva-aws:rpsilva_use_int32_flag
Jan 16, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,82 @@ | ||
import os | ||
import sys | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.utils.utils as xu | ||
import unittest | ||
|
||
|
||
def check_env_flag(name, default=''): | ||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] | ||
class XlaDataTypeTest(unittest.TestCase): | ||
|
||
def setUp(cls): | ||
cls.original_env = { | ||
'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'), | ||
'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'), | ||
'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG') | ||
} | ||
|
||
class XlaDataTypeTest(unittest.TestCase): | ||
def tearDown(self): | ||
for key, value in self.original_env.items(): | ||
if value is None: | ||
os.environ.pop(key, None) | ||
else: | ||
os.environ[key] = value | ||
|
||
def test_datatype_f32(self): | ||
t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) | ||
t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) | ||
t3 = torch.div(t1, t2, rounding_mode='floor') | ||
assert t3.dtype == torch.float | ||
def _set_env(self, **kwargs): | ||
for key, value in kwargs.items(): | ||
os.environ[key] = value | ||
|
||
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) | ||
device_data_hlo = hlo_text.split('\n')[1] | ||
assert 'xla::device_data' in device_data_hlo, device_data_hlo | ||
if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'): | ||
assert 'bf16' in device_data_hlo, device_data_hlo | ||
elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'): | ||
assert 'f16' in device_data_hlo, device_data_hlo | ||
else: | ||
assert 'f32' in device_data_hlo, device_data_hlo | ||
|
||
def test_datatype_f64(self): | ||
t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) | ||
t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) | ||
t3 = torch.div(t1, t2, rounding_mode='floor') | ||
assert t3.dtype == torch.double | ||
def _test_datatype(self, dtype, expected_type, op): | ||
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) | ||
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) | ||
t3 = op(t1, t2) | ||
self.assertEqual(t3.dtype, dtype) | ||
|
||
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) | ||
device_data_hlo = hlo_text.split('\n')[1] | ||
assert 'xla::device_data' in device_data_hlo, device_data_hlo | ||
if check_env_flag('XLA_USE_BF16'): | ||
assert 'bf16' in device_data_hlo, device_data_hlo | ||
elif check_env_flag('XLA_USE_FP16'): | ||
assert 'f16' in device_data_hlo, device_data_hlo | ||
elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag( | ||
'XLA_DOWNCAST_FP16'): | ||
assert 'f32' in device_data_hlo, device_data_hlo | ||
else: | ||
assert 'f64' in device_data_hlo, device_data_hlo | ||
device_data_hlo = hlo_text.split('\n')[2] | ||
self.assertIn('xla::device_data', device_data_hlo) | ||
self.assertIn(expected_type, device_data_hlo) | ||
|
||
def test_datatype_use_bf16(self): | ||
self._set_env(XLA_USE_BF16='1') | ||
self._test_datatype(torch.double, 'bf16', torch.floor_divide) | ||
self._test_datatype(torch.float, 'bf16', torch.floor_divide) | ||
|
||
def test_datatype_downcast_bf16(self): | ||
self._set_env(XLA_DOWNCAST_BF16='1') | ||
self._test_datatype(torch.double, 'bf16', torch.floor_divide) | ||
self._test_datatype(torch.float, 'bf16', torch.floor_divide) | ||
|
||
def test_datatype_use_32bit_long(self): | ||
self._set_env(XLA_USE_32BIT_LONG='1') | ||
self._test_datatype(torch.int64, 's32', torch.add) | ||
self._test_datatype(torch.uint64, 'u32', torch.add) | ||
|
||
def test_module_to_dtype(self): | ||
device = torch_xla.device() | ||
linear = torch.nn.Linear( | ||
5, 10, dtype=torch.float32).to(device).to(torch.bfloat16) | ||
input = torch.randn( | ||
10, | ||
5, | ||
).to(device).to(torch.bfloat16) | ||
input = torch.randn(10, 5).to(device).to(torch.bfloat16) | ||
xm.mark_step() | ||
res = linear(input) | ||
|
||
hlo_text = torch_xla._XLAC._get_xla_tensors_text([res]) | ||
res_hlo = hlo_text.split('\n')[-3] | ||
assert 'bf16' in res_hlo, res_hlo | ||
self.assertIn('bf16', res_hlo) | ||
|
||
linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight | ||
]).split('\n')[-3] | ||
assert 'bf16' in linear_weight_hlo, linear_weight_hlo | ||
self.assertIn('bf16', linear_weight_hlo) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) | ||
suite = unittest.TestSuite() | ||
suite.addTest(XlaDataTypeTest("test_datatype_use_bf16")) | ||
suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16")) | ||
suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long")) | ||
suite.addTest(XlaDataTypeTest("test_module_to_dtype")) | ||
runner = unittest.TextTestRunner(failfast=True) | ||
result = runner.run(suite) | ||
sys.exit(0 if result.wasSuccessful() else 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.