In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim


In [49]:
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [50]:
from mobilenetv2 import mobilenet_v2
# from resnet import resnet18
# model = resnet18(pretrained=True)
model = mobilenet_v2(pretrained=True)

In [51]:
import time
import numpy as np

def evaluation(model, dataloader):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    correct = 0
    total = 0
    results = []
    with torch.no_grad():
        for inputs, labels in dataloader:

            inputs, labels = inputs.to(device), labels.to(device)
            
            start = time.time()
            outputs = model(inputs)
            
            torch.cuda.synchronize()
            end = time.time()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            results.append(end-start)

    infer_time = np.mean(results)
    print(infer_time)
    print(f"Accuracy on test set: {100 * correct / total:.2f}%")

In [52]:
evaluation(model, testloader)

0.010886967936648598
Accuracy on test set: 93.91%


In [53]:
from aimet_common.defs import QuantScheme
from aimet_torch.v1.quantsim import QuantizationSimModel

dummy_input = torch.rand(1, 3, 32, 32) 

dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=4)

2024-12-26 00:54:35,434 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-26 00:54:35,454 - Quant - INFO - Unsupported op type Squeeze
2024-12-26 00:54:35,454 - Quant - INFO - Unsupported op type Mean
2024-12-26 00:54:35,457 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default


In [54]:
print(sim)

-------------------------
Quantized Model Report
-------------------------
----------------------------------------------------------
Layer: features.0.0
  Input[0]: bw=8, encoding-present=False
  -------
  Param[weight]: bw=4, encoding-present=False
  -------
  Output[0]: Not quantized
  -------
----------------------------------------------------------
Layer: features.0.1
  Input[0]: Not quantized
  -------
  Param[weight]: Not quantized
  -------
  Param[bias]: Not quantized
  -------
  Output[0]: Not quantized
  -------
----------------------------------------------------------
Layer: features.0.2
  Input[0]: Not quantized
  -------
  Output[0]: bw=8, encoding-present=False
  -------
----------------------------------------------------------
Layer: features.1.conv.0.0
  Input[0]: Not quantized
  -------
  Param[weight]: bw=4, encoding-present=False
  -------
  Output[0]: Not quantized
  -------
----------------------------------------------------------
Layer: features.1.conv.0.1
  

In [55]:
use_cuda = True

In [56]:
def pass_calibration_data(sim_model, use_cuda):
    data_loader = trainloader
    batch_size = data_loader.batch_size

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in data_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            if (batch_cntr * batch_size) > samples:
                break

In [57]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

In [58]:
evaluation(sim.model, testloader)

0.021007561985450455
Accuracy on test set: 89.85%


In [59]:
from aimet_torch.v1.adaround.adaround_weight import Adaround, AdaroundParameters
import os

data_loader = testloader
params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=64)

dummy_input = torch.rand(1, 3, 32, 32)
if use_cuda:
    dummy_input = dummy_input.cuda()

os.makedirs('./output/', exist_ok=True)
ada_model = Adaround.apply_adaround(model, dummy_input, params,
                                    path="output", 
                                    filename_prefix='adaround', 
                                    default_param_bw=4,
                                    default_quant_scheme=QuantScheme.post_training_tf_enhanced)

2024-12-26 00:54:42,723 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-26 00:54:42,741 - Quant - INFO - Unsupported op type Squeeze
2024-12-26 00:54:42,742 - Quant - INFO - Unsupported op type Mean
2024-12-26 00:54:42,746 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default
2024-12-26 00:54:43,364 - Utils - INFO - Caching 1 batches from data loader at path location: /tmp/tmpe6h05_ej


                                       

2024-12-26 00:54:43,417 - Quant - INFO - Started Optimizing weight rounding of module: features.0.0


                                               

2024-12-26 00:54:43,956 - Quant - INFO - Started Optimizing weight rounding of module: features.1.conv.0.0


                                               

2024-12-26 00:54:44,130 - Quant - INFO - Started Optimizing weight rounding of module: features.1.conv.1


                                               

2024-12-26 00:54:44,306 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.0.0


                                               

2024-12-26 00:54:44,523 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.1.0


                                                

2024-12-26 00:54:44,733 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.2


                                                

2024-12-26 00:54:44,919 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.0.0


                                                

2024-12-26 00:54:45,126 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.1.0


                                                

2024-12-26 00:54:45,399 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.2


                                                

2024-12-26 00:54:45,621 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.0.0


                                                

2024-12-26 00:54:45,837 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.1.0


                                                

2024-12-26 00:54:46,056 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.2


                                                

2024-12-26 00:54:46,224 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.0.0


                                                

2024-12-26 00:54:46,391 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.1.0


                                                

2024-12-26 00:54:46,569 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.2


                                                

2024-12-26 00:54:46,759 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.0.0


                                                

2024-12-26 00:54:46,930 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.1.0


                                                

2024-12-26 00:54:47,109 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.2


                                                

2024-12-26 00:54:47,286 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.0.0


                                                

2024-12-26 00:54:47,459 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.1.0


                                                

2024-12-26 00:54:47,632 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.2


                                                

2024-12-26 00:54:47,799 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.0.0


                                                

2024-12-26 00:54:47,971 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.1.0


                                                

2024-12-26 00:54:48,143 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.2


                                                

2024-12-26 00:54:48,322 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.0.0


                                                

2024-12-26 00:54:48,494 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.1.0


                                                

2024-12-26 00:54:48,672 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.2


                                                

2024-12-26 00:54:48,851 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.0.0


                                                

2024-12-26 00:54:49,032 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.1.0


                                                

2024-12-26 00:54:49,213 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.2


                                                

2024-12-26 00:54:49,397 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.0.0


                                                

2024-12-26 00:54:49,580 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.1.0


                                                

2024-12-26 00:54:49,766 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.2


                                                

2024-12-26 00:54:49,960 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.0.0


                                                

2024-12-26 00:54:50,157 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.1.0


                                                

2024-12-26 00:54:50,359 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.2


                                                

2024-12-26 00:54:50,572 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.0.0


                                                

2024-12-26 00:54:50,769 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.1.0


                                                 

2024-12-26 00:54:50,977 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.2


                                                 

2024-12-26 00:54:51,191 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.0.0


                                                 

2024-12-26 00:54:51,394 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.1.0


                                                 

2024-12-26 00:54:51,595 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.2


                                                 

2024-12-26 00:54:51,800 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.0.0


                                                 

2024-12-26 00:54:52,005 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.1.0


                                                 

2024-12-26 00:54:52,207 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.2


                                                 

2024-12-26 00:54:52,419 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.0.0


                                                 

2024-12-26 00:54:52,626 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.1.0


                                                 

2024-12-26 00:54:52,834 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.2


                                                 

2024-12-26 00:54:53,048 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.0.0


                                                 

2024-12-26 00:54:53,261 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.1.0


                                                 

2024-12-26 00:54:53,474 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.2


                                                 

2024-12-26 00:54:53,700 - Quant - INFO - Started Optimizing weight rounding of module: features.18.0


                                                 

2024-12-26 00:54:53,926 - Quant - INFO - Started Optimizing weight rounding of module: classifier.1


100%|██████████| 141/141 [00:10<00:00, 13.10it/s]

2024-12-26 00:54:54,144 - Quant - INFO - Completed Adarounding Model





In [60]:
sim = QuantizationSimModel(model=ada_model,
                           dummy_input=dummy_input,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_output_bw=8, 
                           default_param_bw=4)

sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))

sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

2024-12-26 00:54:54,437 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-26 00:54:54,456 - Quant - INFO - Unsupported op type Squeeze
2024-12-26 00:54:54,456 - Quant - INFO - Unsupported op type Mean
2024-12-26 00:54:54,460 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default


  sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))


In [61]:
evaluation(sim.model, testloader)

0.02107233940800534
Accuracy on test set: 91.80%
