Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torchdynamo.optimizations.python_key import python_key
from torchdynamo.optimizations.training import aot_autograd_debug_strategy1
from torchdynamo.optimizations.training import aot_autograd_nnc_strategy
from torchdynamo.optimizations.training import aot_autograd_prims_nvfuser_strategy
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
from torchdynamo.profiler import Profiler
from torchdynamo.profiler import fx_insert_profiling
Expand Down Expand Up @@ -764,6 +765,9 @@ def parse_args():
parser.add_argument(
"--nvfuser", action="store_true", help="enable nvfuser globally"
)
parser.add_argument(
"--prims-nvfuser", action="store_true", help="user prims + nvfuser backend"
)
parser.add_argument(
"--isolate", action="store_true", help="run each model in its own process"
)
Expand Down Expand Up @@ -1167,6 +1171,13 @@ def main(runner, original_dir=None):
experiment = speedup_experiment
backend_str = "nvfuser" if args.nvfuser else "nnc"
output_filename = f"accuracy_aot_{backend_str}_mincut.csv"
elif args.prims_nvfuser:
optimize_ctx = torchdynamo.optimize(
aot_autograd_prims_nvfuser_strategy, nopython=args.nopython
)
experiment = speedup_experiment
backend_str = "prims_nvfuser"
output_filename = f"accuracy_aot_{backend_str}.csv"
elif args.print_fx:
optimize_ctx = torchdynamo.optimize(
print_fx,
Expand Down
60 changes: 60 additions & 0 deletions torchdynamo/optimizations/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,63 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs):


aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext()


class AOTAutogradPrimsNvFuser(AOTAutogradStrategy):
"""
Use FX graph partitioner + Aten2Prims ref + trace executor + nvFuser
"""

def __init__(self, gm: torch.fx.GraphModule, example_inputs):
super(AOTAutogradPrimsNvFuser, self).__init__(gm, example_inputs)

from functorch.compile import min_cut_rematerialization_partition
from torch.fx.passes.backends.nvfuser import NvFuserBackend

self.nvfuser = NvFuserBackend()
self.min_cut_rematerialization_partition = min_cut_rematerialization_partition
self.populate_aten2aten_decomps()

def populate_aten2aten_decomps(self):
from torch._decomp import get_decompositions

aten = torch.ops.aten
default_decompositions = {
aten.detach,
aten.gelu_backward,
aten.leaky_relu_backward,
aten.sigmoid_backward,
aten.threshold_backward,
aten.hardtanh_backward,
aten.hardsigmoid_backward,
aten.hardswish_backward,
aten.tanh_backward,
aten.silu_backward,
aten.elu_backward,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.masked_fill.Scalar,
aten.masked_fill.Tensor,
aten.elu,
aten.leaky_relu,
aten.hardtanh,
aten.hardswish,
aten.hardsigmoid,
aten.rsub,
aten.native_batch_norm_backward,
}

self.aten2aten_decompositions = get_decompositions(default_decompositions)

def candidate(self):
return BACKENDS["aot_autograd"](
self.gm,
self.example_inputs,
fw_compiler=self.nvfuser,
partition_fn=self.min_cut_rematerialization_partition,
hasher_type="StaticShapeHasher",
decompositions=self.aten2aten_decompositions,
)


aot_autograd_prims_nvfuser_strategy = AOTAutogradPrimsNvFuser.compile_fn