Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation fault during ONNX exportation. #49959

Closed
torshie opened this issue Dec 30, 2020 · 5 comments
Closed

Segmentation fault during ONNX exportation. #49959

torshie opened this issue Dec 30, 2020 · 5 comments
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@torshie
Copy link

torshie commented Dec 30, 2020

馃悰 Bug

Run into segmentation fault during ONNX exportation.

Segmentation call stack:

#0  0x00007f8cfd9d62be in torch::jit::(anonymous namespace)::ImplicitCastForONNX(torch::jit::Block*) () from /home/dengyao/Code/all-around-env/lib/python3.8/site-packages/torch/lib/libtorch_python.so
[Current thread is 1 (Thread 0x7f8cff63f740 (LWP 65224))]
(gdb) 
  #0  0x00007f8cfd9d62be in torch::jit::(anonymous namespace)::ImplicitCastForONNX(torch::jit::Block*)
    ()
   from /home/dengyao/Code/all-around-env/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#1  0x00007f8cfd99fb17 in pybind11::cpp_function::initialize<void (*&)(std::shared_ptr<torch::jit::Graph> const&), void, std::shared_ptr<torch::jit::Graph> const&, pybind11::name, pybind11::scope, pybind11::sibling>(void (*&)(std::shared_ptr<torch::jit::Graph> const&), void (*)(std::shared_ptr<torch::jit::Graph> const&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) ()
   from /home/dengyao/Code/all-around-env/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#2  0x00007f8cfd5f2346 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) ()
   from /home/dengyao/Code/all-around-env/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#3  0x00000000005f4249 in PyCFunction_Call ()
#4  0x00000000005f46d6 in _PyObject_MakeTpCall ()
#5  0x0000000000570936 in _PyEval_EvalFrameDefault ()
#6  0x000000000056955a in _PyEval_EvalCodeWithName ()
#7  0x00000000005f7323 in _PyFunction_Vectorcall ()
#8  0x000000000056c451 in _PyEval_EvalFrameDefault ()
#9  0x000000000056955a in _PyEval_EvalCodeWithName ()
#10 0x00000000005f7323 in _PyFunction_Vectorcall ()
#11 0x000000000056c451 in _PyEval_EvalFrameDefault ()
#12 0x000000000056955a in _PyEval_EvalCodeWithName ()
--Type <RET> for more, q to quit, c to continue without paging--#13 0x00000000005f7323 in _PyFunction_Vectorcall ()
#14 0x000000000056c451 in _PyEval_EvalFrameDefault ()
#15 0x000000000056955a in _PyEval_EvalCodeWithName ()
#16 0x00000000005f7323 in _PyFunction_Vectorcall ()
#17 0x0000000000570286 in _PyEval_EvalFrameDefault ()
#18 0x000000000056955a in _PyEval_EvalCodeWithName ()
#19 0x00000000005f7323 in _PyFunction_Vectorcall ()
#20 0x000000000056c451 in _PyEval_EvalFrameDefault ()
#21 0x00000000005f7146 in _PyFunction_Vectorcall ()
#22 0x000000000056b26e in _PyEval_EvalFrameDefault ()
#23 0x000000000056955a in _PyEval_EvalCodeWithName ()
#24 0x000000000068c4a7 in PyEval_EvalCode ()
#25 0x000000000067bc91 in ?? ()
#26 0x000000000067bd0f in ?? ()
#27 0x000000000067bdcb in PyRun_FileExFlags ()
#28 0x000000000067de4e in PyRun_SimpleFileExFlags ()
#29 0x00000000006b6032 in Py_RunMain ()
#30 0x00000000006b63bd in Py_BytesMain ()
#31 0x00007f8cff8320b3 in __libc_start_main (main=0x4eea30 <main>, argc=8, 
    argv=0x7ffcae041dd8, init=<optimized out>, fini=<optimized out>, 
    rtld_fini=<optimized out>, stack_end=0x7ffcae041dc8)
    at ../csu/libc-start.c:308
#32 0x00000000005fa4de in _start ()

To Reproduce

I'm sorry, but I haven't found easier steps to reproduce this segfault. Steps to reproduce the behavior:

  1. Prepare source code
git clone https://github.com/ultralytics/yolov5.git
cd yolov5
git checkout 73cf75faa848cb
  1. Put the test_script.py into models directory
  2. Download yolov5l.pt
  3. Run script: python3 test_script.py --weights /path/to/yolov5.pt --test-image /path/to/a/640x640/color/image.jpg --output /tmp/foo.onnx

Expected behavior

Not seg fault.

Environment

Collecting environment information...
PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 1060 6GB
Nvidia driver version: 450.66
cuDNN version: /usr/local/cuda-10.0/lib64/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.4
[pip3] torch==1.7.1
[pip3] torchvision==0.8.2
[conda] Could not collect

