Skip to content

Commit

Permalink
Enable addmm + GELU epilogue fusion via cuBLASLt
Browse files Browse the repository at this point in the history
Summary:

Previously, addmm + GELU epilogue fusion was unconditionally disabled in `ATen/native/cuda/Blas.cpp` due to compilation and numerical issues in CUDA <= 11.4. This PR:

1. Enables addmm + GELU epilogue fusion for CUDA >= 11.8.

2. Restricts the usage of fused addmm epilogue to contiguous output (bugfix).

3. Extends unit tests with addmm epilogue fusion and GELU activation paths.

Test Plan:

$ python test/test_linalg.py -k test_addmm_relu -v

test_addmm_relu_cpu_bfloat16 (__main__.TestLinalgCPU.test_addmm_relu_cpu_bfloat16) ... ok
test_addmm_relu_cpu_float32 (__main__.TestLinalgCPU.test_addmm_relu_cpu_float32) ... ok
test_addmm_relu_cpu_float64 (__main__.TestLinalgCPU.test_addmm_relu_cpu_float64) ... ok
test_addmm_relu_cuda_bfloat16 (__main__.TestLinalgCUDA.test_addmm_relu_cuda_bfloat16) ... ok
test_addmm_relu_cuda_float32 (__main__.TestLinalgCUDA.test_addmm_relu_cuda_float32) ... ok
test_addmm_relu_cuda_float64 (__main__.TestLinalgCUDA.test_addmm_relu_cuda_float64) ... ok

$ python test/test_linalg.py -k test_addmm_gelu -v

test_addmm_gelu_cpu_bfloat16 (__main__.TestLinalgCPU.test_addmm_gelu_cpu_bfloat16) ... ok
test_addmm_gelu_cpu_float32 (__main__.TestLinalgCPU.test_addmm_gelu_cpu_float32) ... ok
test_addmm_gelu_cpu_float64 (__main__.TestLinalgCPU.test_addmm_gelu_cpu_float64) ... ok
test_addmm_gelu_cuda_bfloat16 (__main__.TestLinalgCUDA.test_addmm_gelu_cuda_bfloat16) ... ok
test_addmm_gelu_cuda_float32 (__main__.TestLinalgCUDA.test_addmm_gelu_cuda_float32) ... ok
test_addmm_gelu_cuda_float64 (__main__.TestLinalgCUDA.test_addmm_gelu_cuda_float64) ... ok

Reviewers: @eellison

[ghstack-poisoned]
  • Loading branch information
aakhundov committed Jun 17, 2023
1 parent 2d745b9 commit 0cc8142
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
if (!disable_addmm_cuda_lt) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() &&
self.is_contiguous() && result.is_contiguous() &&
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
Expand Down Expand Up @@ -287,13 +287,13 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
self.const_data_ptr<scalar_t>(),
result_->data_ptr<scalar_t>(),
result_ld,
#if 0
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11800
activation_to_gemm_and_blas_arg(activation)
#else
// GELU is not supported (and does not compile!) prior
// to CUDA 11.4. Have observed accuracy issues with
// to CUDA 11.4. Have observed accuracy issues with
// GELU epilogue in 11.4; disabling the GELU epilogue
// path until we confirm which version it's working in.
// path for CUDA version < 11.8.
activation != Activation::GELU
? activation_to_gemm_and_blas_arg(activation)
: cuda::blas::GEMMAndBiasActivationEpilogue::None
Expand Down Expand Up @@ -345,7 +345,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !0
#if !defined(CUDA_VERSION) || CUDA_VERSION < 11800
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*result_));
}
Expand Down
49 changes: 45 additions & 4 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5403,20 +5403,44 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=
else:
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
if activation == "gelu":
res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
else:
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = torch.full_like(res1, math.nan)
if transpose_out:
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
f(t, m, v, alpha=alpha, beta=beta, out=res2)
if activation == "gelu":
f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
else:
f(t, m, v, alpha=alpha, beta=beta, out=res2)
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
res1_fused_epilogue = (t.is_cuda and t.dim() == 1 and beta == 1)
res2_fused_epilogue = res1_fused_epilogue and res2.is_contiguous()
if beta != 0:
res3 += (beta * t).to(numpy_dtype).cpu().numpy()
if activation == "relu":
res3 = res3 * (res3 > 0)
elif activation == "gelu":
res3_t = torch.from_numpy(res3).to(dtype)
approximate = "none"
if res1_fused_epilogue:
# fused GELU epilogue used in CUDA utilizes
# the tanh approximation to compute GELU
approximate = "tanh"
res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
res3 = res3_t.to(numpy_dtype).cpu().numpy()
else:
assert activation is None, f"unsupported activation {activation}"
res3 = torch.from_numpy(res3).to(dtype)
self.assertEqual(res1, res2)
if activation == "gelu" and res1_fused_epilogue and not res2_fused_epilogue:
# when out=res2 is transposed (not contiguous), the epilogue is unfused;
# in this case, when the activation is GELU and res1's epilogue is fused,
# the difference between res1 and res2 will be larger due to the tanh
# approximation of GELU in res1 computation, but not in res2
self.assertEqual(res1, res2, atol=1e-3, rtol=0)
else:
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)

@precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8,
Expand Down Expand Up @@ -5494,6 +5518,10 @@ def _test_addmm_impl(self, func, activation, device, dtype):
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(func, M, m1, m2, activation=activation)

# vector-shaped bias and beta=1 result in epilogue fusion in CUDA
V = torch.randn(25, device=device).to(dtype)
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)

# Test 0-strided
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
Expand All @@ -5518,6 +5546,10 @@ def maybe_transpose(cond, m):
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)

if t1:
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfMPS(torch.float32)
Expand All @@ -5534,9 +5566,18 @@ def test_addmm(self, device, dtype):
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm_activation(self, device, dtype):
def test_addmm_relu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_types_and(
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm_gelu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)

@dtypes(torch.float, torch.double)
@dtypesIfCUDA(*floating_and_complex_types())
@tf32_on_and_off(0.005)
Expand Down

0 comments on commit 0cc8142

Please sign in to comment.