Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVFuser] RuntimeError: ref_id_it != replayed_concrete_ids_.vector().end() INTERNAL ASSERT FAILED #84510

Open
yueyericardo opened this issue Sep 3, 2022 · 6 comments
Labels
module: assert failure The issue involves an assert failure module: nvfuser triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yueyericardo
Copy link
Contributor

yueyericardo commented Sep 3, 2022

🐛 Describe the bug

# debug_aev_nvfuser_minimal.py

import torch

torch._C._jit_set_nvfuser_single_node_mode(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
torch.manual_seed(0)


def func(x, y, z):
    return (x + y)**z


func_script = torch.jit.script(func)

x = torch.rand([3, 1, 1, 1, 1], device="cuda").requires_grad_()
y = torch.rand([1, 1, 1, 4], device="cuda")
z = torch.rand([1, 1, 1, 1], device="cuda")

for i in range(10):
    res = func(x, y, z)
    grad = torch.autograd.grad(res, x, torch.ones_like(res))[0]

    res_script = func_script(x, y, z)
    grad_script = torch.autograd.grad(res_script, x, torch.ones_like(res))[0]

    print(f"{i}: max_result_error {(res_script-res).abs().max()}, max_grad_error {(grad_script-grad).abs().max()}")

Run with

PYTORCH_NVFUSER_DISABLE=fallback PYTORCH_JIT_LOG_LEVEL=">partition:graph_fuser:>>kernel_cache" python debug_aev_nvfuser_minimal.py

error message:

[DEBUG kernel_cache.cpp:638] GraphCache constructor: 0x7fb774056cc0
[DUMP kernel_cache.cpp:639] GraphCache created for graph
[DUMP kernel_cache.cpp:639] graph(%0 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0),
[DUMP kernel_cache.cpp:639]       %1 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0),
[DUMP kernel_cache.cpp:639]       %2 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cuda:0),
[DUMP kernel_cache.cpp:639]       %3 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0)):
[DUMP kernel_cache.cpp:639]   %4 : int[] = prim::Constant[value=[3, 1, 1, 1, 1]]()
[DUMP kernel_cache.cpp:639]   %5 : int = prim::Constant[value=1]() # <string>:240:94
[DUMP kernel_cache.cpp:639]   %6 : float = prim::Constant[value=0.]() # <string>:240:52
[DUMP kernel_cache.cpp:639]   %7 : Bool(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cuda:0) = aten::eq(%2, %6) # <string>:240:40
[DUMP kernel_cache.cpp:639]   %8 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0) = aten::mul(%3, %2) # <string>:240:98
[DUMP kernel_cache.cpp:639]   %9 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cuda:0) = aten::sub(%2, %5, %5) # <string>:240:139
[DUMP kernel_cache.cpp:639]   %10 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0) = aten::pow(%1, %9) # <string>:240:123
[DUMP kernel_cache.cpp:639]   %11 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0) = aten::mul(%8, %10) # <string>:240:98
[DUMP kernel_cache.cpp:639]   %12 : Float(3, 1, 1, 1, 4, strides=[4, 4, 4, 4, 1], requires_grad=0, device=cuda:0) = aten::where(%7, %0, %11) # <string>:240:28
[DUMP kernel_cache.cpp:639]   %grad_self.20 : Float(3, 1, 1, 1, 1, strides=[1, 1, 1, 1, 1], requires_grad=0, device=cuda:0) = aten::_grad_sum_to_size(%12, %4) # <string>:13:29
[DUMP kernel_cache.cpp:639]   return (%grad_self.20)
[DEBUG kernel_cache.cpp:647] running GraphCache: 0x7fb774056cc0
Traceback (most recent call last):
  File "debug_aev_nvfuser_minimal.py", line 23, in <module>
    grad_script = torch.autograd.grad(res_script, x, torch.ones_like(res))[0]
  File "/home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: ref_id_it != replayed_concrete_ids_.vector().end() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1662103173222/work/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp":724, please report a bug to PyTorch. Could not find required iter domain in reference replay: bblockIdx.y214{( 1 * ( 1 * 1 ) )}
ref_id_it != replayed_concrete_ids_.vector().end() INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1662103173222/work/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp":724, please report a bug to PyTorch. Could not find required iter domain in reference replay: bblockIdx.y214{( 1 * ( 1 * 1 ) )}
Exception raised from constructLoopDomains at /opt/conda/conda-bld/pytorch_1662103173222/work/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp:724 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fd528ba9577 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fd528b77e2c in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::string const&) + 0x3f (0x7fd528ba749f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #3: <unknown function> + 0x2f5c6aa (0x7fd52bb566aa in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #4: <unknown function> + 0x2f5df57 (0x7fd52bb57f57 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #5: <unknown function> + 0x2f5e263 (0x7fd52bb58263 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #6: <unknown function> + 0x2f48370 (0x7fd52bb42370 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #7: <unknown function> + 0x2f4e569 (0x7fd52bb48569 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #8: <unknown function> + 0x2f4e692 (0x7fd52bb48692 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #9: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::BinaryOp const*) + 0x21 (0x7fd52bbd7301 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #10: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::IfThenElse const*) + 0xc0 (0x7fd52bbd6e90 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #11: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xdf (0x7fd52bbd815f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #12: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::IfThenElse const*) + 0xc0 (0x7fd52bbd6e90 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #13: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xdf (0x7fd52bbd815f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #14: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xdf (0x7fd52bbd815f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #15: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xdf (0x7fd52bbd815f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #16: torch::jit::fuser::cuda::IndexLowering::handle(torch::jit::fuser::cuda::kir::ForLoop const*) + 0xdf (0x7fd52bbd815f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #17: torch::jit::fuser::cuda::IndexLowering::generate(std::vector<torch::jit::fuser::cuda::Expr*, std::allocator<torch::jit::fuser::cuda::Expr*> > const&) + 0x27 (0x7fd52bbd6ca7 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #18: torch::jit::fuser::cuda::GpuLower::lower(torch::jit::fuser::cuda::Fusion*, torch::jit::fuser::cuda::DataType) + 0x13c7 (0x7fd52bc276e7 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #19: torch::jit::fuser::cuda::FusionExecutor::compileFusion(torch::jit::fuser::cuda::Fusion*, c10::ArrayRef<c10::IValue> const&, torch::jit::fuser::cuda::LaunchParams const&, torch::jit::fuser::cuda::CompileOptions) + 0xcc1 (0x7fd52baea111 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #20: torch::jit::fuser::cuda::FusionKernelRuntime::runKernelWithInput(c10::ArrayRef<c10::IValue> const&, unsigned long, torch::jit::fuser::cuda::SegmentedGroup*) + 0x591 (0x7fd52bba7421 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #21: torch::jit::fuser::cuda::FusionKernelRuntime::runWithInput(c10::ArrayRef<c10::IValue> const&, unsigned long) + 0x4ff (0x7fd52bba908f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #22: torch::jit::fuser::cuda::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&) + 0x375 (0x7fd52bbab915 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #23: <unknown function> + 0x2fb1c8f (0x7fd52bbabc8f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #24: <unknown function> + 0x302ffa8 (0x7fd52bc29fa8 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #25: torch::jit::fuser::cuda::runCudaFusionGroup(torch::jit::Node const*, std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x43c (0x7fd52bc2a7fc in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cu.so)
frame #26: <unknown function> + 0x443fef2 (0x7fd55c8f7ef2 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #27: torch::jit::InterpreterState::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x3f (0x7fd55c8e407f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #28: <unknown function> + 0x441c61a (0x7fd55c8d461a in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #29: <unknown function> + 0x441f4f6 (0x7fd55c8d74f6 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #30: <unknown function> + 0x406051b (0x7fd55c51851b in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #31: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1638 (0x7fd55c511c28 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #32: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x698 (0x7fd55c512798 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #33: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x8b (0x7fd55c509b3b in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #34: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4f (0x7fd56ce42d4f in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #35: <unknown function> + 0xdbbf4 (0x7fd57f9d8bf4 in /home/richard/program/anaconda3/envs/torch_nightly/lib/python3.8/site-packages/torch/lib/../../../../libstdc++.so.6)
frame #36: <unknown function> + 0x8609 (0x7fd5a0383609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #37: clone + 0x43 (0x7fd5a02a8133 in /lib/x86_64-linux-gnu/libc.so.6)

cc @ngimel @jjsjann123 @zasdfgbnm

Versions

the latest pytorch nightly

PyTorch version: 1.13.0.dev20220902
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.13 (default, Mar 28 2022, 11:38:47)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-125-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.3.109
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 510.85.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] torch==1.13.0.dev20220902
[pip3] torchani==2.3.dev174+gbe932233.d20220903
[pip3] torchaudio==0.13.0.dev20220902
[pip3] torchvision==0.14.0.dev20220902
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.3.1               h2bc3f7f_2
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py38h7f8727e_0
[conda] mkl_fft                   1.3.1            py38hd3c417c_0
[conda] mkl_random                1.2.2            py38h51133e4_0
[conda] numpy                     1.23.1           py38h6c91a56_0
[conda] numpy-base                1.23.1           py38ha15fc14_0
[conda] pytorch                   1.13.0.dev20220902 py3.8_cuda11.3_cudnn8.3.2_0    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchani                  2.3.dev174+gbe932233.d20220903          pypi_0    pypi
[conda] torchaudio                0.13.0.dev20220902      py38_cu113    pytorch-nightly
[conda] torchvision               0.14.0.dev20220902      py38_cu113    pytorch-nightly
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: assert failure The issue involves an assert failure module: nvfuser labels Sep 6, 2022
@yueyericardo
Copy link
Contributor Author

yueyericardo commented Apr 4, 2023

@zasdfgbnm @ngimel

The previous error was resolved, however when I changed the shape of x from [3, 1, 1, 1, 1] to [300, 1, 1, 1, 1], it still gave a silent wrong gradient result.

import torch

