Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable BFloat support for gemms on arch other than ampere #50442

Closed
wants to merge 20 commits into from

Conversation

zasdfgbnm
Copy link
Collaborator

Fixes #{issue number}

@zasdfgbnm zasdfgbnm changed the title Enable BFloat support for gemms on arch other than ampere [WIP]Enable BFloat support for gemms on arch other than ampere Jan 12, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 12, 2021

💊 CI failures summary and remediations

As of commit c33a608 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda11.1_test2 (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

AssertionError: "Simulate error" does not match "grad can be implicitly created only for scalar outputs"

Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 290, in instantiated_test
    result = test_fn(self, *args)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_device_type.py", line 687, in only_fn
    return fn(slf, device, *args, **kwargs)
  File "test_autograd.py", line 6652, in test_reentrant_parent_error_on_cpu
    self._test_reentrant_parent_error_on_cpu(device)
  File "test_autograd.py", line 6638, in _test_reentrant_parent_error_on_cpu
    torch.autograd.backward([t5.sum(), t7.sum()])
AssertionError: "Simulate error" does not match "grad can be implicitly created only for scalar outputs"

----------------------------------------------------------------------
Ran 2794 tests in 2826.455s

FAILED (failures=1, skipped=23, expected failures=1)

Generating XML reports...
Generated XML report: test-reports\python-unittest\TEST-TestAutograd-20210125173827.xml
Generated XML report: test-reports\python-unittest\TEST-TestAutogradComplex-20210125173827.xml
Generated XML report: test-reports\python-unittest\TEST-TestAutogradDeviceTypeCPU-20210125173827.xml

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@zasdfgbnm zasdfgbnm changed the title [WIP]Enable BFloat support for gemms on arch other than ampere Enable BFloat support for gemms on arch other than ampere Jan 14, 2021
@zasdfgbnm
Copy link
Collaborator Author

This should be ready. Test failures are unrelated.

@mrshenli mrshenli added module: bfloat16 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 15, 2021
} else {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
TORCH_CUDABLAS_CHECK(cublasGemmEx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting and resetting cublas MathMode is not required if you specify CUBLAS_GEMM_DFALT_TENSOR_OP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.nvidia.com/cuda/cublas/index.html#cublasmath_t

CUBLAS_DEFAULT_MATH    This is the default and highest-performance mode that uses compute and intermediate storage precisions with at least the same number of mantissa and exponent bits as requested. Tensor Cores will be used whenever possible.
CUBLAS_TENSOR_OP_MATH    This mode is deprecated and will be removed in a future release. Allows the library to use Tensor Core operations whenever possible. For single precision GEMM routines cuBLAS will use the CUBLAS_COMPUTE_32F_FAST_16F compute type.

test/test_linalg.py Outdated Show resolved Hide resolved
b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2))
if not is_cuda_bfloat:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_supported=False, is_cuda_bfloat=False is an impossible situation?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Jan 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some ops are supported on SM52, and some are not. I don't think it worth the maintenance effort to write a clear list on which is supported on which SM. So what I implemented here is:

SM >= 53 ---> supported
SM < 53 ---> undefined behavior

torch/testing/_internal/common_cuda.py Outdated Show resolved Hide resolved
test/test_linalg.py Outdated Show resolved Hide resolved
test/test_linalg.py Outdated Show resolved Hide resolved
test/test_linalg.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Jan 20, 2021

Codecov Report

Merging #50442 (79a68e3) into master (8e9ed27) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #50442      +/-   ##
==========================================
- Coverage   81.00%   81.00%   -0.01%     
==========================================
  Files        1916     1916              
  Lines      209481   209484       +3     
==========================================
+ Hits       169690   169692       +2     
- Misses      39791    39792       +1     

@zasdfgbnm
Copy link
Collaborator Author

@ngimel @mruberry I think I have resolved all review comments, and all tests pass.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Thanks @zasdfgbnm!

Would you just rebase this? Sorry PyTorch is especially popular these days.

@zasdfgbnm
Copy link
Collaborator Author

@mruberry rebased

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

Internal builds are failing with:

    from torch.testing._internal.common_cuda import _get_torch_cuda_version
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/fbcode/buck-out/dev/gen/caffe2/caffe2/fb/high_perf_models/pytorch/torchscript/test/test_ir_bench#binary,link-tree/torch/testing/_internal/common_cuda.py", line 18, in <module>
    CUDA11OrLater = torch.version.cuda and float(torch.version.cuda) >= 11
ValueError: could not convert string to float: '9.2.0'

We typically use LooseVersion for version comparisons. See

active_if=LooseVersion(scipy.__version__) < "1.4.0"),

for an example.

@zasdfgbnm
Copy link
Collaborator Author

@mruberry fixed

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in b822aba.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: bfloat16 open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants