Skip to content

Commit

Permalink
Fix test_jit_cuda_archflags on machine with more than one arch (#50405)
Browse files Browse the repository at this point in the history
Summary:
This fixes the following flaky test on machine with gpus of different arch:
```
_________________________________________________________________________________________________________________ TestCppExtensionJIT.test_jit_cuda_archflags __________________________________________________________________________________________________________________

self = <test_cpp_extensions_jit.TestCppExtensionJIT testMethod=test_jit_cuda_archflags>

    unittest.skipIf(not TEST_CUDA, "CUDA not found")
    unittest.skipIf(TEST_ROCM, "disabled on rocm")
    def test_jit_cuda_archflags(self):
        # Test a number of combinations:
        #   - the default for the machine we're testing on
        #   - Separators, can be ';' (most common) or ' '
        #   - Architecture names
        #   - With/without '+PTX'

        capability = torch.cuda.get_device_capability()
        # expected values is length-2 tuple: (list of ELF, list of PTX)
        # note: there should not be more than one PTX value
        archflags = {
            '': (['{}{}'.format(capability[0], capability[1])], None),
            "Maxwell+Tegra;6.1": (['53', '61'], None),
            "Pascal 3.5": (['35', '60', '61'], None),
            "Volta": (['70'], ['70']),
        }
        if int(torch.version.cuda.split('.')[0]) >= 10:
            # CUDA 9 only supports compute capability <= 7.2
            archflags["7.5+PTX"] = (['75'], ['75'])
            archflags["5.0;6.0+PTX;7.0;7.5"] = (['50', '60', '70', '75'], ['60'])

        for flags, expected in archflags.items():
>           self._run_jit_cuda_archflags(flags, expected)

test_cpp_extensions_jit.py:198:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_cpp_extensions_jit.py:158: in _run_jit_cuda_archflags
    _check_cuobjdump_output(expected[0])
test_cpp_extensions_jit.py:134: in _check_cuobjdump_output
    self.assertEqual(actual_arches, expected_arches,
../../.local/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py:1211: in assertEqual
    super().assertEqual(len(x), len(y), msg=self._get_assert_msg(msg, debug_msg=debug_msg))
E   AssertionError: 2 != 1 : Attempted to compare the lengths of [iterable] types: Expected: 2; Actual: 1.
E   Flags: ,  Actual: ['sm_75', 'sm_86'],  Expected: ['sm_86']
E   Stderr:
E   Output: ELF file    1: cudaext_archflags.1.sm_75.cubin
E   ELF file    2: cudaext_archflags.2.sm_86.cubin

```

Pull Request resolved: #50405

Reviewed By: albanD

Differential Revision: D25920200

Pulled By: mrshenli

fbshipit-source-id: 1042a984142108f954a283407334d39e3ec328ce
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Jan 26, 2021
1 parent 5f297cc commit 5834b3b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/test_cpp_extensions_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _check_cuobjdump_output(expected_values, is_ptx=False):
err, output))

actual_arches = sorted(re.findall(r'sm_\d\d', output))
expected_arches = ['sm_' + xx for xx in expected_values]
expected_arches = sorted(['sm_' + xx for xx in expected_values])
self.assertEqual(actual_arches, expected_arches,
msg="Flags: {}, Actual: {}, Expected: {}\n"
"Stderr: {}\nOutput: {}".format(
Expand Down Expand Up @@ -180,11 +180,12 @@ def test_jit_cuda_archflags(self):
# - Architecture names
# - With/without '+PTX'

capability = torch.cuda.get_device_capability()
n = torch.cuda.device_count()
capabilities = {torch.cuda.get_device_capability(i) for i in range(n)}
# expected values is length-2 tuple: (list of ELF, list of PTX)
# note: there should not be more than one PTX value
archflags = {
'': (['{}{}'.format(capability[0], capability[1])], None),
'': (['{}{}'.format(capability[0], capability[1]) for capability in capabilities], None),
"Maxwell+Tegra;6.1": (['53', '61'], None),
"Pascal 3.5": (['35', '60', '61'], None),
"Volta": (['70'], ['70']),
Expand Down

0 comments on commit 5834b3b

Please sign in to comment.