diff --git a/torchdynamo/optimizations/training.py b/torchdynamo/optimizations/training.py index 7f64c86c1f..647811dc6a 100644 --- a/torchdynamo/optimizations/training.py +++ b/torchdynamo/optimizations/training.py @@ -1,6 +1,7 @@ import logging import operator from collections import defaultdict +from functools import partial from typing import Set import torch @@ -262,6 +263,45 @@ def candidate(self): aot_prims_nvfuser = AotPrimsNvfuser.compile_fn +def prims_executor(gm, inputs, *, executor): + # This function is called once per forward/backward pass of a graph in AOT + # Autograd. We use it to set up the nvFuser-specific FX graph and return + # execute function. + from torch._prims.context import TorchRefsNvfuserCapabilityMode + from torch._prims.executor import execute + from torch.fx.experimental.proxy_tensor import make_fx + + # First we trace the graph conditionally decomposing nodes + # that can be sent to the nvfuser executor + with TorchRefsNvfuserCapabilityMode(): + prim_gm = make_fx(gm)(*inputs) + + # Then we return a callable that executes the "prim_gm" graph + return partial(execute, prim_gm, executor=executor) + + +def create_nvprims_backend(*, executor): + class NvPrims(AotAutogradStrategy): + def __init__(self, gm: torch.fx.GraphModule, example_inputs): + super(NvPrims, self).__init__(gm, example_inputs) + self.executor = executor + + def candidate(self): + return BACKENDS["aot_autograd"]( + self.gm, + self.example_inputs, + fw_compiler=partial(prims_executor, executor=self.executor), + bw_compiler=partial(prims_executor, executor=self.executor), + hasher_type="StaticShapeHasher", + ) + + return NvPrims + + +aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn +aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn + + def cloner(t): if isinstance(t, torch.Tensor): return t.clone() @@ -431,6 +471,12 @@ def create_aot_backends(): # directly lowers to NVFuser without relying no Torchscript. BACKENDS["prims_nvfuser"] = aot_prims_nvfuser + # "nvprims" is a subset of PrimTorch primitives that are guaranteed to be + # supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch. + BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser + # This is useful for debugging. Can be removed later. + BACKENDS["nvprims_aten"] = aot_nvprims_aten + # aot_nvfuser uses the memory efficient fusion algorithm from AOT Autograd. # It uses min cut rematerialization algorithm, and uses nvfuser as the # compiler backend. This is the most optimized setting with nvfuser for