# Custom Layer Usage Example - multiclass_nms Layer

## Overview

In this tutorial we will illustrate how to integrate a custom layer with model quantization using the MCT library.
Using a simple object detection model as an example, we will apply post-training quantization, then incorporate a custom NMS layer into the quantized model.

## Setup

### Install & import relevant packages

In [78]:
!pip install -q torch
!pip install onnx
!pip install -q model_compression_toolkit


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Iterator, List
import model_compression_toolkit as mct
from sony_custom_layers.pytorch.nms import multiclass_nms
from sony_custom_layers.pytorch.nms.nms_with_indices import multiclass_nms_with_indices

## Model Quantization

### Create Model Instance

We will start with creating a simple object-detection model instance as an example. You can replace the model with your own model, or use a pre-trained model (Make sure the model is supported by MCT library). 

In [80]:
class ObjectDetector(nn.Module):
    def __init__(self, num_classes=2, max_detections=20):
        super().__init__()
        self.max_detections = max_detections

        self.backbone = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.bbox_reg = nn.Conv2d(32, 4 * max_detections, kernel_size=1)
        self.class_reg = nn.Conv2d(32, num_classes * max_detections, kernel_size=1)

    def forward(self, x):
        batch_size = x.size(0)
        features = self.backbone(x)
        H_prime = features.shape[2]
        W_prime = features.shape[3]
        
        bbox = self.bbox_reg(features)
        bbox = bbox.view(batch_size, self.max_detections, 4, H_prime * W_prime).mean(dim=3)
        class_probs = self.class_reg(features).view(batch_size, self.max_detections, -1, H_prime * W_prime)
        class_probs = F.softmax(class_probs.mean(dim=2), dim=2)

        return bbox, class_probs

model = ObjectDetector()
model.eval()

ObjectDetector(
  (backbone): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (bbox_reg): Conv2d(32, 80, kernel_size=(1, 1), stride=(1, 1))
  (class_reg): Conv2d(32, 40, kernel_size=(1, 1), stride=(1, 1))
)

### Post-Training Quantization using Model Compression Toolkit

We're all set to use MCT's post-training quantization. 
To begin, we'll define a representative dataset generator. Please note that for demonstration purposes, we will generate random data of the desired image shape instead of using real images.  
Then, we will apply PTQ on our model using the dataset generator we have created.

In [81]:
NUM_ITERS = 20
BATCH_SIZE = 32

def get_representative_dataset(n_iter: int):
    """
    This function creates a representative dataset generator. The generator yields numpy
        arrays of batches of shape: [Batch, C, H, W].
    Args:
        n_iter: number of iterations for MCT to calibrate on
    Returns:
        A representative dataset generator
    """
    def representative_dataset() -> Iterator[List]:
        for _ in range(n_iter):
            yield [torch.rand(BATCH_SIZE, 3, 64, 64)]

    return representative_dataset

representative_data_generator = get_representative_dataset(n_iter=NUM_ITERS)

quant_model, _ = mct.ptq.pytorch_post_training_quantization(model, representative_data_gen=representative_data_generator)
print('Quantized model is ready')

Statistics Collection: 20it [00:02,  6.80it/s]



Running quantization parameters search. This process might take some time, depending on the model size and the selected quantization methods.



Calculating quantization parameters: 100%|██████████| 14/14 [00:00<00:00, 50.77it/s]

Weights_memory: 8880.0, Activation_memory: 65536.0, Total_memory: 74416.0, BOPS: 569083166720

Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.
Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:
FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md
Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md
Quantized model is ready





##  Custom Layer Stitching

Now that we have a quantized model, we can add it a custom layer. In our example we will add NMS layer by creating a model wrapper that applies NMS over the quantized model output. You can use this wrapper for your own model.

In [82]:
class PostProcessWrapper(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 score_threshold: float = 0.001,
                 iou_threshold: float = 0.7,
                 max_detections: int = 300):

        super(PostProcessWrapper, self).__init__()
        self.model = model
        self.score_threshold = score_threshold
        self.iou_threshold = iou_threshold
        self.max_detections = max_detections

    def forward(self, images):
        # model inference
        outputs = self.model(images)

        boxes = outputs[0]
        scores = outputs[1]
        nms = multiclass_nms(boxes=boxes, scores=scores, score_threshold=self.score_threshold,
                             iou_threshold=self.iou_threshold, max_detections=self.max_detections)
        """
        In case you're interested in NMS with indices, you can replace the above with the following code:
            nms = multiclass_nms_with_indices(boxes=boxes, scores=scores, score_threshold=self.score_threshold, iou_threshold=self.iou_threshold, max_detections=self.max_detections)
        """
        return nms

device = "cuda" if torch.cuda.is_available() else "cpu"
quant_model_with_nms = PostProcessWrapper(model=quant_model,
                                    score_threshold=0.001,
                                    iou_threshold=0.7,
                                    max_detections=300).to(device=device)
print('Quantized model with NMS is ready')

Quantized model with NMS is ready


### Model Export

Finally, we can export the quantized model into a .onnx format file. Please ensure that the save_model_path has been set correctly.

In [83]:
mct.exporter.pytorch_export_model(model=quant_model_with_nms,
                                  save_model_path='./qmodel_with_nms.onnx',
                                  repr_dataset=representative_data_generator)

  threshold = torch.tensor(threshold, dtype=torch.float32).to(get_working_device())


Exporting onnx model with MCTQ quantizers: ./qmodel_with_nms.onnx


Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.