In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim



In [2]:
# 데이터 변환 설정
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
])

# 데이터 로더 설정
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = cifar100_resnet56(pretrained=True)
# model = model.to(device)
from mobilenetv2 import mobilenet_v2
model = mobilenet_v2(pretrained=True)


In [18]:
import time
import numpy as np

model.eval()
model = model.to(device)
correct = 0
total = 0
results = []
with torch.no_grad():
    for inputs, labels in testloader:
        torch.cuda.synchronize()
        inputs, labels = inputs.to(device), labels.to(device)
        start = time.time()
        outputs = model(inputs)
        torch.cuda.synchronize()
        end = time.time()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        results.append(end-start)

infer_time = np.mean(results)
print(infer_time)
print(f"Accuracy on test set: {100 * correct / total:.2f}%")

0.01758119124400465
Accuracy on test set: 93.91%


In [19]:
from aimet_torch.v2.batch_norm_fold import fold_all_batch_norms

_ = fold_all_batch_norms(model, input_shapes=(1, 3, 32, 32))

In [20]:
from aimet_common.defs import QuantScheme
from aimet_torch.v1.quantsim import QuantizationSimModel

dummy_input = torch.rand(1, 3, 32, 32)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)

dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=4)

2024-12-23 06:05:37,657 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-23 06:05:37,675 - Quant - INFO - Unsupported op type Squeeze
2024-12-23 06:05:37,675 - Quant - INFO - Unsupported op type Mean
2024-12-23 06:05:37,677 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default


In [21]:
use_cuda = True


In [22]:
import time
import numpy as np
correct = 0
total = 0
results = []
with torch.no_grad():
    for inputs, labels in testloader:
        torch.cuda.synchronize()
        inputs, labels = inputs.to(device), labels.to(device)
        start = time.time()
        outputs = sim.model(inputs)
        torch.cuda.synchronize()
        end = time.time()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        results.append(end-start)

infer_time = np.mean(results)
print(infer_time)
print(f"Accuracy on test set: {100 * correct / total:.2f}%")

0.5548770970936063
Accuracy on test set: 90.84%


In [23]:
def pass_calibration_data(sim_model, use_cuda):
    data_loader = testloader
    batch_size = data_loader.batch_size

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in data_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            if (batch_cntr * batch_size) > samples:
                break

In [24]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

In [25]:
import time
import numpy as np
correct = 0
total = 0
results = []
with torch.no_grad():
    for inputs, labels in testloader:
        torch.cuda.synchronize()
        inputs, labels = inputs.to(device), labels.to(device)
        start = time.time()
        outputs = sim.model(inputs)
        torch.cuda.synchronize()
        end = time.time()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        results.append(end-start)

infer_time = np.mean(results)
print(infer_time)
print(f"Accuracy on test set: {100 * correct / total:.2f}%")

0.016245310819601712
Accuracy on test set: 90.76%


In [26]:
from aimet_torch.v1.adaround.adaround_weight import Adaround, AdaroundParameters
import os

data_loader = testloader
params = AdaroundParameters(data_loader=data_loader, num_batches=1, default_num_iterations=64)

dummy_input = torch.rand(1, 3, 32, 32)
if use_cuda:
    dummy_input = dummy_input.cuda()

os.makedirs('./output/', exist_ok=True)
ada_model = Adaround.apply_adaround(model, dummy_input, params,
                                    path="output", 
                                    filename_prefix='adaround', 
                                    default_param_bw=4,
                                    default_quant_scheme=QuantScheme.post_training_tf_enhanced)

2024-12-23 06:06:59,835 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-23 06:06:59,853 - Quant - INFO - Unsupported op type Squeeze
2024-12-23 06:06:59,853 - Quant - INFO - Unsupported op type Mean
2024-12-23 06:06:59,856 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default
2024-12-23 06:07:00,458 - Utils - INFO - Caching 1 batches from data loader at path location: /tmp/tmpztxrh4s_


                                       

2024-12-23 06:07:00,469 - Quant - INFO - Started Optimizing weight rounding of module: features.0.0


                                               

2024-12-23 06:07:01,178 - Quant - INFO - Started Optimizing weight rounding of module: features.1.conv.0.0


                                               

2024-12-23 06:07:01,864 - Quant - INFO - Started Optimizing weight rounding of module: features.1.conv.1


                                               

2024-12-23 06:07:02,454 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.0.0


                                               

2024-12-23 06:07:02,620 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.1.0


                                                

2024-12-23 06:07:03,298 - Quant - INFO - Started Optimizing weight rounding of module: features.2.conv.2


                                                

2024-12-23 06:07:04,071 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.0.0


                                                

2024-12-23 06:07:04,752 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.1.0


                                                

2024-12-23 06:07:05,492 - Quant - INFO - Started Optimizing weight rounding of module: features.3.conv.2


                                                

2024-12-23 06:07:06,197 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.0.0


                                                

2024-12-23 06:07:06,910 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.1.0


                                                

2024-12-23 06:07:07,642 - Quant - INFO - Started Optimizing weight rounding of module: features.4.conv.2


                                                

2024-12-23 06:07:08,406 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.0.0


                                                

2024-12-23 06:07:09,106 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.1.0


                                                

2024-12-23 06:07:09,837 - Quant - INFO - Started Optimizing weight rounding of module: features.5.conv.2


                                                

2024-12-23 06:07:10,548 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.0.0


                                                

2024-12-23 06:07:11,229 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.1.0


                                                

2024-12-23 06:07:11,870 - Quant - INFO - Started Optimizing weight rounding of module: features.6.conv.2


                                                

2024-12-23 06:07:12,032 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.0.0


                                                

2024-12-23 06:07:12,707 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.1.0


                                                

2024-12-23 06:07:12,866 - Quant - INFO - Started Optimizing weight rounding of module: features.7.conv.2


                                                

2024-12-23 06:07:13,476 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.0.0


                                                

2024-12-23 06:07:14,180 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.1.0


                                                

2024-12-23 06:07:14,887 - Quant - INFO - Started Optimizing weight rounding of module: features.8.conv.2


                                                

2024-12-23 06:07:15,560 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.0.0


                                                

2024-12-23 06:07:15,720 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.1.0


                                                

2024-12-23 06:07:16,353 - Quant - INFO - Started Optimizing weight rounding of module: features.9.conv.2


                                                

2024-12-23 06:07:17,044 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.0.0


                                                

2024-12-23 06:07:17,794 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.1.0


                                                

2024-12-23 06:07:18,494 - Quant - INFO - Started Optimizing weight rounding of module: features.10.conv.2


                                                

2024-12-23 06:07:19,234 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.0.0


                                                

2024-12-23 06:07:19,921 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.1.0


                                                

2024-12-23 06:07:20,528 - Quant - INFO - Started Optimizing weight rounding of module: features.11.conv.2


                                                

2024-12-23 06:07:20,709 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.0.0


                                                

2024-12-23 06:07:21,401 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.1.0


                                                

2024-12-23 06:07:22,101 - Quant - INFO - Started Optimizing weight rounding of module: features.12.conv.2


                                                

2024-12-23 06:07:22,850 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.0.0


                                                

2024-12-23 06:07:23,506 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.1.0


                                                 

2024-12-23 06:07:24,141 - Quant - INFO - Started Optimizing weight rounding of module: features.13.conv.2


                                                 

2024-12-23 06:07:24,333 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.0.0


                                                 

2024-12-23 06:07:25,039 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.1.0


                                                 

2024-12-23 06:07:25,716 - Quant - INFO - Started Optimizing weight rounding of module: features.14.conv.2


                                                 

2024-12-23 06:07:26,307 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.0.0


                                                 

2024-12-23 06:07:26,514 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.1.0


                                                 

2024-12-23 06:07:27,195 - Quant - INFO - Started Optimizing weight rounding of module: features.15.conv.2


                                                 

2024-12-23 06:07:27,894 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.0.0


                                                 

2024-12-23 06:07:28,614 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.1.0


                                                 

2024-12-23 06:07:29,343 - Quant - INFO - Started Optimizing weight rounding of module: features.16.conv.2


                                                 

2024-12-23 06:07:30,055 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.0.0


                                                 

2024-12-23 06:07:30,777 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.1.0


                                                 

2024-12-23 06:07:31,490 - Quant - INFO - Started Optimizing weight rounding of module: features.17.conv.2


                                                 

2024-12-23 06:07:32,219 - Quant - INFO - Started Optimizing weight rounding of module: features.18.0


                                                 

2024-12-23 06:07:32,944 - Quant - INFO - Started Optimizing weight rounding of module: classifier.1


100%|██████████| 141/141 [00:33<00:00,  4.25it/s]

2024-12-23 06:07:33,650 - Quant - INFO - Completed Adarounding Model





In [27]:
sim = QuantizationSimModel(model=ada_model,
                           dummy_input=dummy_input,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_output_bw=8, 
                           default_param_bw=4)

sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))

sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=use_cuda)

2024-12-23 06:07:36,770 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.10/dist-packages/aimet_common/quantsim_config/default_config.json
2024-12-23 06:07:36,788 - Quant - INFO - Unsupported op type Squeeze
2024-12-23 06:07:36,788 - Quant - INFO - Unsupported op type Mean
2024-12-23 06:07:36,791 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:default


  sim.set_and_freeze_param_encodings(encoding_path=os.path.join("output", 'adaround.encodings'))


In [28]:
import time
import numpy as np
correct = 0
total = 0
results = []
with torch.no_grad():
    for inputs, labels in testloader:
        torch.cuda.synchronize()
        inputs, labels = inputs.to(device), labels.to(device)
        start = time.time()
        outputs = sim.model(inputs)
        torch.cuda.synchronize()
        end = time.time()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        results.append(end-start)

infer_time = np.mean(results)
print(infer_time)
print(f"Accuracy on test set: {100 * correct / total:.2f}%")

0.016973688632627076
Accuracy on test set: 92.51%
