Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to run ./yolo_inference on GPU #132

Closed
mattpopovich opened this issue Jul 12, 2021 · 8 comments
Closed

Unable to run ./yolo_inference on GPU #132

mattpopovich opened this issue Jul 12, 2021 · 8 comments
Assignees
Labels
question Further information is requested

Comments

@mattpopovich
Copy link
Contributor

mattpopovich commented Jul 12, 2021

Hi, thanks for putting this repo together. I am working with it due to trying to infer my yolov5 model in c++ with pre and post processing on the GPU as I mentioned here.

I converted my model from yolov5 to yolov5-rt-stack and it seemed to work without issue, but I was having issues trying to run it. Before diving into that issue too deeply, I decided to try and run your sample code first to see if that worked.

I followed your README and I was able to run inference via CPU without issue. However, when I try to run using the --gpu flag, I get the following error:

Click to display error

root@pc:/home/user/git/yolov5-rt-stack/deployment/build# ./yolo_inference --input_source /path/to/dog.jpg --checkpoint ../../test/tracing/yolov5s.torchscript.pt --labelmap ../../notebooks/assets/coco.names --gpu 
>>> Set GPU mode
>>> Loading model
>>> Model loaded
>>> Run once on empty image
[W TensorImpl.h:1153] Warning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (function operator())
terminate called after throwing an instance of 'c10::NotImplementedError'
  what():  The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/yolort/models/yolo_module.py", line 31, in forward
    inputs: List[Tensor],
    targets: Optional[List[Dict[str, Tensor]]]=None) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
    _0 = (self)._forward_impl(inputs, targets, )
          ~~~~~~~~~~~~~~~~~~~ <--- HERE
    return _0
  def _forward_impl(self: __torch__.yolort.models.yolo_module.YOLOModule,
  File "code/__torch__/yolort/models/yolo_module.py", line 51, in _forward_impl
    _4 = (self.transform).forward(inputs, targets, )
    samples, targets0, = _4
    outputs = (self.model).forward(samples.tensors, targets0, )
               ~~~~~~~~~~~~~~~~~~~ <--- HERE
    losses = annotate(Dict[str, Tensor], {})
    detections = annotate(List[Dict[str, Tensor]], [])
  File "code/__torch__/yolort/models/box_head.py", line 293, in forward
      _105 = annotate(List[Optional[Tensor]], [inds, labels])
      scores0 = torch.index(scores, _105)
      keep = _92(boxes0, scores0, labels, self.nms_thresh, )
             ~~~ <--- HERE
      keep0 = torch.slice(keep, 0, None, self.detections_per_img)
      _106 = annotate(List[Optional[Tensor]], [keep0])
  File "code/__torch__/torchvision/ops/boxes.py", line 16, in batched_nms
    _5 = torch.unsqueeze(torch.slice(offsets), 1)
    boxes_for_nms = torch.add(boxes, _5)
    keep = __torch__.torchvision.ops.boxes.nms(boxes_for_nms, scores, iou_threshold, )
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _0 = keep
  return _0
  File "code/__torch__/torchvision/ops/boxes.py", line 87, in nms
  _16 = __torch__.torchvision.extension._assert_has_ops
  _17 = _16()
  _18 = ops.torchvision.nms(boxes, scores, iou_threshold)
        ~~~~~~~~~~~~~~~~~~~ <--- HERE
  return _18

Traceback of TorchScript, original code (most recent call last):
  File "/home/user/git/yolov5-rt-stack/yolort/models/yolo_module.py", line 137, in forward
        ``training_step``). We keep ``targets`` here for Backward Compatible.
        """
        return self._forward_impl(inputs, targets)
               ~~~~~~~~~~~~~~~~~~ <--- HERE
  File "/home/user/git/yolov5-rt-stack/yolort/models/box_head.py", line 376, in forward
    
            # non-maximum suppression, independently done per level
            keep = batched_nms(boxes, scores, labels, self.nms_thresh)
                   ~~~~~~~~~~~ <--- HERE
            # keep only topk scoring head_outputs
            keep = keep[:self.detections_per_img]
  File "/usr/local/lib/python3.8/dist-packages/torchvision-0.8.0a0+2f40a48-py3.8-linux-x86_64.egg/torchvision/ops/boxes.py", line 42, in nms
    """
    _assert_has_ops()
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
           ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: Could not run 'torchvision::nms' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'torchvision::nms' is only available for these backends: [CPU, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /resources/vision/torchvision/csrc/vision.cpp:59 [kernel]
BackendSelect: fallthrough registered at /resources/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /resources/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
ADInplaceOrView: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:60 [backend fallback]
AutogradOther: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
AutogradXLA: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:51 [backend fallback]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradMLC: fallthrough registered at /resources/pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:55 [backend fallback]
Tracer: fallthrough registered at /resources/pytorch/torch/csrc/jit/frontend/tracer.cpp:1036 [backend fallback]
Autocast: fallthrough registered at /resources/pytorch/aten/src/ATen/autocast_mode.cpp:255 [backend fallback]
Batched: registered at /resources/pytorch/aten/src/ATen/BatchingRegistrations.cpp:1019 [backend fallback]
VmapMode: fallthrough registered at /resources/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]


Exception raised from reportError at /resources/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:392 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7fd3681ff7ac in /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x9d73bf (0x7fd35aed33bf in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x1047ef6 (0x7fd35b543ef6 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x4206df7 (0x7fd35e702df7 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x4005af6 (0x7fd35e501af6 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::jit::InterpreterState::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x30 (0x7fd35e4f1480 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x3fec2ee (0x7fd35e4e82ee in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: torch::jit::GraphFunction::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&) + 0x3e (0x7fd35e242a7e in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::jit::Method::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&) + 0x168 (0x7fd35e253198 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x6afd7 (0x5599ef5a0fd7 in ./yolo_inference)
frame #10: <unknown function> + 0x5caf5 (0x5599ef592af5 in ./yolo_inference)
frame #11: __libc_start_main + 0xf3 (0x7fd3165060b3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #12: <unknown function> + 0x59a9e (0x5599ef58fa9e in ./yolo_inference)

Aborted (core dumped)

I think the main thing to note in that error log is the following:

RuntimeError: Could not run 'torchvision::nms' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'torchvision::nms' is only available for these backends: [CPU, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /resources/vision/torchvision/csrc/vision.cpp:59 [kernel]

My takeaway from that is either I am building TorchVision for CPU and not CUDA... or torchvision::nms does not support CUDA?

Click to show my environment:

root@pc:/home/user/git/yolov5-rt-stack# python3 -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.8.0a0+56b43f4
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080
GPU 2: GeForce GTX 1080

Nvidia driver version: 460.84
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.0
[pip3] pytorch-lightning==1.3.8
[pip3] torch==1.8.0a0+56b43f4
[pip3] torchmetrics==0.4.1
[pip3] torchvision==0.8.0a0+45f960c
[conda] Could not collect

I installed TorchVision via your instructions listed under number 2 here. I've tried checking out release/0.8.0, v0.8.1, and v0.8.2 all with the same issue. I've also tried v0.9.0 and v0.10.0 but your build instructions do not work for them so I ignored them for the time being.

Also worth noting there are two dependencies that I don't meet:

  • Me
    • CUDA 11.2
    • Ubuntu 20.04
  • You
    • CUDA 10.2
    • Ubuntu 18.04

Similar issues that I've found:
pytorch/vision#3058
WongKinYiu/PyTorch_YOLOv4#169

Any thoughts or ideas? Does the --gpu flag work for you?

Thanks,

Matt

Edited by @zhiqwang , updating some links in deployment.

@mattpopovich mattpopovich added the question Further information is requested label Jul 12, 2021
@mattpopovich
Copy link
Contributor Author

Update, I might have to add -DWITH_CUDA=ON to the TorchVision cmake command...
Investigating...

@zhiqwang
Copy link
Owner

Hi @mattpopovich , The codes in development are truely outdated, we plan to update them in this week.

@mattpopovich
Copy link
Contributor Author

That'd be great! Thank you. I'd be happy to test once they are updated.

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 12, 2021

Hi @mattpopovich ,

Seems that you should add -DWITH_CUDA=ON flag when building TorchVision if you want to use the CUDA version NMS.

git clone https://github.com/pytorch/vision.git
cd vision
git checkout release/0.8.0
mkdir build && cd build
cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch -DWITH_CUDA=ON
make -j4
sudo make install

I guess that the above modification will temporarily solve your problem.

FYI, TorchVision has updated the C++ interface in pytorch/vision#3146, so the codes in development \ tree:82d6afb only work with PyTorch 1.7.x / TorchVision 0.8.x. The use of libtorch with TorchVision 0.8.x and TorchVision 0.9.x is different, and TorchVision has refactored these C++ interfaces half a year ago, our plan is to update the libtorch interface to PyTorch 1.8.0+ and TorchVision 0.9.0+ in the next release (v0.4.1) for better maintainability.

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 13, 2021

Hi @mattpopovich , we've updated the C++ interfaces in #136 , and we've tested the new interfaces with PyTorch 1.8.0 / TorchVision 0.9.0, and PyTorch 1.9.0 / TorchVision 0.10.0.

Just for the above problem, I guess it's because that you forget to add the -DWITH_CUDA=ON flag when you build the C++ version of TorchVison.

git clone https://github.com/pytorch/vision.git
cd vision
git checkout release/0.9  # Assume that you're using PyTorch 1.8.0, replace this to `release/0.10` if you're using PyTorch 1.9.0
mkdir build && cd build
cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch -DWITH_CUDA=ON
make -j4
sudo make install

BTW, we didn't impose too strong restrictions on the version of CUDA and Ubuntu systems.

I believe this will solve your problem, as such I'm closing this issue, and feel free to reopen this or file a new issue if you have further questions.

@zhiqwang zhiqwang self-assigned this Jul 13, 2021
@mattpopovich
Copy link
Contributor Author

Thanks @zhiqwang. I think that solves my problem. I'm still having issues building but I've raised those concerns with pytorch/vision: pytorch/vision#4175

@Jelly123456
Copy link

@zhiqwang @mattpopovich How to convert the customized yolov5 to yolov5rt?

I trained a customized yolov5 with my own dataset and now I need to convert to yolov5rt because I need to use torch.jit.script function to export the weight to torchscript format.

Your help will appreciate.

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 14, 2021

Hi @Jelly123456 , we provide a notebook https://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/how-to-align-with-ultralytics-yolov5.ipynb to show how to convert a customized yolov5 model trained with ultralytics to yolort, you can refer to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants