diff --git a/benchmarks/common.py b/benchmarks/common.py index 5deaf95b37..e76f73b635 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -32,6 +32,7 @@ from torchdynamo.optimizations.python_key import python_key from torchdynamo.optimizations.training import aot_autograd_debug_strategy1 from torchdynamo.optimizations.training import aot_autograd_nnc_strategy +from torchdynamo.optimizations.training import aot_autograd_prims_nvfuser_strategy from torchdynamo.optimizations.training import aot_autograd_speedup_strategy from torchdynamo.profiler import Profiler from torchdynamo.profiler import fx_insert_profiling @@ -764,6 +765,9 @@ def parse_args(): parser.add_argument( "--nvfuser", action="store_true", help="enable nvfuser globally" ) + parser.add_argument( + "--prims-nvfuser", action="store_true", help="user prims + nvfuser backend" + ) parser.add_argument( "--isolate", action="store_true", help="run each model in its own process" ) @@ -1167,6 +1171,13 @@ def main(runner, original_dir=None): experiment = speedup_experiment backend_str = "nvfuser" if args.nvfuser else "nnc" output_filename = f"accuracy_aot_{backend_str}_mincut.csv" + elif args.prims_nvfuser: + optimize_ctx = torchdynamo.optimize( + aot_autograd_prims_nvfuser_strategy, nopython=args.nopython + ) + experiment = speedup_experiment + backend_str = "prims_nvfuser" + output_filename = f"accuracy_aot_{backend_str}.csv" elif args.print_fx: optimize_ctx = torchdynamo.optimize( print_fx, diff --git a/torchdynamo/optimizations/training.py b/torchdynamo/optimizations/training.py index f28ee5c63e..8438611c3d 100644 --- a/torchdynamo/optimizations/training.py +++ b/torchdynamo/optimizations/training.py @@ -159,3 +159,63 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs): aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext() + + +class AOTAutogradPrimsNvFuser(AOTAutogradStrategy): + """ + Use FX graph partitioner + Aten2Prims ref + trace executor + nvFuser + """ + + def __init__(self, gm: torch.fx.GraphModule, example_inputs): + super(AOTAutogradPrimsNvFuser, self).__init__(gm, example_inputs) + + from functorch.compile import min_cut_rematerialization_partition + from torch.fx.passes.backends.nvfuser import NvFuserBackend + + self.nvfuser = NvFuserBackend() + self.min_cut_rematerialization_partition = min_cut_rematerialization_partition + self.populate_aten2aten_decomps() + + def populate_aten2aten_decomps(self): + from torch._decomp import get_decompositions + + aten = torch.ops.aten + default_decompositions = { + aten.detach, + aten.gelu_backward, + aten.leaky_relu_backward, + aten.sigmoid_backward, + aten.threshold_backward, + aten.hardtanh_backward, + aten.hardsigmoid_backward, + aten.hardswish_backward, + aten.tanh_backward, + aten.silu_backward, + aten.elu_backward, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.masked_fill.Scalar, + aten.masked_fill.Tensor, + aten.elu, + aten.leaky_relu, + aten.hardtanh, + aten.hardswish, + aten.hardsigmoid, + aten.rsub, + aten.native_batch_norm_backward, + } + + self.aten2aten_decompositions = get_decompositions(default_decompositions) + + def candidate(self): + return BACKENDS["aot_autograd"]( + self.gm, + self.example_inputs, + fw_compiler=self.nvfuser, + partition_fn=self.min_cut_rematerialization_partition, + hasher_type="StaticShapeHasher", + decompositions=self.aten2aten_decompositions, + ) + + +aot_autograd_prims_nvfuser_strategy = AOTAutogradPrimsNvFuser.compile_fn