Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference results differ for diffrent batch size #5640

Closed
casperroo opened this issue Apr 14, 2023 · 4 comments
Closed

Inference results differ for diffrent batch size #5640

casperroo opened this issue Apr 14, 2023 · 4 comments
Labels
question Further information is requested

Comments

@casperroo
Copy link

Software version:

  • Triton Inference Server 23.03 (build 56086596) - from docker
  • image_client from c++ client examples (built on host from 23.03 tag)

In short:
Is this expected behavior of the Triton Inference Server (or underlying backend) to yield different results for different batch sizes for the very same input data?
I am testing image classification with efficientnet.
The differences are minimal but I wonder if this is expected or some batched data gets overlapped somewhere in the pipeline?

My test:
I start the triton inference server and serve a model (efficientnet) with the following config:

name: "mynet"
platform: "tensorflow_savedmodel"
max_batch_size: 20
input [
  {
    name: "image_input"
    data_type: TYPE_FP32
    format: FORMAT_NHWC
    dims: [ 300, 300, 3 ]
  }
]
output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [ 13 ]
    label_filename: "labels.txt"
  }
]
instance_group [
  {
    count: 4
    kind: KIND_GPU
  }
]
optimization {
  execution_accelerators {
    gpu_execution_accelerator : [
    {
      name : "auto_mixed_precision"
    }
  ]
}}

I run the client with batch size of 1:

./examples/image_client -m mynet -u 127.0.0.1:8001 -i gRPC -c 3 -b 1 /tmp/random
Request 0, batch size 1
Image '/tmp/random/1.png':
    0.978062 (11) = l12
    0.012779 (8) = l9
    0.007350 (12) = l13
Request 1, batch size 1
Image '/tmp/random/2.png':
    0.990088 (11) = l12
    0.010908 (8) = l9
    0.008616 (12) = l13

With batch size of 2:

./examples/image_client -m mynet -u 127.0.0.1:8001 -i gRPC -c 3 -b 2 /tmp/random
Request 0, batch size 2
Image '/tmp/random/1.png':
    0.978104 (11) = l12
    0.012779 (8) = l9
    0.007350 (12) = l13
Image '/tmp/random/2.png':
    0.990011 (11) = l12
    0.010951 (8) = l9
    0.008583 (12) = l13

With batch size of 10:

./examples/image_client -m mynet -u 127.0.0.1:8001 -i gRPC -c 3 -b 10 /tmp/random
Request 0, batch size 10
Image '/tmp/random/1.png':
    0.977936 (11) = l12
    0.012730 (8) = l9
    0.007379 (12) = l13
Image '/tmp/random/2.png':
    0.990011 (11) = l12
    0.010866 (8) = l9
    0.008616 (12) = l13
(...)

So 1.png depending on batch size yields the following score for l12:

- Batch size  1: 0.978062
- Batch size  2: 0.978104
- Batch size 10: 0.977936

I get the same results when I run it from my client with cuda shared memory.
Disabling the gpu_execution_accelerators doesn't change the behavior.

@kthui
Copy link
Contributor

kthui commented Apr 17, 2023

Triton never alters the output on batch size. Triton can group inputs into different batch sizes, but it is ultimately passed to the model for the output. I think this might be difference in behavior on the model regarding different batch sizes.

CC @oandreeva-nv in case you have some input on this.

@kthui kthui added the question Further information is requested label Apr 17, 2023
@oandreeva-nv
Copy link
Contributor

I tend to agree that underlying computations on different batch sizes may cause discrepancies. CC @rmccorm4 , if I remember correctly, you were exploring similar issues at some point?

@rmccorm4
Copy link
Collaborator

Hi @casperroo,

Yes it is generally expected that different batch sizes when executed on GPU can have slight variance in their results. This is generally due to different CUDA Kernels being selected based on the batch size. Some frameworks provide certain APIs to make it more deterministic, there are some relevant docs about that for PyTorch here, but the same concepts generally apply to TensorFlow and other frameworks: https://github.com/triton-inference-server/python_backend/tree/main#determinism-and-reproducibility

@casperroo
Copy link
Author

Thank you guys very much,

you have pointed me into the right direction. The "determinism" is the keyword that I have been missing.
TensorFlow has some ways around it but in general this explains it quite well:

https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism

The following bit explains it very well:

"These differences are often caused by the use of asynchronous threads within the op nondeterministically changing the order in which floating-point numbers are added. Most of these cases of nondeterminism occur on GPUs, which have thousands of hardware threads that are used to run ops"

So this is actually so obvious once you read it: the sequence in which one executes operations on floats makes difference.
Thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Development

No branches or pull requests

4 participants