Skip to content

Commit

Permalink
switching NNC as default for TorchScript support (#105185)
Browse files Browse the repository at this point in the history
Disable nvfuser by default in TorchScript
Add deprecation warning for nvfuser usage via TorchScript and PrimTorch

Pull Request resolved: #105185
Approved by: https://github.com/malfet, https://github.com/davidberard98
  • Loading branch information
jjsjann123 authored and pytorchmergebot committed Jul 19, 2023
1 parent a10f93f commit 218b547
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 8 deletions.
5 changes: 1 addition & 4 deletions third_party/nvfuser/python_tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def func(a):
# nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back
return torch.where(tmp > 0, tmp, 0.0)

# No warnings and no errors
with warnings.catch_warnings(record=True) as w:
nvfuser_result = func(input1)
self.assertEqual(len(w), 0)
nvfuser_result = func(input1)
eager_result = func.__wrapped__(input1)
self.assertEqual(eager_result, nvfuser_result)

Expand Down
1 change: 1 addition & 0 deletions torch/_prims/nvfuser_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
),
)
)
warn("nvfuser integration in primTorch is deprecated")
result = tree_unflatten(
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
unflatten_spec, # type: ignore[has-type]
Expand Down
6 changes: 2 additions & 4 deletions torch/csrc/jit/codegen/cuda/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,8 @@ class NVFuserEnabler {
return *getCachedFuserEnabledEnvVar();
}
// 3. default value
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
// default off since TorchScript integration isn't maintained any longer
return false;
#else
return nvfuserCanBeEnabled();
#endif
}

public:
Expand Down Expand Up @@ -235,6 +232,7 @@ void fuseGraph(std::shared_ptr<Graph>& graph) {
return;
}

TORCH_WARN_ONCE("nvfuser integration in TorchScript is deprecated.");
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
"Running the CUDA fuser requires a CUDA build.");
Expand Down

0 comments on commit 218b547

Please sign in to comment.