In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import copy
import pathlib

base_path = pathlib.Path("/home/mpuscian/Desktop/repozytoria/MINI_projects/anvil/models/")
cifar_model_path = base_path.joinpath("cifar_model.pth")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = resnet18(weights=None)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(512, 10)

adaround = copy.deepcopy(model)

adaround.load_state_dict(torch.load(cifar_model_path))
adaround.eval().to(device)

print(f"Using device: {device}")

# 2. Transforms
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# 3. Datasets
full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_size = int(0.9 * len(full_train_dataset))  # 45,000
val_size = len(full_train_dataset) - train_size  # 5,000
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# 4. DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=6)
val_dataset.dataset.transform = transform_test
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=6)

Using device: cuda


In [2]:
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np

val_dataset = val_loader.dataset  # your full val dataset

num_samples = len(val_dataset)
subset_size = num_samples // 5

# Randomly sample 1/5 indices
all_indices = np.arange(num_samples)
np.random.shuffle(all_indices)
subset_indices = all_indices[:subset_size]

# Create a sampler using those indices
sampler = SubsetRandomSampler(subset_indices)

# Create DataLoader with this sampler
subset_loader = DataLoader(
    val_dataset,
    batch_size=val_loader.batch_size,
    sampler=sampler,
    num_workers=val_loader.num_workers,
    pin_memory=True
)

In [12]:
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
from aimet_torch.quantsim import QuantizationSimModel
from aimet_common.defs import QuantScheme

def calibration_data_loader():
    for i, (x, _) in enumerate(val_loader):
        yield x.to(device)
        if i >= 9:  # Limit to 10 batches, as specified in AdaroundParameters
            break


params = AdaroundParameters(
    data_loader=subset_loader,
    num_batches=min(10, len(subset_loader)),
    default_num_iterations=3000,
    default_reg_param=0.07,
    default_beta_range=(40, 2)
)

dummy_input = torch.randn(1, 3, 32, 32).to(device)

adaround = Adaround.apply_adaround(
    model=adaround,
    dummy_input=dummy_input,
    params=params,
    path=base_path.joinpath('aimet/adaround_encodings'),
    filename_prefix='resnet18',
    default_param_bw=4,
    default_quant_scheme=QuantScheme.post_training_tf_enhanced
)

2025-05-28 15:56:59,349 - Quant - INFO - Unsupported op type Squeeze
2025-05-28 15:56:59,349 - Quant - INFO - Unsupported op type Mean
2025-05-28 15:56:59,349 - Quant - INFO - Unsupported op type Unsqueeze
2025-05-28 15:56:59,349 - Quant - INFO - Unsupported op type Compress
2025-05-28 15:56:59,349 - Quant - INFO - Unsupported op type Identity
2025-05-28 15:56:59,350 - Quant - INFO - Unsupported op type Shape
2025-05-28 15:56:59,350 - Quant - INFO - Unsupported op type If
2025-05-28 15:56:59,351 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:None
2025-05-28 15:57:00,196 - Utils - INFO - Caching 8 batches from data loader at path location: /tmp/tmpovkds60d


                                                                                                                                                                                        

2025-05-28 15:57:00,200 - Quant - INFO - Started Optimizing weight rounding of module: conv1


                                                                                                                                                                                        

2025-05-28 15:57:02,247 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv1


                                                                                                                                                                                        

2025-05-28 15:57:04,531 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv2


                                                                                                                                                                                        

2025-05-28 15:57:06,837 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv1


                                                                                                                                                                                        

2025-05-28 15:57:09,152 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv2


                                                                                                                                                                                        

2025-05-28 15:57:11,492 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv1


                                                                                                                                                                                        

2025-05-28 15:57:13,891 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv2


                                                                                                                                                                                        

2025-05-28 15:57:16,159 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.downsample.0


                                                                                                                                                                                        

2025-05-28 15:57:18,530 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv1


                                                                                                                                                                                        

2025-05-28 15:57:20,790 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv2


                                                                                                                                                                                        

2025-05-28 15:57:23,085 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv1


                                                                                                                                                                                        

2025-05-28 15:57:25,329 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv2


                                                                                                                                                                                        

2025-05-28 15:57:27,564 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.downsample.0


                                                                                                                                                                                        

2025-05-28 15:57:29,743 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv1


                                                                                                                                                                                        

2025-05-28 15:57:32,000 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv2


                                                                                                                                                                                        

