Skip to content

Commit

Permalink
Nvfuser python API import fix (#94036)
Browse files Browse the repository at this point in the history
1. Having nvfuser python API import working with both devel and upstream;
2. Add environment variable to allow custom nvfuser code base to be built with upstream pytorch core.
Pull Request resolved: #94036
Approved by: https://github.com/malfet, https://github.com/davidberard98
  • Loading branch information
jjsjann123 authored and pytorchmergebot committed Feb 16, 2023
1 parent 7aaebe0 commit 21eb7f7
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 14 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,11 @@ if(NOT USE_CUDA AND NOT USE_ROCM)
endif()

if(BUILD_NVFUSER)
add_subdirectory(third_party/nvfuser)
if(DEFINED ENV{NVFUSER_SOURCE_DIR})
add_subdirectory($ENV{NVFUSER_SOURCE_DIR} nvfuser)
else()
add_subdirectory(third_party/nvfuser nvfuser)
endif()
endif()

include(cmake/Summary.cmake)
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@
# NCCL_INCLUDE_DIR
# specify where nccl is installed
#
# NVFUSER_SOURCE_DIR
# specify nvfuser root directory
#
# NVTOOLSEXT_PATH (Windows only)
# specify where nvtoolsext is installed
#
Expand Down
6 changes: 5 additions & 1 deletion test/test_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used
# Assuming one-to-one mapping between prims and nvfuser implementations
# This test is not intended to test the correctness of the nvfuser implementation
from nvfuser._C import FusionDefinition as fd
try:
from nvfuser import FusionDefinition as fd
except ImportError:
from nvfuser._C import FusionDefinition as fd


prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops))
ops_without_nvfuser_impl = {
Expand Down
38 changes: 29 additions & 9 deletions torch/_prims/nvfuser_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,29 @@
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

if torch.cuda.is_available():
from nvfuser._C import ( # type: ignore[import]
DataType,
Fusion,
FusionDefinition,
Tensor,
)
try:
from nvfuser import ( # type: ignore[attr-defined, import]
DataType,
FusionDefinition,
Tensor,
)

def create_fusion_definition():
fd = FusionDefinition()
return fd, fd

except ImportError:
from nvfuser._C import ( # type: ignore[import]
DataType,
Fusion,
FusionDefinition,
Tensor,
)

def create_fusion_definition():
fusion = Fusion()
return fusion, FusionDefinition(fusion)

else:
DataType = None

Expand Down Expand Up @@ -74,7 +91,10 @@ def compute_contiguity(shape, strides):
Contiguous dimensions are represented by True, strided dimensions
are represented by False.
"""
from nvfuser._C import compute_contiguity
try:
from nvfuser import compute_contiguity # type: ignore[attr-defined]
except ImportError:
from nvfuser._C import compute_contiguity

return compute_contiguity(shape, strides)

Expand Down Expand Up @@ -148,8 +168,8 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
orig_flat_out, _ = tree_flatten(output_node.args[0])

fusion = Fusion()
with FusionDefinition(fusion) as fd:
fusion, fd = create_fusion_definition()
with fd:

def _to_nvfuser_constant(arg):
if isinstance(arg, Number):
Expand Down
12 changes: 10 additions & 2 deletions torch/_prims/nvfuser_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@

def _assert_nvfuser_op_exists(fname: str):
try:
from nvfuser._C import FusionDefinition as fd # type: ignore[import]
try:
from nvfuser import ( # type: ignore[import, attr-defined]
FusionDefinition as fd,
)
except ImportError:
from nvfuser._C import FusionDefinition as fd # type: ignore[import]

assert getattr(fd.Operators, fname)
except ImportError:
Expand Down Expand Up @@ -285,7 +290,10 @@ def _sum_nvfuser(
dims: DimsSequenceType,
):
keep_dims = False
from nvfuser._C import DataType # type: ignore[import]
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]

output_dtype = DataType.Null
return fd.ops.sum(a, dims, keep_dims, output_dtype)
Expand Down
5 changes: 4 additions & 1 deletion torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from torch import sym_float, sym_int, sym_max

try:
from nvfuser._C import DataType # type: ignore[import]
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]

_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
Expand Down

0 comments on commit 21eb7f7

Please sign in to comment.