# Model Compression Toolkit (MCT) Wrapper API (PyTorch)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/SonySemiconductorSolutions/mct-model-optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_mct_wrapper.ipynb)

## Overview 
In this notebook, we provide a detailed explanation of the MCTWrapper class from the Model Compression Toolkit (MCT).
Using this class enables a consistent implementation, making it easy to compare various quantization methods.
In this tutorial, we take MobileNetV2 as an example and use MCTWrapper to apply the following quantization techniques:
PTQ (Post-Training Quantization), PTQ with Mixed Precision, GPTQ (Gradient-based PTQ), GPTQ with Mixed Precision.
By working through these methods, you will experience the convenience and flexibility of MCTWrapper, 
helping you to select the optimal quantization approach for your application.

## Summary
- **Setup**: Import required libraries and configure MCT with MobileNetV2 model
- **Dataset Preparation**: Load and prepare ImageNet validation dataset with representative data generation
- **Model Quantization using MCTWrapper**: Quantize the float model using MCTWrapper with four methods
  - **PTQ**: Perform PTQ
  - **PTQ + Mixed Precision**: Assign optimal quantization bit-width to each layer based on PTQ
  - **GPTQ**: Perform GPTQ
  - **GPTQ + Mixed Precision**: Assign optimal quantization bit-width to each layer based on GPTQ
- **Evaluation**: Evaluate accuracy of all quantization methods

## Setup

In [None]:
!pip install -q onnx==1.17.0
!pip install -q torch==2.6.0 torchvision==0.21.0
!pip install -q tqdm

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

In [None]:
from typing import Tuple
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.datasets import ImageNet
import model_compression_toolkit as mct

Load a pre-trained MobileNetV2 model from torchvision, in 32-bits floating-point precision format.

In [None]:
weights = MobileNet_V2_Weights.IMAGENET1K_V2

float_model = mobilenet_v2(weights=weights)

## Dataset Preparation
### Download ImageNet validation set
Download ImageNet dataset (validation split only).

This step may take several minutes...

**Note:** For demonstration purposes, we use the validation set for the model quantization routines. Usually, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for the sake of this demonstration.

In [None]:
import os

if not os.path.isdir('imagenet'):
    !mkdir imagenet
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

Extract ImageNet validation dataset using torchvision "datasets" module.

In [None]:
dataset = ImageNet(root='./imagenet', split='val', transform=weights.transforms())

## Representative Dataset
For quantization with MCT, we need to define a representative dataset. This dataset is a generator that returns a list of images:

In [None]:
batch_size = 16
n_iter = 10

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def representative_dataset_gen():
    dataloader_iter = iter(dataloader)
    for _ in range(n_iter):
        yield [next(dataloader_iter)[0]]

## Model Quantization using MCTWrapper

We implement quantizing example using MCTWrapper with four methods.

By specifying the SDSP converter version, you can select the optimal quantization settings for IMX500.
Here, we use the settings for SDSP Converter 3.14. For other settings, please see [here](https://github.com/SonySemiconductorSolutions/mct-model-optimization/tree/main/model_compression_toolkit/target_platform_capabilities).

**Note:** This tutorial sets the minimum parameters required to run MCTWrapper. For details on omitted parameters, refer to [MCT Documentation](https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/classes/Wrapper.html#ug-wrapper).

**Note:** This tutorial uses parameters focused on shorter run time for demonstration, resulting in lower accuracy. For improve accuracy, refer to other tutorials.

Run PTQ with PyTorch

In [None]:
def PTQ_Pytorch(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform PTQ on PyTorch model.
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration
    framework = 'pytorch'               # Target framework (PyTorch)
    method = 'PTQ'                      # Quantization method
    use_mixed_precision = False         # Disable mixed-precision quantization

    # Parameter configuration
    param_items = [
        ['sdsp_version', '3.14'],                           # Version of the SDSP converter
        ['save_model_path', './qmodel_PTQ_Pytorch.onnx']    # Path to save quantized model as ONNX format
    ]

    # Execute quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run PTQ + Mixed Precision with PyTorch

In [None]:
def PTQ_Pytorch_mixed_precision(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform PTQ with Mixed Precision on PyTorch model.
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration
    framework = 'pytorch'               # Target framework (PyTorch)
    method = 'PTQ'                      # Quantization method
    use_mixed_precision = True          # Enable mixed-precision quantization

    # Parameter configuration
    param_items = [
        ['sdsp_version', '3.14'],                                         # Version of the SDSP converter
        ['num_of_images', 5],                                             # Number of images for Mixed-Precision calibration
        ['weights_compression_ratio', 0.5],                               # Compression ratio of weights for Mixed-Precision
        ['save_model_path', './qmodel_PTQ_Pytorch_mixed_precision.onnx']  # Path to save quantized model as ONNX format
    ]

    # Execute quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run GPTQ with PyTorch

In [None]:
def GPTQ_Pytorch(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform GPTQ on PyTorch model.
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration
    framework = 'pytorch'               # Target framework (PyTorch)
    method = 'GPTQ'                     # Quantization method
    use_mixed_precision = False         # Disable mixed-precision quantization

    # Parameter configuration
    param_items = [
        ['sdsp_version', '3.14'],                          # Version of the SDSP converter
        ['n_epochs', 5],                                   # Number of epochs for GPTQ optimization
        ['save_model_path', './qmodel_GPTQ_Pytorch.onnx']  # Path to save quantized model as ONNX format
    ]

    # Execute quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

Run GPTQ + Mixed Precision with PyTorch

In [None]:
def GPTQ_Pytorch_mixed_precision(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform GPTQ with Mixed Precision on PyTorch model.
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration
    framework = 'pytorch'               # Target framework (PyTorch)
    method = 'GPTQ'                     # Quantization method
    use_mixed_precision = True          # Enable mixed-precision quantization

    # Parameter configuration
    param_items = [
        ['sdsp_version', '3.14'],                                          # Version of the SDSP converter
        ['n_epochs', 5],                                                   # Number of epochs for GPTQ optimization
        ['num_of_images', 5],                                              # Number of images for Mixed-Precision calibration
        ['weights_compression_ratio', 0.5],                                # Compression ratio of weights for Mixed-Precision
        ['save_model_path', './qmodel_GPTQ_Pytorch_mixed_precision.onnx']  # Path to save quantized model as ONNX format
    ]

    # Execute quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model=float_model, 
        representative_dataset=representative_dataset_gen, 
        framework=framework, 
        method=method, 
        use_mixed_precision=use_mixed_precision, 
        param_items=param_items)
    return flag, quantized_model

### Run Quantization
Lastly, we quantize our model using MCTWrapper API.

In [None]:
# Basic PTQ
flag, quantized_model_ptq = PTQ_Pytorch(float_model)

In [None]:
# PTQ with Mixed Precision
flag, quantized_model_ptq_mixed_precision = PTQ_Pytorch_mixed_precision(float_model)

In [None]:
# GPTQ
flag, quantized_model_gptq = GPTQ_Pytorch(float_model)

In [None]:
# GPTQ with Mixed Precision
flag, quantized_model_gptq_mixed_precision = GPTQ_Pytorch_mixed_precision(float_model)

## Evaluation
Define a comprehensive evaluation function for PyTorch models that provides accurate performance measurement on the validation dataset.

In [None]:
def evaluate(model: torch.nn.Module, testloader: DataLoader, mode: str) -> float:
    """
    Evaluate PyTorch model accuracy using a DataLoader.
    
    This function performs complete accuracy evaluation by:
    - Moving model and data to available device (GPU/CPU)
    - Running inference in evaluation mode (no gradient computation)
    - Computing Top-1 accuracy across the entire validation set
    - Providing progress tracking during evaluation
    
    Args:
        model: PyTorch model to evaluate (float or quantized)
        testloader: DataLoader containing validation dataset
        mode: String identifier for logging (e.g., 'Float', 'PTQ_Pytorch')
    
    Returns:
        float: Top-1 accuracy percentage
    """
    # Determine best available device for inference
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    
    # Perform inference without gradient computation for efficiency
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass to get predictions
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Calculate and display accuracy
    val_acc = (100 * correct / total)
    print(mode + ' Accuracy: %.2f%%' % val_acc)
    return val_acc

Create DataLoader for evaluation with larger batch size for efficiency.

In [None]:
val_dataloader = DataLoader(dataset, batch_size=50, shuffle=False)

Finally, let's evaluate each model.

In [None]:
# Original floating-point PyTorch model
evaluate(float_model, val_dataloader, 'Float')

In [None]:
# PTQ model
evaluate(quantized_model_ptq, val_dataloader, 'PTQ_Pytorch')

In [None]:
# PTQ + Mixed Precision model
evaluate(quantized_model_ptq_mixed_precision, val_dataloader, 'PTQ_Pytorch_mixed_precision')

In [None]:
# GPTQ model
evaluate(quantized_model_gptq, val_dataloader, 'GPTQ_Pytorch')

In [None]:
# GPTQ + Mixed Precision quantized model
evaluate(quantized_model_gptq_mixed_precision, val_dataloader, 'GPTQ_Pytorch_mixed_precision')

## Conclusion

In this tutorial, we demonstrated how to quantize a pre-trained model using MCTWrapper with a few lines of code.


## Copyrights

Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
