Skip to content

Commit

Permalink
Reenable some BF16 tests on CUDA (#48805)
Browse files Browse the repository at this point in the history
Summary:
Fixes #{issue number}

Pull Request resolved: #48805

Reviewed By: agolynski

Differential Revision: D25375885

Pulled By: ailzhang

fbshipit-source-id: 2e19fe725ae9450bd1a2bc4e2d308c59b9f94fac
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Dec 8, 2020
1 parent 7629612 commit e3893b8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
3 changes: 1 addition & 2 deletions test/test_tensor_creation_ops.py
Expand Up @@ -14,7 +14,7 @@
IS_WINDOWS)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
onlyCPU, skipCUDAIfNotRocm, largeTensorTest, precisionOverride, dtypes,
onlyCPU, largeTensorTest, precisionOverride, dtypes,
onlyCUDA, skipCPUIf, dtypesIfCUDA, dtypesIfCPU)

# TODO: refactor tri_tests_args, _compare_trilu_indices, run_additional_tri_tests
Expand Down Expand Up @@ -2581,7 +2581,6 @@ def test_arange_device_vs_cpu(self, device, dtype):
self.assertEqual(cpu_tensor, device_tensor)

@onlyCUDA
@skipCUDAIfNotRocm
def test_arange_bfloat16(self, device):
ref_tensor = torch.tensor([0, 1, 2, 3], dtype=torch.bfloat16, device=device)
bfloat16_tensor = torch.arange(0, 4, dtype=torch.bfloat16, device=device)
Expand Down
44 changes: 26 additions & 18 deletions test/test_torch.py
Expand Up @@ -6316,10 +6316,6 @@ def test_copy_broadcast(self, device) -> None:
torch.uint8
]

# _types2 adds bfloat16 type to _types only on ROCm. Should eventually be unified
# with _types when bfloat16 bringup is complete on all platforms.
_types2 = _types + [torch.bfloat16] if TEST_WITH_ROCM else _types

_float_types = [torch.half, torch.float, torch.double]

_complex_types = [torch.cfloat, torch.cdouble]
Expand Down Expand Up @@ -6601,10 +6597,14 @@ def inner(self, device, dtype):
('dot', '', _medium_1d, lambda t, d: [_medium_1d(t, d)],
1e-2, 1e-5, 1e-5, _float_types + _complex_types, _cpu_types, False),
('element_size', '', _medium_1d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False),
('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('equal', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)],
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('equal', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
Expand All @@ -6618,10 +6618,14 @@ def inner(self, device, dtype):
('lcm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 0, 0, 0,
[torch.int16, torch.int32, torch.int64],
[torch.int16, torch.int32, torch.int64], True, [onlyOnCPUAndCUDA]),
('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5, _types2),
('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('is_contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
# TODO: can't check negative case - cross-device copy is contiguous
('is_same_size', 'negative', _medium_2d, lambda t, d: [_small_3d(t, d)],
Expand Down Expand Up @@ -6705,12 +6709,16 @@ def inner(self, device, dtype):
torch.LongTensor([[1], [2]]).to(dtype=_convert_t(t, d), device=d),
True],
1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('prod', '', lambda t, d: _small_2d(t, d, oneish=True),
lambda t, d: [], 1e-2, 1e-1, 1e-5, _types2, _cpu_types, False),
('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-1, 1e-5, _types2, _cpu_types, False),
('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-1, 1e-5, _types2, _cpu_types, False),
('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _types2, _cpu_types, False),
('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _types2, _cpu_types, False),
('prod', '', lambda t, d: _small_2d(t, d, oneish=True), lambda t, d: [], 1e-2, 1e-1, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-1, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-1, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False),
('sum', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, 1e-5, _types, _cpu_types, False),
('sum', 'complex', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False),
('sum', 'complex_dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False),
Expand Down

0 comments on commit e3893b8

Please sign in to comment.