Additional context

test_script.py:

#!/usr/bin/env python3

import argparse
import sys
from typing import List

sys.path.append('.')

import torch
import cv2
import torchvision

import models
import models.yolo
from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU
from utils.general import check_img_size
from utils.datasets import letterbox


@torch.jit.script
def box_area(box):
    # box = 4xn
    return (box[2] - box[0]) * (box[3] - box[1])


@torch.jit.script
def box_iou(box1, box2):
    # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
    """
    Return intersection-over-union (Jaccard index) of boxes.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    Arguments:
        box1 (Tensor[N, 4])
        box2 (Tensor[M, 4])
    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in boxes1 and boxes2
    """

    area1 = box_area(box1.T)
    area2 = box_area(box2.T)

    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
    return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter)


@torch.jit.script
def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone()
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y


@torch.jit.script
def loop_body(xi: int, x: torch.Tensor, multi_label: bool, xc: torch.Tensor,
        output: List[torch.Tensor], labels: torch.Tensor, nc: int,
        conf_thres: float, classes: torch.Tensor, agnostic: bool,
        iou_thres: float):
    max_wh = 4096
    max_det = 300
    # Apply constraints
    # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
    x = x[xc[xi]]  # confidence

    # Cat apriori labels if autolabelling
    if len(labels.size()) and labels and len(labels[xi]):
        l = labels[xi]
        v = torch.zeros((len(l), nc + 5), device=x.device)
        v[:, :4] = l[:, 1:5]  # box
        v[:, 4] = 1.0  # conf
        v[torch.arange(len(l)), l[:, 0].long() + 5] = 1.0  # cls
        x = torch.cat((x, v), 0)

    # If none remain process next image
    if not x.shape[0]:
        return

    # Compute conf
    x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

    # Box (center x, center y, width, height) to (x1, y1, x2, y2)
    box = xywh2xyxy(x[:, :4])

    # Detections matrix nx6 (xyxy, conf, cls)
    if multi_label:
        tmp = (x[:, 5:] > conf_thres).nonzero().T
        i, j = tmp[0], tmp[1]
        x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
    else:  # best class only
        conf, j = x[:, 5:].max(1, keepdim=True)
        x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

    # Filter by class
    if len(classes.size()) and classes:
        x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

    # Apply finite constraint
    # if not torch.isfinite(x).all():
    #     x = x[torch.isfinite(x).all(1)]

    # If none remain process next image
    if not x.shape[0]:
        return

    # Sort by confidence
    # x = x[x[:, 4].argsort(descending=True)]

    # Batched NMS
    c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
    boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
    i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
    if i.shape[0] > max_det:  # limit detections
        i = i[:max_det]

    output[xi] = x[i]


@torch.jit.script
def non_max_suppression(prediction,
        conf_thres: float = 0.25, iou_thres: float = 0.45,
        classes: torch.Tensor = torch.tensor(0),
        agnostic: bool = False, labels: torch.Tensor = torch.tensor(0)):
    """Performs Non-Maximum Suppression (NMS) on inference results

    Returns:
         detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
    """

    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_det = 300  # maximum number of detections per image
    multi_label = nc > 1  # multiple labels per box (adds 0.5ms/img)

    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        loop_body(xi=xi, x=x, multi_label=multi_label, xc=xc, output=output,
            labels=labels, classes=classes, agnostic=agnostic,
            iou_thres=iou_thres, conf_thres=conf_thres, nc=nc)

    return output


class CombinedModel(torch.nn.Module):
    def __init__(self, yolov5):
        super().__init__()
        self.yolov5 = yolov5

    def forward(self, x):
        pred = self.yolov5(x)
        return non_max_suppression(pred[0])


def parse_cmdline():
    p = argparse.ArgumentParser()
    p.add_argument('--weights', required=True)
    p.add_argument('--img-size', nargs='+', type=int, default=[640, 640])
    p.add_argument('--test-image', required=True)
    p.add_argument('--output', required=True)
    return p.parse_args()


def load_model(path, size):
    model = attempt_load(path, map_location=torch.device('cpu'))
    gs = int(max(model.stride))  # grid size (max stride)
    size = [check_img_size(x, gs) for x in size]  # verify img_size are gs-multiples
    # Update model
    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
        if isinstance(m, models.common.Conv):  # assign export-friendly activations
            if isinstance(m.act, torch.nn.Hardswish):
                m.act = Hardswish()
            elif isinstance(m.act, torch.nn.SiLU):
                m.act = SiLU()
        # elif isinstance(m, models.yolo.Detect):
        #     m.forward = m.forward_export  # assign forward (optional)
    model.model[-1].export = False
    return model, size


def load_test_image(filename, size):
    image = cv2.imread(filename)
    if list(image.shape[:2]) != list(size):
        print("Invalid image shape", image.shape, file=sys.stderr)
        sys.exit(1)
    image = letterbox(image, new_shape=640)[0]
    image = image[:, :, ::-1].transpose(2, 0, 1)
    image = torch.from_numpy(image.copy()).float() / 255.0

    return image.unsqueeze(0)


def main():
    cmdline = parse_cmdline()

    yolov5, img_size = load_model(cmdline.weights, cmdline.img_size)
    batch = load_test_image(cmdline.test_image, img_size)

    # This call is a must, it will change the model internal structure.
    y = yolov5(batch)
    print(y[0].shape)

    combined = CombinedModel(yolov5)

    traced = torch.jit.trace_module(combined, {'forward': batch})
    r = combined(batch)
    print(r[0].shape)

    torch.onnx.export(combined, batch, cmdline.output, opset_version=11)


if __name__ == '__main__':
    print(non_max_suppression.code)
    print(loop_body.code)
    main()

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof

@mruberry mruberry added high priority module: onnx Related to torch.onnx labels Dec 30, 2020
@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jan 4, 2021
@malfet malfet self-assigned this Jan 7, 2021
@malfet
Copy link
Contributor

malfet commented Jan 7, 2021

Grabbing for myself, as I have a symbolicated backtrace:

frame #0: 0x00007fffbec6bb68 libtorch_python.so`c10::TensorType::scalarType(this=0x0000000000000000) const at jit_type.h:489:12
   486 	    return device_;
   487 	  }
   488 	  c10::optional<at::ScalarType> scalarType() const {
-> 489 	    return scalar_type_;
    	           ^
   490 	  }
   491 	  c10::optional<bool> requiresGrad() const {
   492 	    return requires_grad_;
(lldb) up
frame #1: 0x00007fffbec696ad libtorch_python.so`operator(__closure=0x00007fffffffbad0, input=0x0000555579d280a0) at scalar_type_analysis.cpp:183:63
   180 	          }
   181 	        } else if (
   182 	            auto scalar_type =
-> 183 	                input->type()->cast<TensorType>()->scalarType()) {
    	                                                              ^
   184 	          typesFromTensors.emplace_back(*scalar_type);
   185 	        }
   186 	      });

malfet added a commit to malfet/pytorch that referenced this issue Jan 7, 2021
Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash.

Fixes SIGSEGV reported in pytorch#49959
facebook-github-bot pushed a commit that referenced this issue Jan 8, 2021
Summary:
Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash.

Fixes SIGSEGV reported in #49959

Pull Request resolved: #50237

Reviewed By: walterddr

Differential Revision: D25839675

Pulled By: malfet

fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab
@malfet malfet removed their assignment Jan 11, 2021
@malfet
Copy link
Contributor

malfet commented Jan 11, 2021

Fixed crash, but underlying transformation of tensor lists (i.e. Tensor[] types) needs to be properly addressed by ONNX exporter.

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this issue Jan 14, 2021
Summary:
Apply a little bit of defensive programming: `type->cast<TensorType>()` returns an optional pointer so dereferencing it can lead to a hard crash.

Fixes SIGSEGV reported in pytorch#49959

Pull Request resolved: pytorch#50237

Reviewed By: walterddr

Differential Revision: D25839675

Pulled By: malfet

fbshipit-source-id: 403d6df5e2392dd6adc308b1de48057f2f9d77ab
@malfet
Copy link
Contributor

malfet commented Feb 9, 2021

@SplitInfinity , @BowenBao do you know if the underlying issue has been fixed? And should we remove a high-pri as it no longer causing crash?

@thiagocrepaldi thiagocrepaldi self-assigned this Aug 18, 2022
@thiagocrepaldi
Copy link
Collaborator

thiagocrepaldi commented Aug 18, 2022

Looking into it. This is a trickier case of #81386

@thiagocrepaldi thiagocrepaldi removed their assignment Aug 25, 2022
@thiagocrepaldi
Copy link
Collaborator

thiagocrepaldi commented Aug 25, 2022

The generic Tensor[]? cannot be handled with the current torch jit script implementation, but a workaround for the issue here is to replace output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] by a conventional for loop such as

output = []
for i in range(prediction.shape[0]):
    output.append(torch.zeros((0, 6), device=prediction.device))

In ONNX, each operator must have a type while a ListConstruct (from JIT) does not have such constraint. We are looking into ways of providing more context to the ONNX exporter to handle such scenarios without doing additional "passes" to replace one pattern by another

As the crash doesn't happen anymore, I have remove the high pri label and will close this issue.
The generic ListConstruct issue is being tracked at https://github.com/microsoft/onnx-converters-private/issues/46

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants