This notebook illustrates how to
1. Compile a `transformers` model to ONNX format (method is generalizable to custom models)
1. Optimize the ONNX model to TensorRT format (which has synergy with Triton Inference Server)

How to convert a model to ONNX

<!-- 1. If supported by `optimum`, use it!
1. If a complex torch module, follow best-practice example in `optimum` to compile. -->

```.mermaid
graph LR
    F{Is complex model?}
    F -- yes --> A
    F -- no --> G[use torch.onnx.export]
    A{Is complex model NOT supported by optimum?}
    A -- yes --> C[Find most similar model in optimum] --> D[Follow best-practice example]
    A -- no --> E[Use optimum]
```

(An alternative to using `optimum`'s function is to directly use `torch.onnx.export`, but there may be other gotchas that could have been avoided as they are handled by `optimum`.)

Refer to [optimum cotribution guide](https://github.com/huggingface/optimum/blob/d21256c2964945fc3fe4623f7befb21082b69a25/docs/source/exporters/onnx/usage_guides/contribute.mdx#L56)

## Defns & Imports

In [2]:
from pathlib import Path

import numpy as np
import onnxruntime as ort
import torch
from optimum.exporters import TasksManager
from optimum.exporters.onnx import export
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from src.config import models_dir
from src.utils_compile import (
    DefaultConfigInput,
    check_model_name,
    generate_config_pbtxt,
)

In [3]:
HF_MODEL_NAME = "ProsusAI/finbert"
E2E_MODEL_NAME = check_model_name("finbert")
MODEL_NAME = f"{E2E_MODEL_NAME}-model"
MODEL_VERSION = 1
MODEL_DIR = models_dir / f"{MODEL_NAME}"
MODEL_VERSION_DIR = MODEL_DIR / f"{MODEL_VERSION}"
PTH_MODEL_ONNX = MODEL_VERSION_DIR / "model.onnx"
PTH_CONFIG = MODEL_DIR / "config.pbtxt"

# TensorRT model
MODEL_NAME_TRT = f"{E2E_MODEL_NAME}-trt-model"
MODEL_VERSION_TRT = 1
MODEL_TRT_DIR = models_dir / f"{MODEL_NAME_TRT}"
MODEL_VERSION_TRT_DIR = MODEL_TRT_DIR / f"{MODEL_VERSION_TRT}"
PTH_MODEL_TRT = MODEL_VERSION_TRT_DIR / "model.plan"

# Tokenizer "model"
TOKENIZER_MODEL_DIR = models_dir / f"{E2E_MODEL_NAME}-tokenizer"
TOKENIZER_MODEL_VERSION = 1
TOKENIZER_MODEL_VERSION_DIR = TOKENIZER_MODEL_DIR / f"{TOKENIZER_MODEL_VERSION}"
PTH_TOKENIZER_MODEL_DATA = TOKENIZER_MODEL_VERSION_DIR / "tokenizer_data"

# End-to-end model (pipeline model; a.k.a. "ensemble" model in Nvidia terminology)
E2E_MODEL_DIR = models_dir / E2E_MODEL_NAME
E2E_PTH_CONFIG = E2E_MODEL_DIR / "config.pbtxt"


MODEL_VERSION_DIR.mkdir(parents=True, exist_ok=True)
MODEL_VERSION_TRT_DIR.mkdir(parents=True, exist_ok=True)
TOKENIZER_MODEL_VERSION_DIR.mkdir(parents=True, exist_ok=True)
E2E_MODEL_DIR.mkdir(parents=True, exist_ok=True)

# onnx_opset = onnx_config.DEFAULT_ONNX_OPSET
onnx_opset = 20  # highest supported opset version by `torch.onnx.export()` for torch==2.5.1

## Export Torch model to ONNX

(and save tokenizer as "model" too)

In [5]:
model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)

## Save tokenizer as a "model"

tokenizer.save_pretrained(PTH_TOKENIZER_MODEL_DATA)


## Compile the ONNX model

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
    "onnx",
    model,
    task="text-classification",  # NOTE: change to others where applicable e.g. "summarization" (Refer to: `TasksManager.get_all_tasks()`)
    library_name="transformers",  # NOTE: change to others where applicable e.g. "sentence_transformers"
)
onnx_config = onnx_config_constructor(model.config)

onnx_config.int_dtype = "int32"  # We force int32 since it is unlikely any of {input_ids, attention_mask, token_type_ids} will have values > 2^31

