Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] Runtimes for timm + TensorRT #1860

Closed
SimJeg opened this issue Apr 26, 2023 · 8 comments
Closed

❓ [Question] Runtimes for timm + TensorRT #1860

SimJeg opened this issue Apr 26, 2023 · 8 comments
Assignees
Labels
No Activity question Further information is requested

Comments

@SimJeg
Copy link

SimJeg commented Apr 26, 2023

❓ Question

I created a script to compare inference runtimes with torch, torch.compile and torch_tensorrt.compile for any timm model, input shape and dtype and some runtimes are worse using TensorRT, why ?

What you have already tried

I used latest NVIDIA pytorch container(nvcr.io/nvidia/pytorch:23.04-py3, released today) on a g5.2xlarge AWS instance (A10g GPU). You can find the script (benchmark.py) at the end of this issue and the command used to run it below :

docker run --gpus all --rm --volume $DIR:/app nvcr.io/nvidia/pytorch:23.04-py3 /bin/bash -c "pip install --pre timm && python /app/benchmark.py"

with $DIR the path to the directory where I saved the script. Here are a few results :

model dtype shape torch torch.compile torch_tensorrt.compile
resnet50 float32 (16, 3, 224, 224) 16.0ms 11.4ms 7.6ms
resnet50 float16 (16, 3, 224, 224) 9.0ms 6.3ms 3.6ms
convnext_large float32 (16, 3, 224, 224) 70.5ms 56.7ms 145.9ms
convnext_large float16 (16, 3, 224, 224) 35.4ms 28.3ms 64.8ms
vit_base_patch16_224 float32 (16, 3, 224, 224) 28.6ms 28.2ms 30.5ms
vit_large_patch14_clip_336 float32 (16, 3, 336, 336) 288.1ms 284.2ms 310.2ms
vit_large_patch14_clip_336 float16 (16, 3, 336, 336) 129.1ms 127.5ms error°

(error° : Expected input tensors to have type Half, found type float, maybe some forcing on Layernorm layers is applied and I should enable mixed precision somehow ?)

Everything goes well for the resnet50 model but for the convnext_large and vit models the torch_tensorrt.compile option get lower throughput and even fail in one case. And of course these models are the ones I am interested in 😅

Several questions :

  • Do you see any issue with the script I provided or how I ran it ?
  • How can I minimize the runtimes for the convnext_large and vit_large_patch14_clip_336 models ? Would using ONNX + TensorRT provide different results ? Is it related to how these models are implemented in timm ?

I can provide more details if needed (e.g. stack track),
Thanks for your help and support,
Simon


from time import time
import timm
import torch
import torch_tensorrt


def benchmark(model, inputs, compile_torch=False, compile_tensorrt=False, n_warmups=5, n_runs=100):
    """
    1. Optionally compile the model
    2. Warmup phase (n_warmups) 
    3. Benchmark phase (n_runs)
    """

    assert not (compile_torch and compile_tensorrt), "Cannot compile both torch and tensorrt"

    # 1. Compile
    if compile_tensorrt:
        model = torch_tensorrt.compile(model,
                                       inputs=[torch_tensorrt.Input(inputs.shape, dtype=inputs.dtype)],
                                       enabled_precisions={inputs.dtype})

    if compile_torch:
        model = torch.compile(model)

    # 2. Warmup
    for _ in range(n_warmups):
        with torch.no_grad():
            model(inputs)
    torch.cuda.synchronize()

    # 3. Benchmark
    runtimes = []
    for _ in range(n_runs):
        with torch.no_grad():
            start = time()
            model(inputs)
            torch.cuda.synchronize()
            runtimes.append(time() - start)
            
    # Print result
    print('*' * 80)
    print(f"Average: {1000*sum(runtimes)/n_runs:.2f}ms")
    print('*' * 80)


if __name__ == '__main__':

    # To run this script using the latest pytorch docker image, save it into a directory (DIR) and run:
    # docker run --gpus all --rm --volume $DIR:/app nvcr.io/nvidia/pytorch:23.04-py3 /bin/bash -c "pip install --pre timm && python /app/benchmark.py"

    # Parameters
    model_name = 'resnet50'
    shape = (16, 3, 224, 224)
    dtype = torch.float32

    # Prepare model and inputs
    model = timm.create_model(model_name)
    model.eval().cuda().type(dtype)
    inputs = torch.randn(*shape).type(dtype).cuda()

    benchmark(model, inputs)
    benchmark(model, inputs, compile_torch=True)
    benchmark(model, inputs, compile_tensorrt=True)
@SimJeg SimJeg added the question Further information is requested label Apr 26, 2023
@gs-olive
Copy link
Collaborator

gs-olive commented Apr 28, 2023

Hello - thanks for the detailed results - to answer the questions:

total_runtime = 0
for _ in range(n_runs):
        with torch.no_grad():
            start = time()
            model(inputs)
            torch.cuda.synchronize()
            end = time()
            total_runtime += end - start

avg_runtime = total_runtime/n_runs
  • I do not believe the above will substantially change the rankings, but it may improve the individual times.
  • Regarding the runtimes on convnext_large and vit_large_patch14_clip_336 models, one thing to try might be tuning the min_block_size argument which can be passed to torch_tensorrt.compile(...). This controls the size of the smallest block that can be converted to a TRT Engine (it defaults to 3), and increasing this value can lead to improved performance by decreasing the amount of segmentation in the graph.
  • Another option is to identify any unimplemented operators as listed in the debug logs, and file feature requests for those.
  • ONNX + TensorRT could potentially give different performance results; I am not certain in this case

I will evaluate the key models you have mentioned on our current main branch and update with any additional findings or suggestions.

@SimJeg
Copy link
Author

SimJeg commented May 2, 2023

Hi @gs-olive,

Thanks for your detailed answer. I updated the script and tried min_block_size=3,4,5,10, 50,100 for the convnext_large - float16 model but did not notice any difference / improvement. For this model, the only warning I have is about LayerNorm : WARNING: [Torch-TensorRT TorchScript Conversion Context] - Running layernorm after self-attention in FP16 may cause overflow. Forcing layernorm layers to run in FP32 precision can help with preserving accuracy.

I will try ONNX + TensorRT and give you feedback.

@SimJeg
Copy link
Author

SimJeg commented May 2, 2023

Following this blog post, I created a new script to use torch.onnx.export + onnxsim.simplify + trtexec. Here are a few results in float16 with (16, 3, size, size) shape :

model torch.compile torch_tensorrt.compile torch.onnx.export + trtexec
resnet50 9.0ms 3.6ms 3.6ms
convnext_large 28.3ms 64.8ms 48.3ms
vit_large_patch14_clip_336 127.5ms error error (update : 121.4ms)

The results are the ones reported directly by trtexec. The errors for the ViT models are related to this issue. Again, ConvNext runtimes are worse using TensorRT 😢

I checked I get similar results using torchtrtc to convert the TensorRT model to pytorch and use the same script as before. I initially wanted to provide a script running torchtrtc using os.system for a better benchmark, however I get an error I was not able to solve : ERROR: 3: [runtime.cpp::deserializeCudaEngineEx::96] Error Code 3: Internal Error (Cannot deserialize with an empty memory buffer.). Here is the full script :

"""
Script to benchmark a model using ONNX export + TensorRT
To run this script using the latest pytorch docker image, save it into a directory (DIR) and run:
docker run --gpus all --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --volume $DIR:/app nvcr.io/nvidia/pytorch:23.04-py3 /bin/bash -c "pip install --pre timm onnx onnxruntime onnx-simplifier && python /app/benchmark.py"
"""

import os
import warnings
import time
from tempfile import TemporaryDirectory

import torch
import torch_tensorrt # import required to load the .ts model
import timm
import onnx
from onnxsim import simplify


# Parameters
n_warmups = 5
n_runs = 100
opset_version = 18 

# Load model
model_name = 'resnet50'
shape = (16, 3, 224, 224)
model = timm.create_model(model_name, exportable=True)
model.eval().cuda().half()

with TemporaryDirectory() as tmpdir:
    name = lambda ext: f'{tmpdir}/{model_name}.{ext}'
    

    # 1. Compile model using ONNX export + TensorRT

    # Export to ONNX
    with torch.inference_mode(), torch.autocast("cuda"):
        inputs = torch.randn(*shape, dtype=torch.half, device='cuda')
        torch.onnx.export(model, inputs, name('onnx'), export_params=True, opset_version=opset_version,
                        do_constant_folding=True, input_names = ['input_0'], output_names = ['output_0'])

    # Simplify using onnx-simplifier
    model = onnx.load(name('onnx'))
    simplified_model, check = simplify(model)
    if not check:
        warnings.warn('Simplified ONNX model could not be validated, using original ONNX model')
    else:
        onnx.save(simplified_model, name('onnx'))

    # Convert to TensorRT using default settings
    os.system(f'trtexec --onnx={name("onnx")} --saveEngine={model_name}.trt --fp16')

    exit() # The command below is not working yet
    os.system(f'torchtrtc {name("trt")} {name("ts")} --embed-engine --device-type=gpu')

    # 2. Get runtime

    model = torch.jit.load(name('ts'))
    model.eval().half().cuda()
    inputs = torch.randn(*shape, dtype=torch.half, device='cuda')

    # Warmup
    for _ in range(n_warmups):
        with torch.no_grad():
            model(inputs)
    torch.cuda.synchronize()

    # Benchmark
    runtimes = []
    for _ in range(n_runs):
        with torch.no_grad():
            start = time()
            model(inputs)
            torch.cuda.synchronize()
            runtimes.append(time() - start)
            
    # Print result
    print('*' * 80)
    print(f"Average: {1000*sum(runtimes)/n_runs:.2f}ms")
    print('*' * 80)

@SimJeg
Copy link
Author

SimJeg commented May 2, 2023

I just noticed the exportable kwarg in timm.create_model which solves the issue I mentionned for ViT (see this discussion). Here are updated results for the ViT models adding exportable=True. Slightly better than torch.compile, but not the x2 I expected (as reported in NVIDIA FasterTransformer for instance)

model torch.compile torch_tensorrt.compile torch.onnx.export + trtexec
vit_base_patch16_224 13.9ms error 13.3ms
vit_large_patch14_clip_336 127.5ms error 121.4ms

@SimJeg
Copy link
Author

SimJeg commented May 7, 2023

For the convnext_large model, using the kwargs conv_mlp=True in timm.create_model allows to fill the gap between torch.compile and torch_tensorrt.compile. It changes how the forward pass is performed (see here) by removing 2 permutations. Still, TensorRT does not bring improvements :(

@github-actions
Copy link

github-actions bot commented Aug 6, 2023

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Copy link

github-actions bot commented Nov 7, 2023

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@samar-khanna
Copy link

@SimJeg hi, did you manage to figure out the reason for the slower runtime for convnext when using TRT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
No Activity question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants