Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable addmm + GELU epilogue fusion via cuBLASLt #103811

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
if (!disable_addmm_cuda_lt) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() &&
self.is_contiguous() && result.is_contiguous() &&
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
Expand Down Expand Up @@ -287,13 +287,13 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
self.const_data_ptr<scalar_t>(),
result_->data_ptr<scalar_t>(),
result_ld,
#if 0
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for anyone looking at this in the future: we're using 11.8 here because that's what our CI has test coverage for

activation_to_gemm_and_blas_arg(activation)
#else
// GELU is not supported (and does not compile!) prior
// to CUDA 11.4. Have observed accuracy issues with
// to CUDA 11.4. Have observed accuracy issues with
// GELU epilogue in 11.4; disabling the GELU epilogue
// path until we confirm which version it's working in.
// path for CUDA version < 11.8.
activation != Activation::GELU
? activation_to_gemm_and_blas_arg(activation)
: cuda::blas::GEMMAndBiasActivationEpilogue::None
Expand Down Expand Up @@ -345,7 +345,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !0
#if !defined(CUDA_VERSION) || CUDA_VERSION < 11080
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*result_));
}
Expand Down
52 changes: 48 additions & 4 deletions test/test_linalg.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, we should move these tests to the test_matmul_cuda.py file. Other than that changes look good!

Original file line number Diff line number Diff line change
Expand Up @@ -5403,20 +5403,47 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=
else:
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
if activation == "gelu":
res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
else:
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = torch.full_like(res1, math.nan)
if transpose_out:
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
f(t, m, v, alpha=alpha, beta=beta, out=res2)
if activation == "gelu":
f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
else:
f(t, m, v, alpha=alpha, beta=beta, out=res2)
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
res1_fused_epilogue = (t.is_cuda and t.dim() == 1 and beta == 1)
if TEST_WITH_ROCM or _get_torch_cuda_version() < (11, 8):
# epilogue fusion enabled only on CUDA >= 11.8
res1_fused_epilogue = False
res2_fused_epilogue = res1_fused_epilogue and res2.is_contiguous()
if beta != 0:
res3 += (beta * t).to(numpy_dtype).cpu().numpy()
if activation == "relu":
res3 = res3 * (res3 > 0)
elif activation == "gelu":
res3_t = torch.from_numpy(res3).to(dtype)
approximate = "none"
if res1_fused_epilogue:
# fused GELU epilogue used in CUDA utilizes
# the tanh approximation to compute GELU
approximate = "tanh"
res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
res3 = res3_t.to(numpy_dtype).cpu().numpy()
else:
assert activation is None, f"unsupported activation {activation}"
res3 = torch.from_numpy(res3).to(dtype)
self.assertEqual(res1, res2)
if activation == "gelu" and res1_fused_epilogue and not res2_fused_epilogue:
# when out=res2 is transposed (not contiguous), the epilogue is unfused;
# in this case, when the activation is GELU and res1's epilogue is fused,
# the difference between res1 and res2 will be larger due to the tanh
# approximation of GELU in res1 computation, but not in res2
self.assertEqual(res1, res2, atol=1e-3, rtol=0)
else:
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)

@precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8,
Expand Down Expand Up @@ -5494,6 +5521,10 @@ def _test_addmm_impl(self, func, activation, device, dtype):
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(func, M, m1, m2, activation=activation)

# vector-shaped bias and beta=1 result in epilogue fusion in CUDA
V = torch.randn(25, device=device).to(dtype)
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)

# Test 0-strided
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
Expand All @@ -5518,6 +5549,10 @@ def maybe_transpose(cond, m):
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)

if t1:
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfMPS(torch.float32)
Expand All @@ -5534,9 +5569,18 @@ def test_addmm(self, device, dtype):
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm_activation(self, device, dtype):
def test_addmm_relu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm surprised the torch.half bounds are so high here, but this is a prior issue, no need to fix on your pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that these tests are not run on torch.half (only on float32, float64, and bfloat16 on CUDA). So maybe it's a copy / paste artifact.

@dtypesIfCUDA(*floating_types_and(
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm_gelu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)

@dtypes(torch.float, torch.double)
@dtypesIfCUDA(*floating_and_complex_types())
@tf32_on_and_off(0.005)
Expand Down