torch._C._jit_set_nvfuser_single_node_mode(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
torch.manual_seed(0)


def func(x, y, z):
    return (x + y)**z


func_script = torch.jit.script(func)

x = torch.rand([300, 1, 1, 1, 1], device="cuda").requires_grad_()
y = torch.rand([1, 1, 1, 4], device="cuda")
z = torch.rand([1, 1, 1, 1], device="cuda")

for i in range(10):
    res = func(x, y, z)
    grad = torch.autograd.grad(res, x, torch.ones_like(res))[0]

    res_script = func_script(x, y, z)
    grad_script = torch.autograd.grad(res_script, x, torch.ones_like(res))[0]

    print(f"{i}: max_result_error {(res_script-res).abs().max()}, max_grad_error {(grad_script-grad).abs().max()}")

output

0: max_result_error 0.0, max_grad_error 0.0
1: max_result_error 0.0, max_grad_error 0.0
2: max_result_error 0.0, max_grad_error 0.2535855770111084
3: max_result_error 0.0, max_grad_error 0.2535855770111084
4: max_result_error 0.0, max_grad_error 0.2535855770111084
5: max_result_error 0.0, max_grad_error 0.2535855770111084
6: max_result_error 0.0, max_grad_error 0.2535855770111084
7: max_result_error 0.0, max_grad_error 0.2535855770111084
8: max_result_error 0.0, max_grad_error 0.2535855770111084
9: max_result_error 0.0, max_grad_error 0.2535855770111084

Version

I'm using torch nightly, and here is all the version information.

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==2.1.0.dev20230404
[pip3] torchani==2.2.3.dev5+g40cf334
[pip3] torchaudio==2.1.0.dev20230404
[pip3] torchvision==0.16.0.dev20230404
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py38h7f8727e_0
[conda] mkl_fft                   1.3.1            py38hd3c417c_0
[conda] mkl_random                1.2.2            py38h51133e4_0
[conda] numpy                     1.23.5           py38h14f4228_0
[conda] numpy-base                1.23.5           py38h31eccc5_0
[conda] pytorch                   2.1.0.dev20230404 py3.8_cuda11.7_cudnn8.5.0_0    pytorch-nightly
[conda] pytorch-cuda              11.7                 h778d358_3    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchani                  2.2.3.dev5+g40cf334           dev_0    <develop>
[conda] torchaudio                2.1.0.dev20230404      py38_cu117    pytorch-nightly
[conda] torchtriton               2.1.0+46672772b4            py38    pytorch-nightly
[conda] torchvision               0.16.0.dev20230404      py38_cu117    pytorch-nightly

Related issue:

aiqm/torchani#628
The minimal reproduction extracted from our code is the following, so once the above is fixed, the following could also be a test to verify the correctness.

def func(x, y, z, w):
    ret = (x + y) * z * w * 2
    return ret

device = "cuda"
x = torch.rand([360, 1, 1, 1, 1], device=device).requires_grad_()
y = torch.rand([1, 1, 1, 4], device=device)
z = torch.rand([1, 1, 1, 1], device=device)
w = torch.rand([1, 1, 8, 1], device=device)

@zasdfgbnm
Copy link
Collaborator

Looks like this bug is already fixed in the latest nvfuser, likely by csarofeen#2517, you should see it in next upstream push cc @jjsjann123

This is what I get:

0: max_result_error 0.0, max_grad_error 0.0
1: max_result_error 0.0, max_grad_error 0.0
2: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
3: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
4: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
5: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
6: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
7: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
8: max_result_error 0.0, max_grad_error 1.4901161193847656e-08
9: max_result_error 0.0, max_grad_error 1.4901161193847656e-08

@jjsjann123
Copy link
Collaborator

I'll try to start upstream push next week, since we have our repo in a cleaner state now. My hands are really tied at this moment.

image

@sef43
Copy link

sef43 commented Aug 21, 2023

@jjsjann123 when will this fix be in a PyTorch release? We have users reporting errors due to this issue using PyTorch 2.0. Currently our work around is to disable NVFuser.
openmm/openmm-torch#115

@jjsjann123
Copy link
Collaborator

I don't think we'll actually try to patch this in upstream pytorch. nvfuser is in deprecating mode in TorchScript at this moment. Upstream has switched to NNC as the default fuser in TorchScript as well.

If you are stuck with PyTorch 2.0, I think manually disabling nvfuser for NNC sounds like a good way to get unblocked.

If you can move to nightly pytorch and feeling exploratory, you can try to patch nvfuser runtime with nvfuser pypi package.
https://github.com/NVIDIA/Fuser/wiki/Getting-started
NOTE: this is a irreversible process, so I suggest you try this in a container or a throw-away environment.

You can basically try pip install pytorch nightly (choose you cuda version).
And then install nvfuser as

pip install nvfuser-cu121
patch-nvfuser

This should hot swap the nvfuser library that's shipped with upstream pytorch. 🤞

@sef43
Copy link

sef43 commented Aug 23, 2023

Thank you for this update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: assert failure The issue involves an assert failure module: nvfuser triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants