Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 149 additions & 94 deletions py/torch_tensorrt/dynamo/_compiler.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,9 @@ def native_group_norm(

shape = [1, group] + [1] * (rank - 2)

weight_torch = torch.ones(shape)
bias_torch = torch.zeros(shape)
with unset_fake_temporarily():
weight_torch = torch.ones(shape)
bias_torch = torch.zeros(shape)

weight_one = get_trt_tensor(ctx, weight_torch, f"{name}_weight_one", input.dtype)
bias_zero = get_trt_tensor(ctx, bias_torch, f"{name}_bias_zero", input.dtype)
Expand Down
13 changes: 7 additions & 6 deletions tools/perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ This is a comprehensive Python benchmark suite to run perf runs using different
2. Torch-TensorRT [Torchscript]
3. Torch-TensorRT [Dynamo]
4. Torch-TensorRT [torch_compile]
5. TensorRT
5. Torch Inductor
6. ONNX-TensorRT


## Prerequisite
Expand Down Expand Up @@ -42,8 +43,8 @@ Benchmark scripts depends on following Python packages in addition to requiremen

Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module

* `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
* `--backends` : Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, onnx_trt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (pairing with `--is_trt_engine`)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend)
* `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
Expand All @@ -60,16 +61,16 @@ Eg:
```
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vgg16_torch.pt \
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
--precision fp32,fp16 \
--inputs "(1, 3, 224, 224)@fp32" \
--batch_size 1 \
--backends torch,ts_trt,dynamo,torch_compile,tensorrt \
--backends torch,ts_trt,dynamo,torch_compile,inductor,onnx_trt \
--report "vgg_perf_bs1.txt"
```

Note:

1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.

### Example models

Expand Down
42 changes: 36 additions & 6 deletions tools/perf/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ python hub.py

batch_sizes=(1 2 4 8 16 32 64 128 256)
large_model_batch_sizes=(1 2 4 8 16 32 64)
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "tensorrt")
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "tensorrt")
backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "onnx_trt")
backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "onnx_trt")


# Benchmark VGG16 model
Expand Down Expand Up @@ -107,18 +107,48 @@ do
done
done

# Benchmark Stable Diffusion UNet model
echo "Benchmarking SD UNet model"
# Benchmark Stable Diffusion v1.4 UNet model
echo "Benchmarking SD-v1.4 UNet model"
for bs in ${large_model_batch_sizes[@]}
do
for backend in ${backends_no_torchscript[@]}
do
python perf_run.py --model_torch sd_unet \
python perf_run.py --model_torch sd1.4_unet \
--precision fp16 --inputs="(${bs}, 4, 64, 64);(${bs});(${bs}, 1, 768)" \
--batch_size ${bs} \
--truncate \
--backends ${backend} \
--report "sd_unet_perf_bs${bs}_backend_${backend}.csv"
--report "sd1.4_unet_perf_bs${bs}_backend_${backend}.csv"
done
done

# Benchmark Stable Diffusion v2.1 UNet model
echo "Benchmarking SD-v2.1 UNet model"
for bs in ${large_model_batch_sizes[@]}
do
for backend in ${backends_no_torchscript[@]}
do
python perf_run.py --model_torch sd2.1_unet \
--precision fp16 --inputs="(${bs}, 4, 64, 64);(${bs});(${bs}, 1, 1024)" \
--batch_size ${bs} \
--truncate \
--backends ${backend} \
--report "sd2.1_unet_perf_bs${bs}_backend_${backend}.csv"
done
done

# Benchmark Stable Diffusion v2.1 VAE decoder model
echo "Benchmarking SD-v2.1 VAE decoder model"
for bs in ${large_model_batch_sizes[@]}
do
for backend in ${backends_no_torchscript[@]}
do
python perf_run.py --model_torch sd2.1_vae_decoder \
--precision fp16 --inputs="(${bs}, 4, 64, 64)" \
--batch_size ${bs} \
--truncate \
--backends ${backend} \
--report "sd2.1_vae_decoder_perf_bs${bs}_backend_${backend}.csv"
done
done

Expand Down
33 changes: 30 additions & 3 deletions tools/perf/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def BertInputs():
return [tokens_tensor, segments_tensors]


def StableDiffusionUnet():
def StableDiffusion1_4_Unet():
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
Expand All @@ -35,7 +35,25 @@ def StableDiffusionUnet():
return pipe.unet


def UNet():
def StableDiffusion2_1_Unet():
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
)
return pipe.unet


def StableDiffusion2_1_VaeDecoder():
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
)
return pipe.vae.decoder


def MonaiUNet():
from monai.networks.nets import UNet

model = UNet(
Expand All @@ -46,4 +64,13 @@ def UNet():
strides=(2, 2),
num_res_units=2,
)
return model.eval().cuda()
return model


def GoogleViTForImageClassification():
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224", torch_dtype=torch.float16
)
return model
Loading
Loading