# LightlyTrain - Semantic Segmentation - ONNX and TensorRT Export

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/semantic_segmentation_export.ipynb)

This notebook demonstrates how to export a semantic segmentation model to ONNX and TensorRT.

The notebook covers the following steps:
1. Install LightlyTrain
2. Export a trained EoMT model to ONNX
3. Export a trained EoMT model to TensorRT
4. Run inference with the TensorRT engine

> **Important**: When running on Google Colab make sure to select a GPU runtime. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator.

## Installation

LightlyTrain can be installed directly via `pip`:

In [None]:
!pip install "lightly-train[onnx,onnxruntime,onnxslim]"

## Export to ONNX

### Load the model weights

Then load the model with LightlyTrain's `load_model` function. This will automatically download the model weights and load the model.

In [None]:
import lightly_train

model = lightly_train.load_model("dinov3/vits16-eomt-coco")

### Download an example image

Download an example image for inference with the following command:

In [None]:
!wget -O image.jpg http://images.cocodataset.org/val2017/000000039769.jpg

### Preprocessing

In [None]:
import torchvision.transforms.v2 as T
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor

# Load image with PIL.
image_pil = Image.open("image.jpg").convert("RGB")

# Convert PIL image to tensor for plotting.
image_tensor = pil_to_tensor(image_pil)

# Define pre-processing transforms.
w, h = image_pil.size
transforms = T.Compose(
    [
        T.Resize((model.image_size)),
        T.ToTensor(),
        T.Normalize(**model.image_normalize),
    ]
)

# Apply transforms for ONNX and TensorRT inference.
image_tensor_transformed = transforms(image_pil)[None]

### Get the model predictions for reference

We define a helper function to visualize the predictions.
The function will be used to compare the predictions from PyTorch, ONNX and
TensorRT models.

In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_segmentation_masks


def visualize_segmentations(image, masks):
    masks = torch.stack([masks == class_id for class_id in masks.unique()])
    image_with_masks = draw_segmentation_masks(image, masks, alpha=1.0)
    fig, axs = plt.subplots(1, 2, figsize=(12, 8))
    axs[0].imshow(image.permute(1, 2, 0))
    axs[0].axis("off")
    axs[1].imshow(image_with_masks.permute(1, 2, 0))
    axs[1].axis("off")
    fig.show()

In [None]:
# Get predictions from the PyTorch model.
masks = model.predict(image_tensor)

# Visualize predictions from the PyTorch model.
visualize_segmentations(image_tensor, masks=masks)

### Export the model to ONNX

In [None]:
# Export the PyTorch model to ONNX.
model.export_onnx(
    out="model.onnx",
    # precision="fp16", # Export model with FP16 weights for smaller size and faster inference.
)

See [`export_onnx`](https://docs.lightly.ai/train/stable/python_api/lightly_train.html#lightly_train._task_models.dinov3_eomt_semantic_segmentation.task_model.DINOv3EoMTSemanticSegmentation.export_onnx) for all available options when exporting to ONNX.


### Run inference with the ONNX model

In [None]:
import onnxruntime as ort
import torch.nn.functional as F

# Create an ONNX Runtime session.
sess = ort.InferenceSession("model.onnx")

# Get expected input dtype.
input_dtype = sess.get_inputs()[0].type
input_dtype_numpy = {
    "tensor(float)": "float32",
    "tensor(float16)": "float16",
}[input_dtype]


# Run inference.
masks_onnx = sess.run(
    output_names=None,
    input_feed={
        "images": image_tensor_transformed.numpy().astype(input_dtype_numpy),
    },
)[0][0]


# Resize ONNX predictions to original image size if necessary.
# This is only needed if the original image size is different from the model input size.
masks_onnx = torch.from_numpy(masks_onnx)
masks_onnx = (
    F.interpolate(masks_onnx.float()[None, None, ...], size=(h, w), mode="nearest")
    .squeeze(0, 1)
    .long()
)

# Visualize predictions from the ONNX model.
visualize_segmentations(image_tensor, masks=masks_onnx)

**Note**: There might be small visual differences between the masks predicted by
the PyTorch model with `model.predict` and the ONNX model. This is because the PyTorch
model can process the input image at a higher resolution. Instead of resizing the input
image to a fixed size (e.g., 512x512), it resizes the shorter side of the image to a
fixed size (e.g., 512 pixels) while keeping the aspect ratio. The ONNX model
on the other hand, requires a fixed input size (e.g., 512x512).

## Export to TensorRT

### Requirements

TensorRT is not part of LightlyTrainâ€™s dependencies and must be installed separately. Installation depends on your OS, Python
version, GPU, and NVIDIA driver/CUDA setup. See the [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html) for more details.

On CUDA 12.x systems you can often install the Python package via:

In [None]:
!pip install tensorrt-cu12

In [None]:
# Get the TensorRT engine.
model.export_tensorrt(
    out="model.trt",
    # precision="fp16", # Export model with FP16 weights for smaller size and faster inference.
)

See [`export_tensorrt`](https://docs.lightly.ai/train/stable/python_api/lightly_train.html#lightly_train._task_models.dinov3_eomt_semantic_segmentation.task_model.DINOv3EoMTSemanticSegmentation.export_tensorrt) for all available options when exporting to TensorRT.


### Run inference with the TensorRT engine

In [None]:
import numpy as np
import tensorrt as trt
import torch


class TRT:
    def __init__(self, engine_path: str, device: str = "cuda:0", verbose: bool = False):
        self.device = torch.device(device)
        logger = trt.Logger(trt.Logger.VERBOSE if verbose else trt.Logger.INFO)
        trt.init_libnvinfer_plugins(logger, "")
        runtime = trt.Runtime(logger)

        with open(engine_path, "rb") as f:
            self.engine = runtime.deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()

        io_names = [
            self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)
        ]
        self.in_names = [
            n
            for n in io_names
            if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT
        ]
        self.out_names = [
            n
            for n in io_names
            if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
        ]

        self.buffers = {}
        self.bindings = []
        for name in io_names:
            shape = tuple(self.context.get_tensor_shape(name))
            np_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
            torch_dtype = torch.from_numpy(np.empty((), dtype=np_dtype)).dtype
            buffer = torch.empty(
                shape, device=self.device, dtype=torch_dtype
            ).contiguous()
            self.buffers[name] = buffer
            self.bindings.append(buffer.data_ptr())

    @torch.no_grad()
    def __call__(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        for name in self.in_names:
            self.buffers[name].copy_(inputs[name].to(self.device))
        if not self.context.execute_v2(self.bindings):
            raise RuntimeError("TensorRT execution failed")
        return {name: self.buffers[name] for name in self.out_names}

In [None]:
# Instantiate the TensorRT model.
trt_model = TRT("model.trt")

# Run inference with the TensorRT model.
outputs_trt = trt_model({"images": image_tensor_transformed})
masks_trt = outputs_trt["masks"][0]

# Resize TensorRT predictions to original image size if necessary.
masks_trt = (
    F.interpolate(masks_trt.float()[None, None, ...], size=(h, w), mode="nearest")
    .squeeze(0, 1)
    .long()
)

# Visualize predictions from the TensorRT model.
visualize_segmentations(image_tensor, masks=masks_trt.cpu())