The purpose of this example is to demostrate the overall flow of lowering a PyTorch model
to TensorRT conveniently with lower.py. We integrated the transformation process including `TRTInterpreter`, `TRTModule`, pass optimization into the `lower_to_trt` API, users are encouraged to check the docstring of the API and tune it to meet your needs.

In [4]:
import typing as t
from copy import deepcopy
from dataclasses import dataclass, field, replace

import torch
import torchvision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision

Specify the `configuration` class used for FX path lowering and benchmark. To extend, add a new configuration field to this class, and modify the lowering or benchmark behavior in `run_configuration_benchmark()` correspondingly. It automatically stores all its values to a `Result` dataclass.   
`Result` is another dataclass that holds raw essential benchmark result values like Batch size, QPS, accuracy, etc..


In [22]:
@dataclass
class Configuration:
    # number of inferences to run
    batch_iter: int

    # Input batch size
    batch_size: int

    # Friendly name of the configuration
    name: str = ""

    # Whether to apply TRT lowering to the model before benchmarking
    trt: bool = False

    # Whether to apply engine holder to the lowered model
    jit: bool = False

    # Whether to enable FP16 mode for TRT lowering
    fp16: bool = False

    # Relative tolerance for accuracy check after lowering. -1 means do not
    # check accuracy.
    accuracy_rtol: float = -1  # disable
    
@dataclass
class Result:
    module: torch.nn.Module = field(repr=False)
    input: t.Any = field(repr=False)
    conf: Configuration
    time_sec: float
    accuracy_res: t.Optional[bool] = None

    @property
    def time_per_iter_ms(self) -> float:
        return self.time_sec * 1.0e3

    @property
    def qps(self) -> float:
        return self.conf.batch_size / self.time_sec

    def format(self) -> str:
        return (
            f"== Benchmark Result for: {self.conf}\n"
            f"BS: {self.conf.batch_size}, "
            f"Time per iter: {self.time_per_iter_ms:.2f}ms, "
            f"QPS: {self.qps:.2f}, "
            f"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})"
        )

Run FX path lowering and benchmark the given model according to the specified benchmark configuration. Prints the benchmark result for each configuration at the end of the run. `benchmark_torch_function` is the actual function that computes the fixed number of iterations of functions runs.
The FX path lowering and TensorRT engine creation is integrated into `low_to_trt()` API which is defined in `fx/lower.py` file.
It is good to list it out and show the usage of it. It takes in original module, input and lowering setting, run lowering workflow to turn module into a executable TRT engine 
```
def lower_to_trt(
    module: nn.Module,
    input: ,
    max_batch_size: int = 2048,
    max_workspace_size=1 << 25,
    explicit_batch_dimension=False,
    lower_precision=LowerPrecision.FP16,
    verbose_log=False,
    timing_cache_prefix="",
    save_timing_cache=False,
    cuda_graph_batch_size=-1,
    dynamic_batch=False,
) -> nn.Module:
``` 

    Args:
        module: Original module for lowering.
        input: Input for module.
        max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
        max_workspace_size: Maximum size of workspace given to TensorRT.
        explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
        lower_precision: lower_precision config given to TRTModule.
        verbose_log: Enable verbose log for TensorRT if set True.
        timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
        save_timing_cache: Update timing cache with current timing cache data if set to True.
        cuda_graph_batch_size: Cuda graph batch size, default to be -1.
        dynamic_batch: batch dimension (dim=0) is dynamic.

    Returns:
        A torch.nn.Module lowered by TensorRT.
We testd a resnet18 network with input size of [128,3,224,224] for [Batch, Channel, Width, Height]

In [21]:
test_model = torchvision.models.resnet18(pretrained=True)
input = [torch.rand(128, 3, 224, 224)]   
benchmark(test_model, input, 50, 128)

def benchmark_torch_function(iters: int, f, *args) -> float:
    """Estimates the average time duration for a single inference call in second

    If the input is batched, then the estimation is for the batches inference call.
    """
    with torch.inference_mode():
        f(*args)
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    print("== Start benchmark iterations")
    with torch.inference_mode():
        start_event.record()
        for _ in range(iters):
            f(*args)
        end_event.record()
    torch.cuda.synchronize()
    print("== End benchmark iterations")
    return (start_event.elapsed_time(end_event) * 1.0e-3) / iters


def run_configuration_benchmark(
    module,
    input,
    conf: Configuration,
) -> Result:
    print(f"=== Running benchmark for: {conf}", "green")
    time = -1.0

    if conf.fp16:
        module = module.half()
        input = [i.half() for i in input]

    if not conf.trt:
        # Run eager mode benchmark
        time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
    elif not conf.jit:
        # Run lowering eager mode benchmark
        lowered_module = lower_to_trt(
            module,
            input,
            max_batch_size=conf.batch_size,
            lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
        )
        time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
    else:
        print("Lowering with JIT is not available!", "red")

    result = Result(module=module, input=input, conf=conf, time_sec=time)
    return result