2025-05-28 15:57:34,289 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv1


                                                                                                                                                                                        

2025-05-28 15:57:36,642 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv2


                                                                                                                                                                                        

2025-05-28 15:57:40,569 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.downsample.0


                                                                                                                                                                                        

2025-05-28 15:57:42,762 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv1


                                                                                                                                                                                        

2025-05-28 15:57:46,829 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv2


                                                                                                                                                                                        

2025-05-28 15:57:50,917 - Quant - INFO - Started Optimizing weight rounding of module: fc


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:52<00:00,  1.13it/s]

2025-05-28 15:57:53,096 - Quant - INFO - Completed Adarounding Model





In [13]:
print(f"Model first layer device: {next(adaround.parameters()).device}")
print(f"Dummy input device: {dummy_input.device}")

Model first layer device: cuda:0
Dummy input device: cuda:0


In [17]:
def forward_pass(model, _ =  None):
    for batch in calibration_data_loader():
        model(batch)

sim = QuantizationSimModel(
    model=adaround,
    dummy_input=dummy_input,
    quant_scheme=QuantScheme.post_training_tf_enhanced,
    default_output_bw=4,
    default_param_bw=4,
)
sim.compute_encodings(forward_pass)

2025-05-28 16:00:11,718 - Quant - INFO - Unsupported op type Squeeze
2025-05-28 16:00:11,718 - Quant - INFO - Unsupported op type Mean
2025-05-28 16:00:11,718 - Quant - INFO - Unsupported op type Unsqueeze
2025-05-28 16:00:11,719 - Quant - INFO - Unsupported op type Compress
2025-05-28 16:00:11,719 - Quant - INFO - Unsupported op type Identity
2025-05-28 16:00:11,719 - Quant - INFO - Unsupported op type Shape
2025-05-28 16:00:11,719 - Quant - INFO - Unsupported op type If
2025-05-28 16:00:11,720 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:None


# Evaluation

In [15]:
def evaluate_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

In [8]:
unquantized_model = copy.deepcopy(model)

unquantized_model.load_state_dict(torch.load(cifar_model_path))
unquantized_model = unquantized_model.eval().to(device)

In [16]:
print(f"Unquantized model accuracy(TRAIN): {evaluate_model(unquantized_model, train_loader, device) * 100:.2f}%")
print(f"Unquantized model accuracy(TEST): {evaluate_model(unquantized_model, test_loader, device) * 100:.2f}%")
print(f"Quantized model accuracy(TRAIN): {evaluate_model(sim.model, train_loader, device) * 100:.2f}%")
print(f"Quantized model accuracy(TEST): {evaluate_model(sim.model, test_loader, device) * 100:.2f}%")

Unquantized model accuracy(TRAIN): 92.40%
Unquantized model accuracy(TEST): 82.58%
Quantized model accuracy(TRAIN): 87.08%
Quantized model accuracy(TEST): 78.85%


In [49]:
print(f"Unquantized model accuracy(TRAIN): {evaluate_model(unquantized_model, train_loader, device) * 100:.2f}%")
print(f"Unquantized model accuracy(TEST): {evaluate_model(unquantized_model, test_loader, device) * 100:.2f}%")
print(f"Quantized model accuracy(TRAIN): {evaluate_model(sim.model, train_loader, device) * 100:.2f}%")
print(f"Quantized model accuracy(TEST): {evaluate_model(sim.model, test_loader, device) * 100:.2f}%")

Unquantized model accuracy(TRAIN): 92.46%
Unquantized model accuracy(TEST): 82.58%
Quantized model accuracy(TRAIN): 92.37%
Quantized model accuracy(TEST): 82.50%


In [36]:
print(f"Unquantized model accuracy(TRAIN): {evaluate_model(unquantized_model, train_loader, device) * 100:.2f}%")
print(f"Unquantized model accuracy(TEST): {evaluate_model(unquantized_model, test_loader, device) * 100:.2f}%")
print(f"Quantized model accuracy(TRAIN): {evaluate_model(sim.model, train_loader, device) * 100:.2f}%")
print(f"Quantized model accuracy(TEST): {evaluate_model(sim.model, test_loader, device) * 100:.2f}%")

Unquantized model accuracy(TRAIN): 92.46%
Unquantized model accuracy(TEST): 82.58%
Quantized model accuracy(TRAIN): 92.38%
Quantized model accuracy(TEST): 82.48%
