Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch execute failure: isTensor(); Expected Tensor but got GenericList #3348

Closed
rexlow opened this issue Sep 10, 2021 · 3 comments
Closed

Comments

@rexlow
Copy link

rexlow commented Sep 10, 2021

Description
I've built a custom Yolov5 model and traced it via a export script from UltraLytics. Testing it via their inference script and my Python native code works. Now I attempt to deploy it with Triton.

While deploying it Triton did not throw any error, it is when making a client call only I get this error.

A similar error can be found in #2594 , not sure why is this closed?? Perhaps @CoderHam can have a look at this issue? I'm passing a batched numpy array not tensor.

PyTorch execute failure: isTensor() INTERNAL ASSERT FAILED at "/tmp/tritonbuild/pytorch/build/include/torch/ATen/core/ivalue_inl.h":157, please report a bug to PyTorch. Expected Tensor but got GenericList
Exception raised from toTensor at /tmp/tritonbuild/pytorch/build/include/torch/ATen/core/ivalue_inl.h:157 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7f56f28405cc in /opt/tritonserver/backends/pytorch/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7f56f2806d4e in /opt/tritonserver/backends/pytorch/libc10.so)
frame #2: <unknown function> + 0xf558 (0x7f56f2d9a558 in /opt/tritonserver/backends/pytorch/libtriton_pytorch.so)
frame #3: <unknown function> + 0x155a2 (0x7f56f2da05a2 in /opt/tritonserver/backends/pytorch/libtriton_pytorch.so)
frame #4: TRITONBACKEND_ModelInstanceExecute + 0x411 (0x7f56f2da1a91 in /opt/tritonserver/backends/pytorch/libtriton_pytorch.so)
frame #5: <unknown function> + 0x2f2fd7 (0x7f573ca89fd7 in /opt/tritonserver/bin/../lib/libtritonserver.so)
frame #6: <unknown function> + 0xfdfe0 (0x7f573c894fe0 in /opt/tritonserver/bin/../lib/libtritonserver.so)
frame #7: <unknown function> + 0xd6d84 (0x7f573c2d1d84 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x9609 (0x7f573c76c609 in /usr/lib/x86_64-linux-gnu/libpthread.so.0)
frame #9: clone + 0x43 (0x7f573bfbf293 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Triton Information
Triton docker version 21.03

Config file

name: "my_model"
platform: "pytorch_libtorch"
max_batch_size : 128
input {
    name: "input__0"
    data_type: TYPE_FP32
    format: FORMAT_NCHW
    dims: [ 3, 640, 640 ]
}
output {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [ 20 ]
    label_filename: "labels.txt"
}
instance_group {
  count: 1
  kind: KIND_GPU
  gpus: 0
}
dynamic_batching {
  preferred_batch_size: [ 1, 2, 4, 8, 16 ]
  max_queue_delay_microseconds: 100
}

Python Client

image = cv2.resize(img, (640, 640), interpolation = cv2.INTER_CUBIC)
image = image.astype(np.float32)
image /= 255. # [640, 640, 3]
image = np.transpose(image, (2, 0, 1)) # [3, 640, 640]
data = np.expand_dims(image, 0) # [1, 3, 640, 640]

input_1 = tchttp.InferInput("input__0", (1, 3, 640, 640), "FP32")
input_1.set_data_from_numpy(data)
outputs = [tchttp.InferRequestedOutput(name) for name in ["output__0"]]
response = triton_client.infer(
  mmodel_name='my_model', 
  inputs=[input_1],
  outputs=outputs,
  model_version='1')
logits = response.as_numpy('output__0')
@CoderHam
Copy link
Contributor

While libtorch does support passing lists of tensors, Tritonserver does not. You can build a simple wrapper model that coverts a single tensor (or a multiple tensors) into a list of tensors and passes it on to your model. Once you have this wrapper model, simply trace the same and you should be able to use this model inside Triton.

Refer to #2593 and #2373 (comment)

@wscjxky
Copy link

wscjxky commented Oct 25, 2021

  • I got the same problem,and I sovled . just change my model outputs to tuple or namedtuple ! and use jit.trace to export pt model

I guess tritonserver:21.09-py3 could support multiple inputs for pytorch

    def forward(self,image,points):
        coord_features = self.dist_maps(image, points)
        .........
        instance_out = feature_extractor_out[0]
        outputs = {'instances': instance_out}

        #### add following code
        if isinstance(outputs, dict):
            data_named_tuple = namedtuple("outputs", sorted(outputs.keys()))  # type: ignore
            outputs = data_named_tuple(**outputs)  # type: ignore

        elif isinstance(outputs, list):
            outputs = tuple(outputs)

        return outputs
        #### add following code
name: "infer_model"
platform: "pytorch_libtorch"

version_policy {
  latest {
    num_versions: 1
  }
}
input {
  name: "inputs__0"
  data_type: TYPE_FP32
  dims: [1, 3, 28, 28]
}
input {
  name: "inputs__1"
  data_type: TYPE_FP32
  dims: [1, 28, 28]
}

output {
  name: "classes__0"
  data_type: TYPE_FP32
  dims: [1, 1, 28, 28]
}
  • Triton Inference code
model = np.ones((1,28,28),dtype=np.float32)
check = np.ones((1,3,28,28),dtype=np.float32)

with httpclient.InferenceServerClient("localhost:8000") as client:
    inputs = [
        httpclient.InferInput("inputs__0", check.shape,
                              np_to_triton_dtype(check.dtype)),
        httpclient.InferInput("inputs__1", model.shape,
                              np_to_triton_dtype(check.dtype))
    ]
    inputs[0].set_data_from_numpy(check)
    inputs[1].set_data_from_numpy(model)
    result = client.infer("infer_model", inputs)
    print(result.as_numpy("classes__0"))
  • Otherwise I also try to create a wrapper code around the my model . just use split、cat、squeeze、narrow function to make multiple inputs combined into one

like this

    class model():
    ...
    def forward(self,inputs):
        image, points = torch.split(inputs, 1)
        points = points.squeeze(1).narrow(1, 0, 1).squeeze(1)

  • inference model
    x0=torch.ones(1, 3, 28, 28).to(device)
    x1=torch.ones(1, 28, 28).to(device)
    x1 = x1.unsqueeze(1).expand(1, 3, 2, 2)
    inputs = torch.cat([x0, x1])
    
    model(inputs)

@QingYuan-L
Copy link

QingYuan-L commented Oct 28, 2021

hi man,
I solved the problem by modify the model forward, here,in yolov5,
return x if self.training else (torch.cat(z, 1), x)
to return x if self.training else torch.cat(z, 1)
then export again.
and you can change the batchsize to 1,so that the triton can send input data, and receive the single tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

4 participants