@torch.inference_mode()
def benchmark(
    model,
    inputs,
    batch_iter: int,
    batch_size: int,
) -> None:
    model = model.cuda().eval()
    inputs = [x.cuda() for x in inputs]

    # benchmark base configuration
    conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)

    configurations = [
        # Baseline
        replace(conf, name="CUDA Eager", trt=False),
        # FP32
        replace(
            conf,
            name="TRT FP32 Eager",
            trt=True,
            jit=False,
            fp16=False,
            accuracy_rtol=1e-3,
        ),
        # FP16
        replace(
            conf,
            name="TRT FP16 Eager",
            trt=True,
            jit=False,
            fp16=True,
            accuracy_rtol=1e-2,
        ),
    ]

    results = [
        run_configuration_benchmark(deepcopy(model), inputs, conf_)
        for conf_ in configurations
    ]

    for res in results:
        print(res.format())

=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='CUDA Eager', trt=False, jit=False, fp16=False, accuracy_rtol=-1) green
== Start benchmark iterations


== End benchmark iterations
=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='TRT FP32 Eager', trt=True, jit=False, fp16=False, accuracy_rtol=0.001) green


== Log pass <function fuse_permute_matmul at 0x7fbdfcc9f1f0> before/after graph to /tmp/tmpaayayg72
== Log pass <function fuse_permute_linear at 0x7fbe36555f70> before/after graph to /tmp/tmpdw_pq71j

Supported node types in the model:
acc_ops.conv2d: ((), {'input': torch.float32, 'weight': torch.float32})
acc_ops.batch_norm: ((), {'input': torch.float32, 'running_mean': torch.float32, 'running_var': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.relu: ((), {'input': torch.float32})
acc_ops.max_pool2d: ((), {'input': torch.float32})
acc_ops.add: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.adaptive_avg_pool2d: ((), {'input': torch.float32})
acc_ops.flatten: ((), {'input': torch.float32})
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})

Unsupported node types in the model:

Got 1 acc subgraphs and 0 non-acc subgraphs


I0627 233146.650 fx2trt.py:190] Run Module elapsed time: 0:00:00.244369


I0627 233206.570 fx2trt.py:241] Build TRT engine elapsed time: 0:00:19.918630


== Start benchmark iterations


== End benchmark iterations
=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='TRT FP16 Eager', trt=True, jit=False, fp16=True, accuracy_rtol=0.01) green


== Log pass <function fuse_permute_matmul at 0x7fbdfcc9f1f0> before/after graph to /tmp/tmpnoeblgd5
== Log pass <function fuse_permute_linear at 0x7fbe36555f70> before/after graph to /tmp/tmpyb1egsof

Supported node types in the model:
acc_ops.conv2d: ((), {'input': torch.float16, 'weight': torch.float16})
acc_ops.batch_norm: ((), {'input': torch.float16, 'running_mean': torch.float16, 'running_var': torch.float16, 'weight': torch.float16, 'bias': torch.float16})
acc_ops.relu: ((), {'input': torch.float16})
acc_ops.max_pool2d: ((), {'input': torch.float16})
acc_ops.add: ((), {'input': torch.float16, 'other': torch.float16})
acc_ops.adaptive_avg_pool2d: ((), {'input': torch.float16})
acc_ops.flatten: ((), {'input': torch.float16})
acc_ops.linear: ((), {'input': torch.float16, 'weight': torch.float16, 'bias': torch.float16})

Unsupported node types in the model:

Got 1 acc subgraphs and 0 non-acc subgraphs


I0627 233208.996 fx2trt.py:190] Run Module elapsed time: 0:00:00.217076


I0627 233244.147 fx2trt.py:241] Build TRT engine elapsed time: 0:00:35.150950


== Start benchmark iterations


== End benchmark iterations
== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='CUDA Eager', trt=False, jit=False, fp16=False, accuracy_rtol=-1)
BS: 128, Time per iter: 15.00ms, QPS: 8530.72, Accuracy: None (rtol=-1)
== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='TRT FP32 Eager', trt=True, jit=False, fp16=False, accuracy_rtol=0.001)
BS: 128, Time per iter: 7.95ms, QPS: 16098.45, Accuracy: None (rtol=0.001)
== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='TRT FP16 Eager', trt=True, jit=False, fp16=True, accuracy_rtol=0.01)
BS: 128, Time per iter: 4.36ms, QPS: 29365.31, Accuracy: None (rtol=0.01)
