In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

#FIX for module

In [6]:
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
from mobilenetv2 import mobilenet_v2
model = mobilenet_v2(pretrained=True)

In [None]:
import time
import numpy as np

def evaluation(model, dataloader):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    correct = 0
    total = 0
    results = []
    with torch.no_grad():
        for inputs, labels in dataloader:

            inputs, labels = inputs.to(device), labels.to(device)
            
            start = time.time()
            outputs = model(inputs)
            
            torch.cuda.synchronize()
            end = time.time()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            results.append(end-start)

    infer_time = np.mean(results)
    print(infer_time)
    print(f"Accuracy on test set: {100 * correct / total:.2f}%")

In [9]:
evaluation(model, testloader)

0.013391564163980605
Accuracy on test set: 93.91%


In [10]:
from aimet_torch.v2.batch_norm_fold import fold_all_batch_norms

_ = fold_all_batch_norms(model, input_shapes=(1, 3, 32, 32))

2024-12-23 07:21:28,982 - root - INFO - AIMET


In [13]:
from aimet_common.defs import QuantScheme
from aimet_torch.v1.quantsim import QuantizationSimModel

dummy_input = torch.rand(1, 3, 32, 32) 

dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=4)

2024-12-23 07:22:05,098 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-23 07:22:05,116 - Quant - INFO - Unsupported op type Squeeze
2024-12-23 07:22:05,117 - Quant - INFO - Unsupported op type Mean
2024-12-23 07:22:05,119 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default


In [14]:
use_cuda = True

In [15]:
evaluation(sim.model, testloader)

0.5496138439902777
Accuracy on test set: 90.84%


In [None]:
def pass_calibration_data(sim_model, use_cuda):
    data_loader = trainloader
    batch_size = data_loader.batch_size

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in data_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            if (batch_cntr * batch_size) > samples:
                break

In [17]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

In [18]:
evaluation(sim.model, testloader)

0.016102712365645398
Accuracy on test set: 90.76%


In [19]:
from aimet_torch.v1.adaround.adaround_weight import Adaround, AdaroundParameters
import os

data_loader = testloader
params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=64)

dummy_input = torch.rand(1, 3, 32, 32)
if use_cuda:
    dummy_input = dummy_input.cuda()

os.makedirs('./output/', exist_ok=True)
ada_model = Adaround.apply_adaround(model, dummy_input, params,
                                    path="output", 
                                    filename_prefix='adaround', 
                                    default_param_bw=4,
                                    default_quant_scheme=QuantScheme.post_training_tf_enhanced)

2024-12-23 07:23:11,740 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-23 07:23:11,759 - Quant - INFO - Unsupported op type Squeeze
2024-12-23 07:23:11,759 - Quant - INFO - Unsupported op type Mean
2024-12-23 07:23:11,761 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default
2024-12-23 07:23:12,333 - Utils - INFO - Caching 1 batches from data loader at path location: /tmp/tmpnjdjmwqv


                                       

2024-12-23 07:23:12,381 - Quant - INFO - Started Optimizing weight rounding of module: features.0.0


                                               

2024-12-23 07:23:12,928 - Quant - INFO - Started Optimizing weight rounding of module: features.1.conv.0.0


                                               

2024-12-23 07:23:13,094 - Quant - INFO - Started Optimizing weight rounding of module: features.1.conv.1


                                               

2024-12-23 07:23:13,221 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.0.0


                                               

2024-12-23 07:23:13,370 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.1.0


                                                

2024-12-23 07:23:13,552 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.2


                                                

2024-12-23 07:23:13,710 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.0.0


                                                

2024-12-23 07:23:13,912 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.1.0


                                                

2024-12-23 07:23:14,188 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.2


                                                

2024-12-23 07:23:14,427 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.0.0


                                                

2024-12-23 07:23:14,678 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.1.0


                                                

2024-12-23 07:23:14,934 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.2


                                                

2024-12-23 07:23:15,129 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.0.0


                                                

2024-12-23 07:23:15,288 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.1.0


                                                

2024-12-23 07:23:15,446 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.2


                                                

2024-12-23 07:23:15,607 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.0.0


                                                

2024-12-23 07:23:15,765 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.1.0


                                                

2024-12-23 07:23:15,927 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.2


                                                

2024-12-23 07:23:16,088 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.0.0


                                                

2024-12-23 07:23:16,251 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.1.0


                                                

2024-12-23 07:23:16,410 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.2


                                                

2024-12-23 07:23:16,561 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.0.0


                                                

2024-12-23 07:23:16,720 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.1.0


                                                

2024-12-23 07:23:16,876 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.2


                                                

2024-12-23 07:23:17,035 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.0.0


                                                

2024-12-23 07:23:17,194 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.1.0


                                                

2024-12-23 07:23:17,354 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.2


                                                

2024-12-23 07:23:17,513 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.0.0


                                                

2024-12-23 07:23:17,674 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.1.0


                                                

2024-12-23 07:23:17,838 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.2


                                                

2024-12-23 07:23:18,000 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.0.0


                                                

2024-12-23 07:23:18,165 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.1.0


                                                

2024-12-23 07:23:18,330 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.2


                                                

2024-12-23 07:23:18,500 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.0.0


                                                

2024-12-23 07:23:18,677 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.1.0


                                                

2024-12-23 07:23:18,849 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.2


                                                

2024-12-23 07:23:19,025 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.0.0


                                                

2024-12-23 07:23:19,201 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.1.0


                                                 

2024-12-23 07:23:19,378 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.2


                                                 

2024-12-23 07:23:19,555 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.0.0


                                                 

2024-12-23 07:23:19,738 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.1.0


                                                 

2024-12-23 07:23:19,914 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.2


                                                 

2024-12-23 07:23:20,090 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.0.0


                                                 

2024-12-23 07:23:20,275 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.1.0


                                                 

2024-12-23 07:23:20,458 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.2


                                                 

2024-12-23 07:23:20,649 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.0.0


                                                 

2024-12-23 07:23:20,836 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.1.0


                                                 

2024-12-23 07:23:21,018 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.2


                                                 

2024-12-23 07:23:21,202 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.0.0


                                                 

2024-12-23 07:23:21,393 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.1.0


                                                 

2024-12-23 07:23:21,579 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.2


                                                 

2024-12-23 07:23:21,772 - Quant - INFO - Started Optimizing weight rounding of module: features.18.0


                                                 

2024-12-23 07:23:21,971 - Quant - INFO - Started Optimizing weight rounding of module: classifier.1


100%|██████████| 141/141 [00:09<00:00, 14.38it/s]

2024-12-23 07:23:22,152 - Quant - INFO - Completed Adarounding Model





In [20]:
sim = QuantizationSimModel(model=ada_model,
                           dummy_input=dummy_input,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_output_bw=8, 
                           default_param_bw=4)

sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))

sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

2024-12-23 07:23:22,404 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-23 07:23:22,422 - Quant - INFO - Unsupported op type Squeeze
2024-12-23 07:23:22,422 - Quant - INFO - Unsupported op type Mean
2024-12-23 07:23:22,424 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default


  sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))


In [21]:
evaluation(sim.model, testloader)

0.01608066015605685
Accuracy on test set: 92.52%
