-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable addmm + GELU epilogue fusion via cuBLASLt #103811
Changes from all commits
0cc8142
d8f39a9
79df8aa
f190760
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the future, we should move these tests to the |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5403,20 +5403,47 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out= | |
else: | ||
alpha = 1.2 if alpha is None else alpha | ||
beta = 0.8 if beta is None else beta | ||
res1 = f(t, m, v, alpha=alpha, beta=beta) | ||
if activation == "gelu": | ||
res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) | ||
else: | ||
res1 = f(t, m, v, alpha=alpha, beta=beta) | ||
res2 = torch.full_like(res1, math.nan) | ||
if transpose_out: | ||
res2 = res2.t().clone(memory_format=torch.contiguous_format).t() | ||
f(t, m, v, alpha=alpha, beta=beta, out=res2) | ||
if activation == "gelu": | ||
f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) | ||
else: | ||
f(t, m, v, alpha=alpha, beta=beta, out=res2) | ||
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) | ||
res1_fused_epilogue = (t.is_cuda and t.dim() == 1 and beta == 1) | ||
if TEST_WITH_ROCM or _get_torch_cuda_version() < (11, 8): | ||
# epilogue fusion enabled only on CUDA >= 11.8 | ||
res1_fused_epilogue = False | ||
res2_fused_epilogue = res1_fused_epilogue and res2.is_contiguous() | ||
if beta != 0: | ||
res3 += (beta * t).to(numpy_dtype).cpu().numpy() | ||
if activation == "relu": | ||
res3 = res3 * (res3 > 0) | ||
elif activation == "gelu": | ||
res3_t = torch.from_numpy(res3).to(dtype) | ||
approximate = "none" | ||
if res1_fused_epilogue: | ||
# fused GELU epilogue used in CUDA utilizes | ||
# the tanh approximation to compute GELU | ||
approximate = "tanh" | ||
res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) | ||
res3 = res3_t.to(numpy_dtype).cpu().numpy() | ||
else: | ||
assert activation is None, f"unsupported activation {activation}" | ||
res3 = torch.from_numpy(res3).to(dtype) | ||
self.assertEqual(res1, res2) | ||
if activation == "gelu" and res1_fused_epilogue and not res2_fused_epilogue: | ||
# when out=res2 is transposed (not contiguous), the epilogue is unfused; | ||
# in this case, when the activation is GELU and res1's epilogue is fused, | ||
# the difference between res1 and res2 will be larger due to the tanh | ||
# approximation of GELU in res1 computation, but not in res2 | ||
self.assertEqual(res1, res2, atol=1e-3, rtol=0) | ||
else: | ||
self.assertEqual(res1, res2) | ||
self.assertEqual(res1, res3) | ||
|
||
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8, | ||
|
@@ -5494,6 +5521,10 @@ def _test_addmm_impl(self, func, activation, device, dtype): | |
m2 = torch.randn(50, 25, device=device).to(dtype) | ||
self._test_addmm_addmv(func, M, m1, m2, activation=activation) | ||
|
||
# vector-shaped bias and beta=1 result in epilogue fusion in CUDA | ||
V = torch.randn(25, device=device).to(dtype) | ||
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) | ||
|
||
# Test 0-strided | ||
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) | ||
m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) | ||
|
@@ -5518,6 +5549,10 @@ def maybe_transpose(cond, m): | |
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) | ||
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) | ||
|
||
if t1: | ||
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) | ||
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) | ||
|
||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, | ||
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) | ||
@dtypesIfMPS(torch.float32) | ||
|
@@ -5534,9 +5569,18 @@ def test_addmm(self, device, dtype): | |
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) | ||
@dtypes(*floating_types_and(torch.bfloat16)) | ||
@tf32_on_and_off(0.05) | ||
def test_addmm_activation(self, device, dtype): | ||
def test_addmm_relu(self, device, dtype): | ||
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) | ||
|
||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, | ||
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm surprised the torch.half bounds are so high here, but this is a prior issue, no need to fix on your pr There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that these tests are not run on |
||
@dtypesIfCUDA(*floating_types_and( | ||
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) | ||
@dtypes(*floating_types_and(torch.bfloat16)) | ||
@tf32_on_and_off(0.05) | ||
def test_addmm_gelu(self, device, dtype): | ||
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) | ||
|
||
@dtypes(torch.float, torch.double) | ||
@dtypesIfCUDA(*floating_and_complex_types()) | ||
@tf32_on_and_off(0.005) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for anyone looking at this in the future: we're using 11.8 here because that's what our CI has test coverage for