# Quantitizing the  model `wambugu71/crop_leaf_diseases_vit`


Convert the Hugging Face model "wambugu71/crop_leaf_diseases_vit" to a  ONNX format suitable for mobile deployment.

## Install necessary libraries

In [1]:
%pip install transformers onnx onnxruntime  optimum[onnxruntime] onnxruntime-tools

Collecting onnx
  Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.23.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Collecting onnxruntime-tools
  Downloading onnxruntime_tools-1.7.0-py3-none-any.whl.metadata (14 kB)
Collecting optimum[onnxruntime]
  Downloading optimum-2.0.0-py3-none-any.whl.metadata (14 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting optimum-onnx[onnxruntime] (from optimum[onnxruntime])
  Downloading optimum_onnx-0.0.1-py3-none-any.whl.metadata (4.7 kB)
Collecting py3nvml (from onnxruntime-tools)
  Downloading py3nvml-0.2.7-py3-none-any.whl.metadata (13 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Collecting transformers
  Downloading transformers-4.55.4-py3-none

## Load the model and tokenizer


In [2]:
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnxruntime as ort
import numpy as np
from PIL import Image
model_name = "wambugu71/crop_leaf_diseases_vit"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model.eval()

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/22.1M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/325 [00:00<?, ?B/s]



ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=192, out_features=192, bias=True)
              (key): Linear(in_features=192, out_features=192, bias=True)
              (value): Linear(in_features=192, out_features=192, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=192, out_features=192, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=192, out_features=768, bias=True)
            (intermedi

## Convert to onnx


In [3]:
dummy_input = torch.randn(
    1, 3, feature_extractor.size['height'], feature_extractor.size['width']
)


onnx_path = "crop_leaf_diseases_vit.onnx"

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=['pixel_values'],
    output_names=['output'],
    dynamic_axes={
        'pixel_values': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    opset_version=17,

)

print(f"Model exported to ONNX at {onnx_path}")


  torch.onnx.export(
  if num_channels != self.num_channels:
  if height != self.image_size[0] or width != self.image_size[1]:


Model exported to ONNX at crop_leaf_diseases_vit.onnx


Lets quantitize  to  `int8`

In [None]:
!python -m onnxruntime.quantization.preprocess --input crop_leaf_diseases_vit.onnx --output crop_leaf_diseases_ken_vit.onnx

In [None]:
%%writefile run.py
import argparse
import numpy as np
import onnxruntime
import time
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static

import data_reader


def benchmark(model_path):
    session = onnxruntime.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name

    total = 0.0
    runs = 10
    input_data = np.zeros((1, 3, 224, 224), np.float32)
    # Warming up
    _ = session.run([], {input_name: input_data})
    for i in range(runs):
        start = time.perf_counter()
        _ = session.run([], {input_name: input_data})
        end = (time.perf_counter() - start) * 1000
        total += end
        print(f"{end:.2f}ms")
    total /= runs
    print(f"Avg: {total:.2f}ms")


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_model", required=True, help="input model")
    parser.add_argument("--output_model", required=True, help="output model")
    parser.add_argument(
        "--calibrate_dataset", default="./test_images", help="calibration data set"
    )
    parser.add_argument(
        "--quant_format",
        default=QuantFormat.QDQ,
        type=QuantFormat.from_string,
        choices=list(QuantFormat),
    )
    parser.add_argument("--per_channel", default=False, type=bool)
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    input_model_path = args.input_model
    output_model_path = args.output_model
    calibration_dataset_path = args.calibrate_dataset
    dr = data_reader.DataReader(
        calibration_dataset_path, input_model_path
    )

    # Calibrate and quantize model
    # Turn off model optimization during quantization
    quantize_static(
        input_model_path,
        output_model_path,
        dr,
        quant_format=args.quant_format,
        per_channel=args.per_channel,
        weight_type=QuantType.QInt8,
    )
    print("Calibrated and quantized model saved.")

    print("benchmarking fp32 model...")
    benchmark(input_model_path)

    print("benchmarking int8 model...")
    benchmark(output_model_path)


if __name__ == "__main__":
    main()

Overwriting run.py


In [None]:
%%writefile data_reader.py
import numpy
import onnxruntime
import os
from onnxruntime.quantization import CalibrationDataReader
from PIL import Image


def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0):
    """
    Loads a batch of images and preprocess them
    parameter images_folder: path to folder storing images
    parameter height: image height in pixels
    parameter width: image width in pixels
    parameter size_limit: number of images to load. Default is 0 which means all images are picked.
    return: list of matrices characterizing multiple images
    """
    image_names = os.listdir(images_folder)
    if size_limit > 0 and len(image_names) >= size_limit:
        batch_filenames = [image_names[i] for i in range(size_limit)]
    else:
        batch_filenames = image_names
    unconcatenated_batch_data = []

    for image_name in batch_filenames:
        image_filepath = images_folder + "/" + image_name
        pillow_img = Image.new("RGB", (width, height))
        pillow_img.paste(Image.open(image_filepath).resize((width, height)))
        input_data = numpy.float32(pillow_img) - numpy.array(
            [123.68, 116.78, 103.94], dtype=numpy.float32
        )
        nhwc_data = numpy.expand_dims(input_data, axis=0)
        nchw_data = nhwc_data.transpose(0, 3, 1, 2)  # ONNX Runtime standard
        unconcatenated_batch_data.append(nchw_data)
    batch_data = numpy.concatenate(
        numpy.expand_dims(unconcatenated_batch_data, axis=0), axis=0
    )
    return batch_data


class DataReader(CalibrationDataReader):
    def __init__(self, calibration_image_folder: str, model_path: str):
        self.enum_data = None

        # Use inference session to get input shape.
        session = onnxruntime.InferenceSession(model_path, None)
        (_, _, height, width) = session.get_inputs()[0].shape

        # Convert image to input data
        self.nhwc_data_list = _preprocess_images(
            calibration_image_folder, height, width, size_limit=0
        )
        self.input_name = session.get_inputs()[0].name
        self.datasize = len(self.nhwc_data_list)

    def get_next(self):
        if self.enum_data is None:
            self.enum_data = iter(
                [{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list]
            )
        return next(self.enum_data, None)

    def rewind(self):
        self.enum_data = None

Writing data_reader.py


In [None]:
!mkdir -p test_images

In [None]:
!python run.py --input_model crop_leaf_diseases_ken_vit.onnx --output_model crop_leaf_diseases_ken_vit.quant.onnx --calibrate_dataset ./test_images/

Calibrated and quantized model saved.
benchmarking fp32 model...
59.24ms
55.68ms
55.96ms
55.15ms
55.44ms
58.55ms
90.11ms
83.72ms
86.56ms
88.29ms
Avg: 68.87ms
benchmarking int8 model...
119.88ms
117.14ms
120.08ms
126.71ms
113.69ms
117.11ms
112.96ms
108.63ms
143.80ms
269.11ms
Avg: 134.91ms


In [None]:
onnx_model_path = "/content/crop_leaf_diseases_ken_vit.quant.onnx"
session = ort.InferenceSession(onnx_model_path)

# Get input name (should be 'pixel_values')
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name


image_path = "/content/test_images/Potato_healthy-26-_0_9285.jpg"  # replace with your image file
image = Image.open(image_path).convert("RGB")

inputs = feature_extractor(images=image, return_tensors="np")

# Extract input array
pixel_values = inputs["pixel_values"]  # shape [1, 3, H, W], dtype=float32


outputs = session.run([output_name], {input_name: pixel_values})
logits = outputs[0]  # shape [1, num_classes]


predicted_class_idx = np.argmax(logits, axis=1)[0]
print("Predicted class index:", predicted_class_idx)

Predicted class index: 7


  Quantitized to int8 still has lost most accuracy.  `Reccomended original` for full accuracy.

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
from huggingface_hub import create_repo

repo_name = "wambugu71/crop_leaf_diseases_vit_onnx"
create_repo(repo_name, exist_ok=True)
print(f"Repository '{repo_name}' created or already exists.")

Repository 'wambugu71/crop_leaf_diseases_vit_onnx' created or already exists.


In [7]:
from huggingface_hub import upload_file

repo_name = "wambugu71/crop_leaf_diseases_vit_onnx"
onnx_model_path = "/content/crop_leaf_diseases_vit.onnx"
repo_file_path = "crop_leaf_diseases_vit.onnx"

upload_file(
    path_or_fileobj=onnx_model_path,
    path_in_repo=repo_file_path,
    repo_id=repo_name,
    commit_message="Add initial ONNX model"
)

print(f"Pushed {onnx_model_path} to {repo_name}/{repo_file_path}")

crop_leaf_diseases_vit.onnx:   0%|          | 0.00/22.3M [00:00<?, ?B/s]

Pushed /content/crop_leaf_diseases_vit.onnx to wambugu71/crop_leaf_diseases_vit_onnx/crop_leaf_diseases_vit.onnx
