Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1) #110890

Closed
wants to merge 30 commits into from

Conversation

kadeng
Copy link
Contributor

@kadeng kadeng commented Oct 9, 2023

Stack from ghstack (oldest at bottom):

Summary:

This PR adds epilogue fusion code generation support for the new experimental
Inductor Cutlass backend.

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by NVIDIA/cutlass example 49,
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

  • End to end code generation is possible using the above approach.
  • Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
    after a matmul.
  • Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
  • Examples / Unit tests include ReLU and ReLU6 fusion.
  • Support for fp16 and fp16 with fp32 accumulation data types.
  • Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

  • Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
  • Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
  • Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
  • Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
  • Add support for reduction operations and operations which use different output layouts than the input
  • Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes:
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:

  • Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
    prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
  • Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
  • Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:

  • Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

Test Plan:

  • CI
  • pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler

Differential Revision: D50988161

Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110890

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 601d762 with merge base dc1a358 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kadeng added a commit that referenced this pull request Oct 9, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

ghstack-source-id: 939bcc85f645b8369deac718f642693b338cb2d8
Pull Request resolved: #110890
@kadeng
Copy link
Contributor Author

kadeng commented Oct 10, 2023

The test failures appear possibly related to the Cutlass update to v3.2.1, not the code changes within pytorch itself. This test appears problematic: (

def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str):
if isSM86or89Device and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled")
scale = scale if scale is None else (1 / head_dim)
n_heads = 4
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
device=device, dtype=dtype, requires_grad=True)
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
dtype=dtype, requires_grad=True)
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
device=device, dtype=dtype, requires_grad=True)
# Run the math kernel on low precision references
query_ref_lp, key_ref_lp, value_ref_lp = self.query_key_value_clones(query, key, value, dtype=dtype)
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
query_ref, key_ref, value_ref = self.query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
is_dropout = dropout_p > 0.0
if not is_dropout:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
with sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
# High Precision Math Reference
out_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
# Low Precision Math Reference
out_lp_ref = F.scaled_dot_product_attention(
query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
else:
q_padded, q_og_size = pad_last_dim(query, 8)
k_padded, k_og_size = pad_last_dim(key, 8)
v_padded, v_og_size = pad_last_dim(value, 8)
# scale needs to be calculated on the og head_size
if scale is None:
scale = 1 / math.sqrt(q_og_size)
output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout)
out = output_tuple[0]
out = out[..., :v_og_size]
# Build dropout_mask
dbug_mask = output_tuple[-1]
query_padding_mask = torch.ones(
batch_size, seq_len_q, device=device, dtype=torch.bool)
key_padding_mask = torch.ones(
batch_size, seq_len_k, device=device, dtype=torch.bool)
softmax_mask = self.convert_flash_attn_S_to_softmax(
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim,
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
dropout_mask = softmax_mask >= 0
# High Precision Math Reference
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
dropout_mask=dropout_mask)[0]
upstream_grad = torch.rand_like(out, requires_grad=False)
# backward for flash attention on sm86 and sm89 for headdim > 64 currently disabled
if isSM86or89Device and head_dim in range(193, 256):
self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
return
out.backward(upstream_grad)
out_ref.backward(upstream_grad.to(out_ref.dtype))
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
# See [Note] Fused Tolerances above
output_fudge_factor = 3 if head_dim % 8 != 0 else 1
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
# TODO: Investigate why grad_q needs larger tolerances
query_fudge_factor = 4
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad)
value_fudge_factor = 2
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
) which uses a native SDPA Kernel, which in turn is likely to use Cutlass.

@drisspg The test which fails with Cutlass 3.2.1 has been added by you, it seems. Adding you to the list of reviewers, so we can ensure the Cutlass upgrade does not lead to regressions.

I can reproduce a CUDA memory access violation using the following commandline:

pytest test/test_transformers.py -k test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_64_seq_len_k_64_head_dim_64_is_causal_False_dropout_p_0_0_bfloat16_scale_None_cuda_bfloat16

When I run this through CUDA's compute-sanitizer in a debug build ( including CUDA lineinfo), I get the following trace (repeated many times with different thread ids etc):

========= Invalid __global__ atomic of size 4 bytes
=========     at 0x6c10 in void pytorch_flash::flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<pytorch_flash::Flash_bwd_kernel_traits<(int)64, (int)128, (int)128, (int)8, (int)4, (int)4, (int)4, (bool)0, (bool)0, cutlass::bfloat16_t, pytorch_flash::Flash_kernel_traits<(int)64, (int)128, (int)128, (int)8, cutlass::bfloat16_t>>, (bool)0, (bool)0, (bool)0, (bool)1>(pytorch_flash::Flash_bwd_params)
=========     by thread (0,0,0) in block (0,0,0)
=========     Address 0x7f38d32f8c00 is out of bounds
=========     and is 17,178,790,913 bytes after the nearest allocation at 0x7f34d3200000 of size 2,097,152 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x304fd2]
=========                in /lib64/libcuda.so.1
=========     Host Frame: [0x1490c]
=========                in /home/klondenberg/local/cuda/lib64/libcudart.so.12
=========     Host Frame:cudaLaunchKernel [0x6bb4b]
=========                in /home/klondenberg/local/cuda/lib64/libcudart.so.12
=========     Host Frame:/tmp/tmpxft_0031794f_00000000-6_flash_bwd_hdim64_bf16_sm80.cudafe1.stub.c:129:__device_stub__ZN13pytorch_flash44flash_bwd_dq_dk_dv_loop_seqk_parallel_kernelINS_23Flash_bwd_kernel_traitsILi64ELi128ELi128ELi8ELi4ELi4ELi4ELb0ELb0EN7cutlass10bfloat16_tENS_19Flash_kernel_traitsILi64ELi128ELi128ELi8ES3_EEEELb0ELb0ELb0ELb1EEEvNS_16Flash_bwd_paramsE(pytorch_flash::Flash_bwd_params&) [0x2870426]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h:68:pytorch_flash::run_flash_bwd_seqk_parallel<pytorch_flash::Flash_bwd_kernel_traits<64, 128, 128, 8, 4, 4, 4, false, false, cutlass::bfloat16_t, pytorch_flash::Flash_kernel_traits<64, 128, 128, 8, cutlass::bfloat16_t> >, false>(pytorch_flash::Flash_bwd_params&, CUstream_st*, bool)::{lambda()#1}::operator()() const::{lambda()#2}::operator()() const [0x2871c7d]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h:213:void pytorch_flash::run_mha_bwd_hdim64<cutlass::bfloat16_t>(pytorch_flash::Flash_bwd_params&, CUstream_st*, bool) [0x287299d]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp:756:pytorch_flash::mha_bwd(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor>&, c10::optional<at::Tensor>&, c10::optional<at::Tensor>&, float, float, bool, at::Tensor, at::Tensor) [0x2c02603]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/build/aten/src/ATen/core/TensorBody.h:92:at::native::_flash_attention_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x26a8b8d]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/build/aten/src/ATen/RegisterCUDA.cpp:41269:at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___flash_attention_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x2914ac0]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:468:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___flash_attention_backward>, std::tuple<at::Tensor, at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double> > >, std::tuple<at::Tensor, at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x2938011]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/build/aten/src/ATen/Operators_1.cpp:12722:at::_ops::_flash_attention_backward::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x1e3a985]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/native/transformers/cuda/attention_backward.cu:532:at::native::_scaled_dot_product_flash_attention_backward_cuda(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x26a9331]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/build/aten/src/ATen/RegisterCUDA.cpp:41230:at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___scaled_dot_product_flash_attention_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x29147b0]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:584:c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___scaled_dot_product_flash_attention_backward>, std::tuple<at::Tensor, at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double> > >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x2ab226e]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cuda.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:350:torch::autograd::autogradNotImplementedFallbackImpl(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x410f049]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/aten/src/ATen/core/boxing/impl/boxing.h:231:c10::impl::BoxedKernelWrapper<std::tuple<at::Tensor, at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x1e5b39f]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/build/aten/src/ATen/Operators_4.cpp:11173:at::_ops::_scaled_dot_product_flash_attention_backward::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::SymInt, c10::SymInt, double, bool, at::Tensor const&, at::Tensor const&, c10::optional<double>) [0x2228c87]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/generated/Functions.cpp:18060:torch::autograd::generated::ScaledDotProductFlashAttentionBackward0::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) [0x36a9bd9]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/function.h:181:torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) [0x4128e5b]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/engine.cpp:1013:torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) [0x4122c30]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/engine.cpp:576:torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) [0x4123d34]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/engine.cpp:380:torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) [0x411b68b]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_cpu.so
=========     Host Frame:/home/klondenberg/github/pytorch/pytorch/torch/csrc/autograd/python_engine.cpp:85:torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) [0x6b6a21]
=========                in /home/klondenberg/github/pytorch/pytorch/torch/lib/libtorch_python.so
=========     Host Frame:/opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/src/c++11/thread.cc:84:execute_native_thread_routine [0xdbbf4]
=========                in /data/users/klondenberg/miniconda3/envs/pytorch/bin/../lib/libstdc++.so.6
=========     Host Frame:start_thread [0x9f802]
=========                in /lib64/libc.so.6
=========     Host Frame:__GI___clone3 [0x3f450]
=========                in /lib64/libc.so.6

UPDATE:

I could narrow this down to Cutlass 3.2.1. If I checkout v3.2.0 within third_party/cutlass, the memory access violation is gone. Sadly I could not get the exact line number that the memory access violation is happening on, even when building with full CUDA debug info. Neither compute-sanitizer nor cuda-gdb provide that in their stack trace.

@kadeng kadeng marked this pull request as ready for review October 10, 2023 09:25
@drisspg
Copy link
Contributor

drisspg commented Oct 10, 2023

I have actually already started the 3.2.1 upgrade here: #108070

And agree I do indeed see the memory issue. As well there are also some internal window builds failing with the upgrade. I haven't had a chance to dive deeper into this issues yet.

I am going to see if building the latest flash_attn_v3 with cutlass 3.2.1 also produces the IMA if so I think it will be a harder upgrade

@kadeng
Copy link
Contributor Author

kadeng commented Oct 10, 2023

I have actually already started the 3.2.1 upgrade here: #108070

And agree I do indeed see the memory issue. As well there are also some internal window builds failing with the upgrade. I haven't had a chance to dive deeper into this issues yet.

I am going to see if building the latest flash_attn_v3 with cutlass 3.2.1 also produces the oom if so I think it will be a harder upgrade

In that case I'll think the review of this PR should proceed, but it can't be merged until this is solved somehow. For this PR, we would need at least Cutlass 3.2.0, but I would have to make changes depending on the actual version being used.

Copy link
Contributor

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kadeng !
Only finished reviewing 1/3 of the PR, will review the rest tomorrow.

I'd suggest moving CUTLASS version upgrade to a separate PR, since it's not very relevant to CUTLASS epilogue fusion logics implemented here.

# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
# to support EVT similar to
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950
class EmitGemmUniversal3xInstanceWithEVT:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we upstream this change and extract common part of EmitGemmUniversal3xInstanceWithEVT and EmitGemmUniversal3xInstance? In this case if there are future changes in CUTLASS the EVT part will also be covered.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try to get a PR landed in cutlass, yes. Until then I would leave it here to unblock.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should upstream this. Otherwise, would be hard to maintain consistency with the CUTLASS codebase. Especially, on the pytorch scale.

torch/_inductor/codegen/cuda/cutlass_utils.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
@@ -133,14 +133,31 @@ def tuned_mm(mat1, mat2, *, layout=None):
)

if m * n != 0 and use_cutlass_template(layout):
cutlass_template = CUTLASSGemmTemplate([mat1, mat2], layout, alpha=1, beta=0)
cutlass_template = CUTLASSGemmTemplate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably only need to keep templates which support EVT.
Reasons are:

  1. According to Vijay, for most cases we only need to keep TMA mainloop configs which support EVT (https://fburl.com/gdoc/575xeqlt).
  2. For templates which do not support EVT, it's likely that CUBLAS would provide similar perf (since epilogue fusion cannot be done).

Copy link
Contributor Author

@kadeng kadeng Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't leave many op templates. Most are discarded. It's a configurable option already which I use in the unit tests.

Copy link
Contributor

@aakhundov aakhundov Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience, quantity doesn't matter much with CUTLASS 3.x kernels. There are redundancies and many options are bad perf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Adnan, there are way too many configs profiled now and compilation time is too long. Only using TMA configs were suggested by Vijay (please see the link in my comment above). We could confirm with him in the next CUTLASS meeting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I make this change, I think many preexisting tests in test_max_autotune.py will fail due to lacking op candidates. The EVT capable ops support only fp16 output at the moment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean EVT doesn't support fp32 accumulation, or fp32 data types? Let's confirm with CUTLASS folks.

Copy link
Contributor Author

@kadeng kadeng Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, fp32 accumulation is supported, but there are no evt capable kernels being generated / enumerated by the cutlass_library functions which use fp32 as main data type. But that's likely not a fundamental issue.

torch/_inductor/codegen/cuda/cuda_template.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
torch/_inductor/ir.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/cuda_kernel.py Outdated Show resolved Hide resolved
@kadeng
Copy link
Contributor Author

kadeng commented Oct 11, 2023

Thanks @kadeng ! Only finished reviewing 1/3 of the PR, will review the rest tomorrow.

I'd suggest moving CUTLASS version upgrade to a separate PR, since it's not very relevant to CUTLASS epilogue fusion logics implemented here.

Yes, we would go with #108070 from drisspg. It's not clear yet whether we will be able to upgrade to Cutlass 3.2.0 or 3.2.1, but these are not the same from the API perspective of this PR. So, until that is clear, I will leave the change to third_party/cutlass in here, but then remove it prior to any merge.

Copy link
Contributor

@aakhundov aakhundov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kadeng! Left a few comments, mostly nits and questions.

test/inductor/test_max_autotune.py Outdated Show resolved Hide resolved
test/inductor/test_max_autotune.py Show resolved Hide resolved
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_only_evt_capable_ops": True,
"cuda.version": "12.2", # required to enable the Kernels we need
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the underlying CUDA in the CI is 12.1? Does this work fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's necessary here, otherwise the Kernels we need for testing won't be listed at all. It does not mean that these Kernels are performant when CUDA < 12.2, but at least they seem to exist, and tests are working.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I understand that the value is passed to the CUTLASS generator which won't generate the EVT kernels under 12.2 (although, I'd expect under 12.1). I was just trying to make sure that the kernels are functioning. Perf doesn't matter at this point. Thanks for confirming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kadeng I think 12.1 should be enough. Just checked https://github.com/NVIDIA/cutlass/blob/main/python/cutlass_library/generator.py, only found rules for 12.1 instead of 12.2.

test/inductor/test_max_autotune.py Outdated Show resolved Hide resolved
test/inductor/test_max_autotune.py Show resolved Hide resolved
torch/_inductor/kernel/mm.py Outdated Show resolved Hide resolved
torch/_inductor/scheduler.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
# to support EVT similar to
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950
class EmitGemmUniversal3xInstanceWithEVT:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should upstream this. Otherwise, would be hard to maintain consistency with the CUTLASS codebase. Especially, on the pytorch scale.

Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Oct 11, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: e05c2eaba8305d90b82750aa7fcabfc67c9fb54f
Pull Request resolved: #110890
@kadeng
Copy link
Contributor Author

kadeng commented Oct 11, 2023

For reference: The Cutlass 3.2.1 CUDA IMA is now tracked in Cutlass issues here, reported by driss: NVIDIA/cutlass#1138

Update:

The CUDA IMA is supposed to be resolved in Cutlass 3.2.2..

Copy link
Contributor

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kadeng !

@@ -133,14 +133,31 @@ def tuned_mm(mat1, mat2, *, layout=None):
)

if m * n != 0 and use_cutlass_template(layout):
cutlass_template = CUTLASSGemmTemplate([mat1, mat2], layout, alpha=1, beta=0)
cutlass_template = CUTLASSGemmTemplate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Adnan, there are way too many configs profiled now and compilation time is too long. Only using TMA configs were suggested by Vijay (please see the link in my comment above). We could confirm with him in the next CUTLASS meeting.

torch/_inductor/codegen/cuda/gemm_template.py Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/gemm_template.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/cuda_scheduling.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/cuda_kernel.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/cuda_kernel.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/cuda/cuda_template.py Outdated Show resolved Hide resolved
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Oct 12, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: 505be19b9dd0ac21fe2c8835ee48d02cbf0705a6
Pull Request resolved: #110890
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Oct 12, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: cce7d5ec921ee759efaafe982314c71154d6727a
Pull Request resolved: #110890
…ash attention kernels on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"

Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Oct 17, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: 65b1d9466a7498cd97ee862add95daffcb9605f3
Pull Request resolved: #110890
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.1, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Oct 17, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: b27ba07e3407306dad30ec496934563dd550b39b
Pull Request resolved: #110890
kadeng added a commit that referenced this pull request Nov 3, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: 84219ff2f5eb6c60e4bfaa8dd3db7cdb99f55e46
Pull Request resolved: #110890
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: 
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Nov 3, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: a7536e686e17ef016adc45211fcde5effd9f82ed
Pull Request resolved: #110890
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: 
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
kadeng added a commit that referenced this pull request Nov 3, 2023
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: 31e0ecdadf693cbb4be87934b75e5e7fb0e366a5
Pull Request resolved: #110890
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: 
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: 
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…nd] Epilogue fusion codegen (Step 1)"


Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: 
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes: 
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@kadeng
Copy link
Contributor Author

kadeng commented Nov 3, 2023

@kadeng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@kadeng
Copy link
Contributor Author

kadeng commented Nov 3, 2023

@kadeng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kadeng kadeng requested a review from jansel November 4, 2023 08:32
@kadeng
Copy link
Contributor Author

kadeng commented Nov 4, 2023

@jansel All requested changes are done. The upgrade to Cutlass 3.2.2 was factored out into a separate PR, which is also merged by now. Would be great if you could take another look.

@kadeng
Copy link
Contributor Author

kadeng commented Nov 6, 2023

Thanks, @jansel @ipiszy @aakhundov @drisspg !

@kadeng
Copy link
Contributor Author

kadeng commented Nov 6, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 6, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…110890)

Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([pytorch#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes:
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161)
Pull Request resolved: pytorch#110890
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#112762
@facebook-github-bot facebook-github-bot deleted the gh/kadeng/3/head branch November 10, 2023 15:24
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…110890)

Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([pytorch#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes:
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161)
Pull Request resolved: pytorch#110890
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#112762
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants