Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions torchdynamo/optimizations/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import operator
from collections import defaultdict
from functools import partial
from typing import Set

import torch
Expand Down Expand Up @@ -262,6 +263,45 @@ def candidate(self):
aot_prims_nvfuser = AotPrimsNvfuser.compile_fn


def prims_executor(gm, inputs, *, executor):
# This function is called once per forward/backward pass of a graph in AOT
# Autograd. We use it to set up the nvFuser-specific FX graph and return
# execute function.
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
from torch.fx.experimental.proxy_tensor import make_fx

# First we trace the graph conditionally decomposing nodes
# that can be sent to the nvfuser executor
with TorchRefsNvfuserCapabilityMode():
prim_gm = make_fx(gm)(*inputs)

# Then we return a callable that executes the "prim_gm" graph
return partial(execute, prim_gm, executor=executor)


def create_nvprims_backend(*, executor):
class NvPrims(AotAutogradStrategy):
def __init__(self, gm: torch.fx.GraphModule, example_inputs):
super(NvPrims, self).__init__(gm, example_inputs)
self.executor = executor

def candidate(self):
return BACKENDS["aot_autograd"](
self.gm,
self.example_inputs,
fw_compiler=partial(prims_executor, executor=self.executor),
bw_compiler=partial(prims_executor, executor=self.executor),
hasher_type="StaticShapeHasher",
)

return NvPrims


aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn
aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn


def cloner(t):
if isinstance(t, torch.Tensor):
return t.clone()
Expand Down Expand Up @@ -431,6 +471,12 @@ def create_aot_backends():
# directly lowers to NVFuser without relying no Torchscript.
BACKENDS["prims_nvfuser"] = aot_prims_nvfuser
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backend still needed? What is the difference with nvprims_nvfuser?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed for our team and there are no plans to use it or recommend anyone to use it. It wasn't me who added this, so I'm not going to touch it.

"prims_nvfuser" was added in #584 and it makes a decision of translating aten->prims based on a static outdated and not maintained table from https://github.com/pytorch/pytorch/blob/09965957cd8ecc696852e73022892b3ad4475783/torch/fx/passes/backends/nvfuser.py#L70-L71
The performance of this backend is terrible, on the order of 5% of eager mode's performance.

"nvprims_nvfuser" decomposes aten ops into nvprims, a subset of prims that are guaranteed to be executable by nvFuser, and the decomposition is used only if given aten op is fully decomposable into nvprims.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just remove it if you don't recommend anyone use it.

cc @SherlockNoMad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Sherlock would like to continue working on his approach. Can we handle this separately and not remove the "prims_nvfuser" backend in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created this backend for your team's use.
Now that you have your own implementation, please feel free to remove this.


# "nvprims" is a subset of PrimTorch primitives that are guaranteed to be
# supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch.
BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser
# This is useful for debugging. Can be removed later.
BACKENDS["nvprims_aten"] = aot_nvprims_aten

# aot_nvfuser uses the memory efficient fusion algorithm from AOT Autograd.
# It uses min cut rematerialization algorithm, and uses nvfuser as the
# compiler backend. This is the most optimized setting with nvfuser for
Expand Down