Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nvfuser python API import fix #94036

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,11 @@ if(NOT USE_CUDA AND NOT USE_ROCM)
endif()

if(BUILD_NVFUSER)
add_subdirectory(third_party/nvfuser)
if(DEFINED ENV{NVFUSER_SOURCE_DIR})
add_subdirectory($ENV{NVFUSER_SOURCE_DIR} nvfuser)
else()
add_subdirectory(third_party/nvfuser nvfuser)
endif()
malfet marked this conversation as resolved.
Show resolved Hide resolved
endif()

include(cmake/Summary.cmake)
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@
# NCCL_INCLUDE_DIR
# specify where nccl is installed
#
# NVFUSER_SOURCE_DIR
# specify nvfuser root directory
#
# NVTOOLSEXT_PATH (Windows only)
# specify where nvtoolsext is installed
#
Expand Down
6 changes: 5 additions & 1 deletion test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used
# Assuming one-to-one mapping between prims and nvfuser implementations
# This test is not intended to test the correctness of the nvfuser implementation
from nvfuser._C import FusionDefinition as fd
try:
from nvfuser import FusionDefinition as fd
except ImportError:
from nvfuser._C import FusionDefinition as fd


prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops))
ops_without_nvfuser_impl = {
Expand Down
38 changes: 29 additions & 9 deletions torch/_prims/nvfuser_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,29 @@
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

if torch.cuda.is_available():
from nvfuser._C import ( # type: ignore[import]
DataType,
Fusion,
FusionDefinition,
Tensor,
)
try:
from nvfuser import ( # type: ignore[attr-defined, import]
DataType,
FusionDefinition,
Tensor,
)

def create_fusion_definition():
fd = FusionDefinition()
return fd, fd

except ImportError:
from nvfuser._C import ( # type: ignore[import]
DataType,
Fusion,
FusionDefinition,
Tensor,
)

def create_fusion_definition():
fusion = Fusion()
return fusion, FusionDefinition(fusion)

else:
DataType = None

Expand Down Expand Up @@ -74,7 +91,10 @@ def compute_contiguity(shape, strides):
Contiguous dimensions are represented by True, strided dimensions
are represented by False.
"""
from nvfuser._C import compute_contiguity
try:
from nvfuser import compute_contiguity # type: ignore[attr-defined]
except ImportError:
from nvfuser._C import compute_contiguity

return compute_contiguity(shape, strides)

Expand Down Expand Up @@ -148,8 +168,8 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
orig_flat_out, _ = tree_flatten(output_node.args[0])

fusion = Fusion()
with FusionDefinition(fusion) as fd:
fusion, fd = create_fusion_definition()
with fd:

def _to_nvfuser_constant(arg):
if isinstance(arg, Number):
Expand Down
12 changes: 10 additions & 2 deletions torch/_prims/nvfuser_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@

def _assert_nvfuser_op_exists(fname: str):
try:
from nvfuser._C import FusionDefinition as fd # type: ignore[import]
try:
from nvfuser import ( # type: ignore[import, attr-defined]
FusionDefinition as fd,
)
except ImportError:
from nvfuser._C import FusionDefinition as fd # type: ignore[import]

assert getattr(fd.Operators, fname)
except ImportError:
Expand Down Expand Up @@ -285,7 +290,10 @@ def _sum_nvfuser(
dims: DimsSequenceType,
):
keep_dims = False
from nvfuser._C import DataType # type: ignore[import]
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]

output_dtype = DataType.Null
return fd.ops.sum(a, dims, keep_dims, output_dtype)
Expand Down
5 changes: 4 additions & 1 deletion torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from torch import sym_float, sym_int, sym_max

try:
from nvfuser._C import DataType # type: ignore[import]
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]

_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
Expand Down