In [None]:
!pip install -Uq transformers
!pip install -Uq bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from transformers import (
    TimmWrapperImageProcessor,
    TimmWrapperForImageClassification,
    BitsAndBytesConfig,
)
from transformers.image_utils import load_image
import torch

In [None]:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"

model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
    checkpoint,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
)

In [None]:
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
image = load_image(image_url)

image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(image)

In [None]:
def inference(model):
    with torch.inference_mode():
        logits = model(**inputs).logits

    top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)

    id2label = model.config.id2label

    for idx, prob in zip(top5_class_indices[0], top5_probabilities[0]):
        print(f"Label: {id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")

In [None]:
inference(model)

Label: remote control, remote Score: 0.35%
Label: tabby, tabby cat     Score: 0.27%
Label: Egyptian cat         Score: 0.13%
Label: tiger cat            Score: 0.11%
Label: rule, ruler          Score: 0.00%


In [None]:
inference(model_8bit)

Label: remote control, remote Score: 0.33%
Label: tabby, tabby cat     Score: 0.29%
Label: Egyptian cat         Score: 0.13%
Label: tiger cat            Score: 0.11%
Label: rule, ruler          Score: 0.00%


In [None]:
# Compare memory footprints
original_footprint = model.get_memory_footprint()
quantized_footprint = model_8bit.get_memory_footprint()

print(f"Memory footprint of the original model: {original_footprint / 1e6:.2f} MB")
print(f"Memory footprint of the quantized model: {quantized_footprint / 1e6:.2f} MB")
print(
    f"Reduction in memory usage: "
    f"{(original_footprint - quantized_footprint) / original_footprint * 100:.2f}%"
)

Memory footprint of the original model: 346.27 MB
Memory footprint of the quantized model: 88.20 MB
Reduction in memory usage: 74.53%