inputs = tokenizer("Stocks rallied and the British pound gained.", return_tensors="pt")
with torch.no_grad():
    outputs_orig = model(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"])


onnx_inputs, onnx_outputs = export(model, onnx_config, PTH_MODEL_ONNX, opset=onnx_opset)


print("running inference via ONNX runtime")


ort_session = ort.InferenceSession(PTH_MODEL_ONNX)
outputs = ort_session.run(
    None,
    {
        "input_ids": inputs["input_ids"].numpy().astype(np.int32),
        "attention_mask": inputs["attention_mask"].numpy().astype(np.int32),
        "token_type_ids": inputs["token_type_ids"].numpy().astype(np.int32),
    },
)
assert ((outputs[0] - outputs_orig.logits.numpy()) < 1e-3).all()

# print(f"logits: {outputs_orig.logits}")
# print(f"probability (pos, neg, neutral): {torch.softmax(outputs_orig.logits, dim=1)}")

# PTH_MODEL_ONNX.unlink(missing_ok=True)

Using framework PyTorch: 2.5.1+cu124
Overriding 1 configuration item(s)
	- use_cache -> False


running inference via ONNX runtime


## [WIP] Auto-generate `config.pbtxt` files

In [None]:
# ## Generate config.pbtxt
# TODO Automation still work in progress
# _ = generate_config_pbtxt(
#     DefaultConfigInput(
#         model_name=MODEL_NAME,
#         max_batch_size=16,  # MUST be >= preferred_batch_size
#         preferred_batch_size=[8, 16],
#         max_queue_delay_microseconds=100,
#     ),
#     pth_config=PTH_CONFIG,
# )

## Compile ONNX to TensorRT

In [6]:
import tensorrt as trt

# Logger for TensorRT
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

MAX_BATCH_SIZE = 16  # Maximum batch size for inference
MIN_SEQUENCE_LENGTH = 8  # Smallest valid sequence length
OPT_SEQUENCE_LENGTH = 128  # Typical sequence length
MAX_SEQUENCE_LENGTH = 512  # Maximum valid sequence length
FP16_MODE = True  # Enable FP16 precision (if supported)
INT8_MODE = False  # Enable INT8 precision (if calibration data is available)
WORKSPACE_SIZE = 1 << 30  # 1GB workspace size

In [7]:
## Builds a TensorRT engine from an ONNX model with dynamic shape support

with (
    trt.Builder(TRT_LOGGER) as builder,
    builder.create_network(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) as network,
    trt.OnnxParser(network, TRT_LOGGER) as parser,
):

    # Configure builder
    config = builder.create_builder_config()
    try:
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, WORKSPACE_SIZE)
    except AttributeError:
        config.max_workspace_size = WORKSPACE_SIZE

    if FP16_MODE:
        config.set_flag(trt.BuilderFlag.FP16)
    if INT8_MODE:
        config.set_flag(trt.BuilderFlag.INT8)

    # Parse ONNX model
    with PTH_MODEL_ONNX.open("rb") as model:
        if not parser.parse(model.read()):
            print("Failed to parse the ONNX file:")
            for error in range(parser.num_errors):
                print(parser.get_error(error))

    # Handle dynamic shapes with an optimization profile
    profile = builder.create_optimization_profile()
    dynamic_inputs = ["input_ids", "attention_mask", "token_type_ids"]

    for input_name in dynamic_inputs:
        input_tensor = network.get_input(dynamic_inputs.index(input_name))
        tensor_shape = input_tensor.shape

        # Handle dynamic dimensions
        if tensor_shape[0] == -1:
            profile.set_shape(
                input_name,
                (1, MIN_SEQUENCE_LENGTH),  # Min: Batch=1, Min Seq Len
                (MAX_BATCH_SIZE // 2, OPT_SEQUENCE_LENGTH),  # Opt: Half Batch, Opt Seq Len
                (MAX_BATCH_SIZE, MAX_SEQUENCE_LENGTH),  # Max: Full Batch, Max Seq Len
            )
        else:
            print(f"Warning: Input {input_name} does not have dynamic dimensions.")

        print(
            f"Set profile for {input_name}: Min=(1, {MIN_SEQUENCE_LENGTH}), "
            f"Opt=({MAX_BATCH_SIZE // 2}, {OPT_SEQUENCE_LENGTH}), "
            f"Max=({MAX_BATCH_SIZE}, {MAX_SEQUENCE_LENGTH})"
        )

    config.add_optimization_profile(profile)

    # Build the serialized engine
    print("Building TensorRT serialized engine. This may take a while...")
    serialized_engine = builder.build_serialized_network(network, config)
    if serialized_engine:
        print("Serialized engine built successfully!")
        with open(PTH_MODEL_TRT, "wb") as f:
            f.write(serialized_engine)
        print(f"Engine saved at {PTH_MODEL_TRT}")
    else:
        print("Failed to build the serialized engine.")

    # Deserialize engine for verification (Optional)
    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(serialized_engine)
    if engine:
        print("Engine deserialized successfully!")
    else:
        print("Failed to deserialize engine.")

Set profile for input_ids: Min=(1, 8), Opt=(8, 128), Max=(16, 512)
Set profile for attention_mask: Min=(1, 8), Opt=(8, 128), Max=(16, 512)
Set profile for token_type_ids: Min=(1, 8), Opt=(8, 128), Max=(16, 512)
Building TensorRT serialized engine. This may take a while...
Serialized engine built successfully!
Engine saved at /home/ss/work/test-triton/models/finbert-trt-model/1/model.plan
Engine deserialized successfully!
