Skip to content

JIT related failures using Core nightly 20220426 #5881

@datumbox

Description

@datumbox

🐛 Describe the bug

The latest main branch fails the cmake_* jobs with the following error:

+ ./test_frcnn_tracing
Loading model
Model loaded
terminate called after throwing an instance of 'std::runtime_error'
  what():  The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torchvision/models/detection/faster_rcnn.py", line 93, in forward
    roi_heads = self.roi_heads
    image_sizes = images0.image_sizes
    _26 = (roi_heads).forward(features0, proposals, image_sizes, targets4, )
           ~~~~~~~~~~~~~~~~~~ <--- HERE
    detections, detector_losses, = _26
    transform0 = self.transform
  File "code/__torch__/torchvision/models/detection/roi_heads.py", line 341, in forward
      labels8 = torch.index(labels7, _104)
      nms_thresh = self.nms_thresh
      keep0 = _90(boxes4, scores3, labels8, nms_thresh, )
              ~~~ <--- HERE
      detections_per_img = self.detections_per_img
      keep1 = torch.slice(keep0, 0, None, detections_per_img)
  File "code/__torch__/torchvision/ops/boxes.py", line 49, in batched_nms
    _14 = False
  if _14:
    _17 = _9(boxes, scores, idxs, iou_threshold, )
          ~~ <--- HERE
    _16 = _17
  else:
  File "code/__torch__/torchvision/ops/boxes.py", line 74, in _batched_nms_vanilla
    iou_threshold: float) -> Tensor:
  _26 = __torch__.torch.functional._return_output
  keep_mask = torch.zeros_like(scores, dtype=11)
              ~~~~~~~~~~~~~~~~ <--- HERE
  _27 = _26(idxs, True, False, False, None, )
  for _28 in range(torch.len(_27)):

Traceback of TorchScript, original code (most recent call last):
  File "/root/project/torchvision/models/detection/generalized_rcnn.py", line 105, in forward
            features = OrderedDict([("0", features)])
        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
                                      ~~~~~~~~~~~~~~ <--- HERE
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
    
  File "/root/project/torchvision/models/detection/roi_heads.py", line 716, in forward
    
            # non-maximum suppression, independently done per class
            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
                   ~~~~~~~~~~~~~~~~~~~ <--- HERE
            # keep only topk scoring predictions
            keep = keep[: self.detections_per_img]
  File "/root/project/torchvision/ops/boxes.py", line 72, in batched_nms
    # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
    if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
        return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
               ~~~~~~~~~~~~~~~~~~~~ <--- HERE
    else:
        return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
  File "/root/project/torchvision/ops/boxes.py", line 105, in _batched_nms_vanilla
) -> Tensor:
    # Based on Detectron2 implementation, just manually call nms() on each class independently
    keep_mask = torch.zeros_like(scores, dtype=torch.bool)
                ~~~~~~~~~~~~~~~~ <--- HERE
    for class_id in torch.unique(idxs):
        curr_indices = torch.where(idxs == class_id)[0]
RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1650956998902/work/torch/csrc/autograd/functions/utils.h":65, please report a bug to PyTorch. 

packaging/build_cmake.sh: line 106:  2593 Aborted                 (core dumped) ./test_frcnn_tracing

Exited with code exit status 134

More JIT-related failures can be seen at unittest_linux_gpu_py3.8:

test_jit[4-dtype0-cuda-mean-2-1.0]
Traceback (most recent call last):
  File "/home/circleci/project/test/test_ops.py", line 1566, in test_jit
    scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: required keyword attribute 'cache_id' is undefined

It looks like this is due to an upstream change on PyTorch Core.

Versions

The breakage on the cmake_* jobs started appearing on the 20220426 nightly, see 66ed693.

Though other breakages existed using nightly 20220425, the cmake jobs were passing normally (see de31e4b). All of those breakages are due to upstream changes and they are documented at #5873

We have merged #5875 to solve the above issues related to nightly 20220425. For the new issues caused by 20220426 please check the CI results of commit cc53cd0 which provides a clearer view of the issue.

cc @seemethere

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions