-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Inductor] addmm + activation function fusion #158137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@pytorchbot label "topic: not user facing" |
needs some benchmarks comparing against existing Triton fusions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for PR. A couple comments.
def addmm_gelu_pattern(input, mat1, mat2): | ||
output = aten.mm(mat1, mat2) | ||
output = aten.add() | ||
return aten.gelu(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it also worth adding a pattern that targets addmm, instead of mm? I imagine most of the addmms will not be decomposed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the addmms that have an activation after it will be decomposed which is what we are interested in pattern matching
@AaronWang04 re-request when ready.. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
bbb1083
to
0f9acc5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you have a chance to run any dashboards, just curious ? Anyway looks good but it would be nice if you could use the gen_register_replacement api to avoid compile time overhead. It's more required for training patterns but still nice either way..
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes] | ||
|
||
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]: | ||
register_replacement( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we are parameterizing this across 4 total patterns.. Could I trouble you to pre-register the pattern ? See gen_register_replacement
:
## Precompiled Patterns
New patterns are added using register_replacement(). Patterns added in this way
can have a compile-time overhead because they need to be traced before
use. Patterns can be precompiled and added using gen_register_replacement()
instead. To do this you call gen_register_replacement() instead of
register_replacement(). The arguments are the same except for an additional
unique name which is used as a lookup key.
And https://github.com/pytorch/pytorch/blob/main/torchgen/fuse/gen_patterns.py.
def addmm_gelu_pattern_2(input, mat1, mat2): | ||
output = aten.mm(mat1, mat2) | ||
output = aten.add(input, output) | ||
return aten.gelu(output, approximate="tanh") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's kind of unfortunate cublas only support "tanh"... which we dont use by default in pytorch.
@eellison For my edification, what's the dashboard, is it |
@eellison the dashboard ran, results are under branch "AaronWang04_addmmfusion_perftest" not sure where the slow downs are from, whether it is weird addmm_activation shapes that are slower than triton or if the benchmarks have high variance |
@eellison Ran the dashboard again with gen_register_replacement Pasted an image of it on the top comment |
Successfully rebased |
ff89911
to
77ff3db
Compare
PR implements a pass in post_grad to fuse activation(add + mm) This was previously done similarly here pytorch#106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion. however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results perf dash board <img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" /> Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API Graph module before and after this pass Relu(addmm) ``` graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (relu, primals_2, le, permute_1) graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (_addmm_activation_default, primals_2, le, permute_1) ``` Gelu (addmm) ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {}) %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {}) %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {}) %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {}) %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {}) %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {}) return (mul_5,) graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True}) return (_addmm_activation_default,) ``` Benchmark setup: NGC pytorch 25.06 container cublas version: 12.9.1.4 torch.compile ran with dynamic = False and max_autotune H100 ``` Testing with M=1024, N=1024, K=1024, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0107 ms Average Time per Iteration (torch compile): 0.0296 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0262 ms Average Time per Iteration (torch compile): 0.0327 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.1763 ms Average Time per Iteration (torch compile): 0.2457 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 1.5280 ms Average Time per Iteration (torch compile): 1.9437 ms ``` A100 ``` ############################################################ Testing with dtype: float16 ############################################################ ============================================================ Testing with M=1024, N=1024, K=1024, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0313 ms Average Time per Iteration (torch compile): 0.0643 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.1149 ms Average Time per Iteration (torch compile): 0.1255 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.6297 ms Average Time per Iteration (torch compile): 0.7547 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=float16 ============================================================ Average Time per Iteration (cublas): 4.3821 ms Average Time per Iteration (torch compile): 5.0740 ms ``` Script ```py import torch torch.manual_seed(0) warmup, numrun= 10, 100 sizes = [1024, 2048, 4096, 8192] dtypes = [torch.float16, torch.bfloat16, torch.float32] device = torch.device("cuda") for dtype in dtypes: dtype_name = str(dtype).split('.')[-1] print(f"\n{'#'*60}") print(f"Testing with dtype: {dtype_name}") print(f"{'#'*60}") for size in sizes: M, N, K = size, size, size print(f"\n{'='*60}") print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}") print(f"{'='*60}") A = torch.randn(M, K, device=device, dtype=dtype) B = torch.randn(K, N, device=device, dtype=dtype) C = torch.randn(M, device=device, dtype=dtype) def func1(): return torch._addmm_activation(C, A, B, use_gelu=True) def func2(): return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh") func2_compiled = torch.compile( func2, dynamic=False, options={ "force_disable_caches": True, "max_autotune": True, "max_autotune_gemm": True, "max_autotune_gemm_backends": "TRITON", "autotune_fallback_to_aten": False, } ) for _ in range(warmup): func1() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func1() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms") for _ in range(warmup): func2_compiled() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func2_compiled() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms") ``` Pull Request resolved: pytorch#158137 Approved by: https://github.com/eellison
This reverts commit b9d7de3. Reverted pytorch#158137 on behalf of https://github.com/malfet due to Broke inductor torchbench, see https://hud.pytorch.org/hud/pytorch/pytorch/663da17b622d0c56f288b32b31276166dbed761e/1?per_page=50&name_filter=inductor_torchbench%2C%202%2C%202 ([comment](pytorch#158137 (comment)))
@pytorchmergebot label ciflow/inductor ciflow/torchbench ciflow/inductor-periodic |
@pytorchmergebot label ciflow/inductor ciflow/torchbench ciflow/inductor-periodic |
PR implements a pass in post_grad to fuse activation(add + mm) This was previously done similarly here pytorch#106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion. however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results perf dash board <img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" /> Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API Graph module before and after this pass Relu(addmm) ``` graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (relu, primals_2, le, permute_1) graph(): %primals_1 : [num_users=1] = placeholder[target=primals_1] %primals_2 : [num_users=2] = placeholder[target=primals_2] %primals_3 : [num_users=2] = placeholder[target=primals_3] %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {}) %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {}) %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {}) return (_addmm_activation_default, primals_2, le, permute_1) ``` Gelu (addmm) ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {}) %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {}) %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {}) %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {}) %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {}) %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {}) %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {}) return (mul_5,) graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True}) return (_addmm_activation_default,) ``` Benchmark setup: NGC pytorch 25.06 container cublas version: 12.9.1.4 torch.compile ran with dynamic = False and max_autotune H100 ``` Testing with M=1024, N=1024, K=1024, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0107 ms Average Time per Iteration (torch compile): 0.0296 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.0262 ms Average Time per Iteration (torch compile): 0.0327 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 0.1763 ms Average Time per Iteration (torch compile): 0.2457 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=bfloat16 ============================================================ Average Time per Iteration (cublas): 1.5280 ms Average Time per Iteration (torch compile): 1.9437 ms ``` A100 ``` ############################################################ Testing with dtype: float16 ############################################################ ============================================================ Testing with M=1024, N=1024, K=1024, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.0313 ms Average Time per Iteration (torch compile): 0.0643 ms ============================================================ Testing with M=2048, N=2048, K=2048, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.1149 ms Average Time per Iteration (torch compile): 0.1255 ms ============================================================ Testing with M=4096, N=4096, K=4096, dtype=float16 ============================================================ Average Time per Iteration (cublas): 0.6297 ms Average Time per Iteration (torch compile): 0.7547 ms ============================================================ Testing with M=8192, N=8192, K=8192, dtype=float16 ============================================================ Average Time per Iteration (cublas): 4.3821 ms Average Time per Iteration (torch compile): 5.0740 ms ``` Script ```py import torch torch.manual_seed(0) warmup, numrun= 10, 100 sizes = [1024, 2048, 4096, 8192] dtypes = [torch.float16, torch.bfloat16, torch.float32] device = torch.device("cuda") for dtype in dtypes: dtype_name = str(dtype).split('.')[-1] print(f"\n{'#'*60}") print(f"Testing with dtype: {dtype_name}") print(f"{'#'*60}") for size in sizes: M, N, K = size, size, size print(f"\n{'='*60}") print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}") print(f"{'='*60}") A = torch.randn(M, K, device=device, dtype=dtype) B = torch.randn(K, N, device=device, dtype=dtype) C = torch.randn(M, device=device, dtype=dtype) def func1(): return torch._addmm_activation(C, A, B, use_gelu=True) def func2(): return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh") func2_compiled = torch.compile( func2, dynamic=False, options={ "force_disable_caches": True, "max_autotune": True, "max_autotune_gemm": True, "max_autotune_gemm_backends": "TRITON", "autotune_fallback_to_aten": False, } ) for _ in range(warmup): func1() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func1() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms") for _ in range(warmup): func2_compiled() torch.cuda.synchronize(device=device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) total_time_ms = 0.0 start_event.record() for _ in range(numrun): func2_compiled() end_event.record() torch.cuda.synchronize(device=device) total_time_ms += start_event.elapsed_time(end_event) avg_time_ms = total_time_ms / numrun print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms") ``` Pull Request resolved: pytorch#158137 Approved by: https://github.com/eellison
This reverts commit b9d7de3. Reverted pytorch#158137 on behalf of https://github.com/malfet due to Broke inductor torchbench, see https://hud.pytorch.org/hud/pytorch/pytorch/663da17b622d0c56f288b32b31276166dbed761e/1?per_page=50&name_filter=inductor_torchbench%2C%202%2C%202 ([comment](pytorch#158137 (comment)))
# do not fuse if there are pointwise ops after | ||
return not all(is_pointwise_use(use) for use in output.users) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should not fuse as long as there is a pointwise op as per the comment, right? The current logic does not fuse only when pointwise ops are the only consumers.
def addmm_gelu_pattern(input, mat1, mat2): | ||
output = aten.mm(mat1, mat2) | ||
output = aten.add(output, input) | ||
return aten.gelu(output, approximate="tanh") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am definitely missing some context here. But could you please explain why there is no addmm
pattern? Are add
and mm
parts always decoupled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it gets decomposed into add and mm anyway? I forget the details here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I registered this to pass_patterns[2] and moved addmm unfuse to pass_patterns[1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok, so there is an intention here. Maybe dropping a comment about pass_pattern
could be helpful.
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes] | ||
|
||
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]: | ||
name = f"{pattern.__name__}_fp32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason that there is not a registration for fp16 for the relu case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bf16 and fp16's fx graph look the exact same (theres an upcast to fp32 for the intermediates)
whereas fp32 is different due to no upcasts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, sorry, I meant bf16 since I saw you have a bf16 specific registration for gelu but not relu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok, same logic there, gelu needs upcast but relu does not
def register_addmm_activation_fusion(): | ||
shapes = [(5,), (3, 4), (4, 5)] | ||
args_fp32 = [torch.empty(shape) for shape in shapes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my personal understanding. Regarding these shapes -- are they arbitrary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it really doesn't matter what shapes these are, its just for running the function to generate the dynamo graph
name = f"{pattern.__name__}_{dtype_suffix}" | ||
gen_register_replacement( | ||
name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There must be a way to tell that the substitute is only valid in the inference mode, otherwise, as I see it, the backward will not work since _addmm_activation
does not even have a grad_fn function registered for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah its traced with fwd_only
the backward will not work with _addmm_activation (tho I think cublas supports it now, its just that the native function havent been updated to reflect this)
PR implements a pass in post_grad to fuse activation(add + mm)
This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.
however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results
perf dash board

Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API
Graph module before and after this pass
Relu(addmm)
Gelu (addmm)
Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune
H100
A100
Script
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela