Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

caffe2: RuntimeError: [enforce fail at reshape_op.h:110] with Alexnet onnx test with cuda #13598

Open
rjknight opened this issue Nov 5, 2018 · 8 comments
Labels

Comments

@rjknight
Copy link
Contributor

rjknight commented Nov 5, 2018

馃悰 Bug

we are seeing an intermittent failure in the reshape_op when trying to run the EXAMPLE: END-TO-END ALEXNET FROM PYTORCH TO CAFFE2.

Traceback (most recent call last):
  File "/tmp/test.py", line 35, in <module>
    outputs = rep.run(np.random.randn(10, 3, 227, 227).astype(np.float32))
  File "/opt/DL/pytorch/lib/python2.7/site-packages/caffe2/python/onnx/backend_rep.py", line 57, in run
    self.workspace.RunNet(self.predict_net.name)
  File "/opt/DL/pytorch/lib/python2.7/site-packages/caffe2/python/onnx/workspace.py", line 63, in f
    return getattr(workspace, attr)(*args, **kwargs)
  File "/opt/DL/pytorch/lib/python2.7/site-packages/caffe2/python/workspace.py", line 219, in RunNet
    StringifyNetName(name), num_iter, allow_fail,
  File "/opt/DL/pytorch/lib/python2.7/site-packages/caffe2/python/workspace.py", line 180, in CallWithExceptionIntercept
    return func(*args, **kwargs)
RuntimeError: [enforce fail at reshape_op.h:110] total_size == size. 92160 vs -5680358544067434496. Argument `shape` does not agree with the input data. (92160 != -5680358544067434496)Error from operator:
input: "29" input: "36" output: "37" output: "OC2_DUMMY_1" name: "" type: "Reshape" device_option { device_type: 1 device_id: 0 }

To Reproduce

Steps to reproduce the behavior:
Using the example code and the exported alexnet.onnx file run the following sample.

import onnx
import caffe2.python.onnx.backend as backend
import numpy as np
model = onnx.load("alexnet.onnx")

rep = backend.prepare(model, device="CUDA") # or "CPU"
outputs = rep.run(np.random.randn(10, 3, 224, 224).astype(np.float32))
print(outputs[0])

Expected behavior

example code runs successfully every time.

Environment

  • PyTorch Version (e.g., 1.0): 1.0 rc1
  • OS (e.g., Linux): RHEL 7.5 ppc64le
  • How you installed PyTorch (conda, pip, source): Build from source
  • Build command you used (if compiling from source): DEBUG=1 USE_OPENCV=1 python setup.py
  • Python version: 2.7 or 3.6
  • CUDA/cuDNN version: 10 7.31
  • GPU models and configuration: v100
  • Any other relevant information:

problem only occurs when use device = cuda, if device = cpu the issue is not see, seen on ppc64le unknown if other platforms are effected.

Additional context

Checking that very large incorrect value, we see least significant
32-bits of the total size is correct (i.e. 0x16800, or 92160), but the
high order 32-bits are incorrect (should be 0x0, but contain data):

$ echo "obase=16; 2^64 -5680358544067434496" | bc
B12B500000016800

$ echo "ibase=16; 16800" | bc
92160

If we build Caffe2 to crash here, we find the cause of the problem seems
to be that the shape information copied back from the GPU is damaged,
for example in one run we see:

(gdb) print actual_new_shape                                                                                                                                        
$6 = {<std::_Vector_base<long, std::allocator<long> >> = {                                                                                                          
    _M_impl = {<std::allocator<long>> = {<__gnu_cxx::new_allocator<long>> = {<No data fields>}, <No data fields>}, _M_start = 0x7ffdf836e930,                       
      _M_finish = 0x7ffdf836e940, _M_end_of_storage = 0x7ffdf836e940}}, <No data fields>}

(gdb) x/4x 0x7ffdf836e930                                                                                                                                           
0x7ffdf836e930: 0x0000000a      0x00000005      0x00002400      0x00000000

Here, 2 values (0x0a / 10, and 0x2400 / 9216) should have been copied in
from the GPU, but instead the most-significant 32-bits of the "10" value
have been overlayed with "0x05", resulting in a final apparent value of
0x0000 0005 0000 000a (21474836490).

Since the problem is intermittent, it's not clear whether the overwrite
is always happening (but mostly happens to be 0x0) or only happens
sometimes.

@hartb
Copy link
Contributor

hartb commented Nov 6, 2018

For reference, a full test script to recreate:

from torch.autograd import Variable
import torch.onnx
import torchvision


dummy_input = Variable(torch.randn(10, 3, 227, 227,dtype=torch.float32,device='cuda'))

model = torchvision.models.alexnet(pretrained=True).cuda()

input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)

import caffe2.python.onnx.backend as backend
import numpy as np

rep = backend.prepare(model, device="CUDA") # or "CPU"

# For the Caffe2 backend:
#     rep.predict_net is the Caffe2 protobuf for the network
#     rep.workspace is the Caffe2 workspace for the network
#       (see the class caffe2.python.onnx.backend.Workspace)
outputs = rep.run(np.random.randn(10, 3, 227, 227).astype(np.float32))

# To run networks with more than one input, pass a tuple
# rather than a single numpy ndarray.
# - try the other way.. seems to work ok
#img = np.random.randn(10, 3, 227, 227).astype(np.float32)
#outputs = backend.run_model(model, [img])

print(outputs[0])

@hartb
Copy link
Contributor

hartb commented Nov 6, 2018

The failure may occur less frequently when the script is run as whole, vs. more frequently when just the Caffe2 ONNX import-and-run side is run against a previously-exported "alexnet.onnx" file. (But the frequency difference may be due to other factors.)

The corruption seems to consistently hit the most-significant half of first shape dimension, turning what should be a "10" into a much larger value). Here's another run where 0x0000 0a00 is the stray value:

(gdb) print actual_new_shape
$1 = {...
    ..._M_start = 0x7ffdd82f5ef0,
      _M_finish = 0x7ffdd82f5f00,...

(gdb) x/4x 0x7ffdd82f5ef0
0x7ffdd82f5ef0: 0x0000000a      0x00000a00      0x00002400      0x00000000

Since this is LE, the stray data appears at offset 0x4, between two apparently-wholesome data words at 0x0 and 0x8, so maybe that suggests this isn't a simple overwrite?

The problem doesn't appear to be sensitive to the input data. If we seed the numpy RNG just before:

np.random.seed(123)
outputs = rep.run(np.random.randn(10, 3, 227, 227).astype(np.float32))

The problem is still intermittent: sometimes not occurring, and with different stray values when it does.

In some testing we've seen a complaint from _Workspace_feed_blob() in caffe2/python/workspace.py:

    if device_option and device_option.device_type == caffe2_pb2.CUDA:
        if arr.dtype == np.dtype('float64'):
            logger.warning(
                "CUDA operators do not support 64-bit doubles, " +
                "please use arr.astype(np.float32) or np.int32 for ints." +
                " Blob: {}".format(name) +
                " type: {}".format(str(arr.dtype))

Here the test script is specifying np.float32 as the input type, but this shape data is (inherently?) 64-bit and may be passed between host and GPU? Not sure whether this is relevant or a red herring.

@hartb
Copy link
Contributor

hartb commented Nov 6, 2018

The stray data is consistent in form but not value:

(gdb) print actual_new_shape
...
(gdb) x/4x 0x............
0x............: 0x0000000a      0x????????      0x00002400      0x00000000

The stray values always seem to be various int values. There isn't any float-looking stuff showing up.

Just made 20 runs. 2 finished successfully (without overwrite, or where stray value happened to be 0x0), the other 18 runs failed with various stray values (some appearing multiple times across the runs):

00000007  (    7):  2x
00000011  (   17):  1x
00000020  (   32):  1x
00000026  (   38):  2x
00000045  (   69):  1x
000000e6  (  230):  3x
000000ff  (  255):  2x
00000200  (  512):  2x
000005b0  (23296):  1x
00000a00  ( 2560):  1x
ffffffff  (   -1):  2x

@zou3519 zou3519 added the caffe2 label Nov 12, 2018
@rjknight123
Copy link

Looks like the issue is in the gather operation -
Examining the net trace below we can figure out the inputs/outputs to each of the operations - For our case, the items of interest are the inputs/outputs for the Reshape operation,
namely %29, %36 and the expected output %37

 %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
 %30 : Long() = onnx::Constant[value={0}]()
 %31 : Dynamic = onnx::Shape(%29), scope: AlexNet
 %32 : Long() = onnx::Gather[axis=0](%31, %30), scope: AlexNet
 %33 : Long() = onnx::Constant[value={9216}]()
 %34 : Dynamic = onnx::Unsqueeze[axes=[0]](%32), scope: AlexNet
 %35 : Dynamic = onnx::Unsqueeze[axes=[0]](%33), scope: AlexNet
 %36 : int[] = onnx::Concat[axis=0](%34, %35), scope: AlexNet
 %37 : Float(10, 9216) = onnx::Reshape(%29, %36), scope: AlexNet

%29 is the output of the pooling operation. The tensor contains all of the data,
and consistently seemed to be correctly formed.

%36 - is the combination of several operations Shape, Gather, Unsqueeze and Concat and contained the data we observed to be corrupted.

Working backwards through the net trace things seemed to go bad at the output of the Gather operation.

The Output of Shape (%31) is a tensor with dtype int64 and values [10,256,6,6] which is used as input
to the Gather function along with a tensor for the desired axis which also has a dtype of int64 and value [0]. The expected output of the Gather function is a tensor of the same dtype as the input tensor and should contain a single value equal to the first value in the %31.

Gather contains a kernel function to collect pieces of data into a single tensor.

template <typename T_INDEX>
__global__ void GatherKernel(
    const float* X,
    float* Y,
    const T_INDEX* indices,
    const int N,
    const int block_size) {
  for (int i = blockIdx.x; i < N; i += gridDim.x) {
    T_INDEX idx = indices[i];
    const float* src_offset = X + idx * block_size;
    float* dst_offset = Y + i * block_size;
    for (int j = threadIdx.x; j < block_size; j += blockDim.x) {
      dst_offset[j] = src_offset[j];
    }   
  }
}

when the kernal function is called in the example the inputs are:

GatherKernel<<< 1, 128, 0, context_.cuda_stream()>>>(src_base, out, idxs, N, block_size);

src_base   = float *
out        = float *
idxs       = int64_t *
N          = 1
block_size = 1 

since dst_offset and src_offsets are both float pointers derived from the
data member of the Output and Input tensors, only 1/2 of the data gets
copied.

I think the fix will be to template the types of dst_offset and src_offset instead of the Index tensor type, Im working on that now

@rjknight
Copy link
Contributor Author

I've made a change to use the input data size as the data type in GatherKernel, then began to wonder what types made sense, in the current version of the code I don't see a restriction on the data size for the input, should any valid numerical type be accepted?

rjknight added a commit to rjknight/pytorch that referenced this issue Nov 29, 2018
* Intermittent data corruption was seen in the Reshape_op
  when running the end-to-end pytorch to caffe2 example,
  the issue was traced back to the gather_op transfering
  only float data types. See pytorch/pytorch/pytorch#13598 for
  additional details
@rjknight
Copy link
Contributor Author

The gather code has been refactored recently, but we are still seeing the failure. One thing I noticed with the new code is that the issue only occurs when the input data type is int64

I used the following test to validate

import numpy as np

from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2

workspace.ResetWorkspace()

device_opts = caffe2_pb2.DeviceOption()
device_opts.device_type = caffe2_pb2.CUDA

op = core.CreateOperator(
     "Gather",
     ["DATA", "INDICES"],
     ["OUTPUT"],
     "test gather operation",
     control_input = None,
     device_option=device_opts
     )   

print("Type of the created op is: {}".format(type(op)))
print("Content:\n")
print(str(op))

data = np.array([10, 256, 6, 6]) 

print("DATA:",data)

inds = np.array([0])

print("INDICES:",inds)

# try a few different data types
#workspace.FeedBlob("DATA", data.astype(np.int8))  #works
#workspace.FeedBlob("DATA", data.astype(np.int16)) #works
#workspace.FeedBlob("DATA", data.astype(np.int32)) #works
workspace.FeedBlob("DATA", data.astype(np.int64))  #fails

workspace.FeedBlob("INDICES", inds.astype(np.int64))

workspace.RunOperatorOnce(op)
print("OUTPUT:", workspace.FetchBlob("OUTPUT"))

the following output was observed for the int32 type

(my-env) [builder@bb877cf01ed8 ~]$ python test_gather.py 
Type of the created op is: <class 'caffe2.proto.caffe2_pb2.OperatorDef'>
Content:

input: "DATA"
input: "INDICES"
output: "OUTPUT"
name: "test gather operation"
type: "Gather"
device_option {
  device_type: 1
}

('DATA:', array([ 10, 256,   6,   6]))
('INDICES:', array([0]))
('OUTPUT:', array([10], dtype=int32))

and the following with int64 as the input data type

(my-env) [builder@bb877cf01ed8 ~]$ python test_gather.py 
Type of the created op is: <class 'caffe2.proto.caffe2_pb2.OperatorDef'>
Content:

input: "DATA"
input: "INDICES"
output: "OUTPUT"
name: "test gather operation"
type: "Gather"
device_option {
  device_type: 1
}

('DATA:', array([ 10, 256,   6,   6]))
('INDICES:', array([0]))
('OUTPUT:', array([10995116277770])) 

echo "obase=16; 10995116277770" | bc
A000000000A

I will continue to investigate a solution

@rjknight
Copy link
Contributor Author

I also tried gathering multiple int8 types from a tensor, it also gave incorrect results when using the GPU version of gather

(my-env) [builder@bb877cf01ed8 ~]$ vim test_gather.py 
(my-env) [builder@bb877cf01ed8 ~]$ python test_gather.py 
Type of the created op is: <class 'caffe2.proto.caffe2_pb2.OperatorDef'>
Content:

input: "DATA"
input: "INDICES"
output: "OUTPUT"
name: "test gather operation"
type: "Gather"
device_option {
  device_type: 1
}

('DATA:', array([ 10, 256,   6,   6]))
('INDICES:', array([3, 1, 2, 0]))
('OUTPUT:', array([ -3, 127,   0,   0], dtype=int8))  ==> incorrect values

however, if I switch the tensor to CPU type

device_opts.device_type = caffe2_pb2.CPU the results are correct

(my-env) [builder@bb877cf01ed8 ~]$ python test_gather.py 
Type of the created op is: <class 'caffe2.proto.caffe2_pb2.OperatorDef'>
Content:

input: "DATA"
input: "INDICES"
output: "OUTPUT"
name: "test gather operation"
type: "Gather"
device_option {
  device_type: 0
}

('DATA:', array([10, 12,  6,  6]))
('INDICES:', array([3, 1, 2, 0]))
('OUTPUT:', array([ 6, 12,  6, 10], dtype=int8)) ==> correct values

@rjknight
Copy link
Contributor Author

Created a PR based on the refactored code #16077 to resolve the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants