In [None]:
import torchvision.models as models
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
n_bits = 4

In [None]:
def bn_folding(conv, bn):
    # bn params
    mu = bn.running_mean
    var = bn.running_var
    gamma = bn.weight
    beta = bn.bias
    eps = bn.eps

    # conv params
    w = conv.weight
    b = conv.bias if conv.bias is not None else torch.zeros_like(mu)

    denominator = torch.sqrt(var + eps)

    w_new = w * (gamma / denominator).view(-1, 1, 1, 1)
    b_new = (b - mu) * (gamma / denominator) + beta

    conv.weight.data.copy_(w_new)

    if conv.bias is not None:
        conv.bias.data.copy_(b_new)
    else:
        conv.bias = nn.Parameter(b_new)

    return nn.Identity()

In [None]:
def fuse_resnet_module(model):
    model.eval()

    if hasattr(model, 'bn1') and not isinstance(model.bn1, nn.Identity):
        model.bn1 = bn_folding(model.conv1, model.bn1)


    for layer_name in ['layer1', 'layer2', 'layer3']:
        layer = getattr(model, layer_name)
        for i, block in enumerate(layer):
            print(f'Fusing {layer_name}_{i}')

            block.bn1 = bn_folding(block.conv1, block.bn1)
            block.bn2 = bn_folding(block.conv2, block.bn2)

            if block.downsample is not None:
                ## 0이 conv 1이 bn
                block.downsample[1] = bn_folding(block.downsample[0], block.downsample[1])

    print(" Folding completed")
    return model

In [None]:
import torch.nn.functional as F

def round_ste(x):
    return (x.round() - x).detach() + x

class UniformQuantizer(nn.Module):
    def __init__(self, n_bits=4, symmetric=True, channel_wise=False, is_weight=False):
        super().__init__()

        self.n_bits = n_bits
        self.symmetric = symmetric
        self.channel_wise = channel_wise
        self.is_weight = is_weight # weight or activaition

        # register_buffer를 통해 초기화
        self.register_buffer('max_q', torch.tensor(2 ** (n_bits - 1) - 1))
        self.register_buffer('min_q', torch.tensor(-2 ** (n_bits - 1)))
        self.register_buffer('delta', None)
        self.register_buffer('zero_point', None)

    def init_quantization_params(self, x):

        x_clone = x.clone().detach()

        # channel_wise는 언제 사용할까 -> weight를 양자화 할 때, weight는 채널 마다 제각각이기 때문에
        if self.channel_wise:
            max_w = x_clone.abs().view(x_clone.shape[0], -1).max(1)[0]
        else:
            max_w = torch.max(torch.abs(x_clone))

        if self.symmetric:
            self.delta = (max_w / self.max_q).to(x.device)

            if self.channel_wise:
                self.delta = self.delta.view(-1, 1, 1, 1)
                self.zero_point = torch.zeros_like(self.delta)
            else:
                self.zero_point = torch.tensor(0.0).to(x.device)


    def forward(self, x):

        if self.delta is None:
            self.init_quantization_params(x)

        x_int = round_ste(x / self.delta) + self.zero_point
        x_quant = torch.clamp(x_int, self.min_q, self.max_q)

        return self.delta * (x_quant - self.zero_point)


In [None]:
class AdaRoundQuantizer(nn.Module):
    def __init__(self, uq :UniformQuantizer, x):
        super().__init__()

        self.uq = uq
        self.alpha = None # AdaRound parameter (v)
        self.soft_targets = True # adaround 여부

        # 고정 상수들을 buffer로 등록
        self.register_buffer('gamma', torch.tensor(-0.1))
        self.register_buffer('zeta', torch.tensor(1.1))

        self.init_alpha(x.clone())

    def init_alpha(self, x):
        w_floor = torch.floor(x / self.uq.delta)

        rest = (x / self.uq.delta) - w_floor
        rest = torch.clamp(rest, 0.01, 0.99)

        sig_inv = (rest - self.gamma) / (self.zeta - self.gamma)
        sig_inv = torch.clamp(sig_inv, 0.01, 0.99)

        alpha = torch.log(sig_inv / (1 - sig_inv))
        self.alpha = nn.Parameter(alpha)

    def rectified_sigmoid(self):
        # Equation 23
        x = torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1)

        return x

    def forward(self, x):
        w_floor = torch.floor(x / self.uq.delta)

        # 모드에 따라 0~1 사이의 값(soft) 또는 0/1 값(hard)을 결정
        if self.soft_targets:
            rounding_val = self.rectified_sigmoid()
        else:
            rounding_val = (self.alpha >= 0.5).float()

        w_int = w_floor + rounding_val
        w_quant = self.uq.delta * torch.clamp(w_int - self.uq.zero_point, self.uq.min_q, self.uq.max_q)

        return w_quant

In [None]:
class QuantModule(nn.Module):
    def __init__(self, org_module: nn.Conv2d, weight_quantizer):
        super().__init__()
        self.org_module = org_module
        self.weight_quantizer = weight_quantizer # 아까 만든 AdaRoundQuantizer
        self.use_quantization = True # 스위치
        self.conv_type = dict(bias = self.org_module.bias,
                              stride = self.org_module.stride, padding = self.org_module.padding,
                                dilation=self.org_module.dilation, groups=self.org_module.groups)

    def forward(self, x):
        if self.use_quantization:
            # 가중치 가져오기
            w = self.org_module.weight

            # 가중치 양자화 (여기서 AdaRound가 작동)
            w_q = self.weight_quantizer(w)

            # 양자화된 가중치로 Conv 연산
            out = F.conv2d(x, w_q, **self.conv_type)
            return out
        else:
            return self.org_module(x)

In [None]:
def replace_to_quant_module(model, skip_first=True):
    """
    모델을 재귀적으로 탐색하며 nn.Conv2d를 QuantModule로 교체
    """
    for name, module in model.named_children():

        if skip_first and name in ["conv1", "bn1"]:
            continue

        if isinstance(module, nn.Conv2d):

            # UniformQuantizer 먼저 만들고 -> AdaRoundQuantizer로 감싸기
            uq = UniformQuantizer(n_bits=n_bits, symmetric=True, channel_wise=True, is_weight=True)

            # 초기화를 위해 가중치 한번 넣어줌 (init_quantization_params)
            uq.init_quantization_params(module.weight)

            # AdaRound Quantizer 생성
            ada_quantizer = AdaRoundQuantizer(uq, module.weight)

            # Wrapper로 교체
            quant_module = QuantModule(module, ada_quantizer)
            setattr(model, name, quant_module)

        elif len(list(module.children())) > 0:
            # 자식 모듈이 더 있으면 재귀 호출
            replace_to_quant_module(module, skip_first=False)

In [None]:
class EfficientDataSaver:
    def __init__(self, total_samples, in_c, in_h, in_w, out_c, out_h, out_w):
        self.inputs = torch.zeros((total_samples, in_c, in_h, in_w), dtype=torch.float32)
        self.outputs = torch.zeros((total_samples, out_c, out_h, out_w), dtype=torch.float32)
        self.grads = torch.zeros((total_samples, out_c, out_h, out_w), dtype=torch.float32)
        self.idx = 0
        self.backward_idx = 0

    def hook_fn(self, module, input, output):
        batch_size = input[0].shape[0]
        curr_idx = self.idx

        if curr_idx + batch_size <= self.inputs.shape[0]:

            self.inputs[curr_idx : curr_idx + batch_size] = input[0].detach().cpu()
            self.outputs[curr_idx : curr_idx + batch_size] = output.detach().cpu()

            self.idx += batch_size

    def hook_backward(self, module, grad_input, grad_output):
        if grad_output[0] is None:
            return None

        batch_size = grad_output[0].shape[0]
        curr_idx = self.backward_idx

        if curr_idx + batch_size <= self.grads.shape[0]:

            self.grads[curr_idx : curr_idx + batch_size] = grad_output[0].detach().cpu()
            self.backward_idx += batch_size

        return None

In [None]:
from torch.optim import Adam
import numpy as np

def block_reconstruction(block, cali_inputs, cali_outputs, cali_grads, quantizers, batch_size=64, iters=1000):

    # 학습 파라미터 추출 및 adaround 설정
    params = []
    for q in quantizers:
        params.append(q.alpha)
        q.soft_targets = True

    optimizer = Adam(params, lr=1e-3)

    dataset = TensorDataset(cali_inputs.cpu(), cali_outputs.cpu(), cali_grads.cpu())
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    beta_scheduler = np.linspace(20, 2, iters)

    for i in range(iters):
        beta = beta_scheduler[i]

        for x_batch, y_batch, g_batch in loader:
            # 여기서 데이터를 GPU로 이동하여 메모리 절약
            cur_input = x_batch.to(device) # Input
            cur_target = y_batch.to(device) # Output (Target)
            cur_grad = g_batch.to(device) # Fisher (Weight)

            optimizer.zero_grad()

            # Forward (Soft Quantization)
            out_quant = block(cur_input)

            # Fisher Loss 계산
            # sum( (output_diff * grad)^2 )
            delta = out_quant - cur_target

            loss_rec = (delta * cur_grad).pow(2).sum()

            # 데이터 개수로 정규화
            loss_rec = loss_rec / batch_size

            # Regularization Loss
            loss_reg = 0

            for q in quantizers:
                soft_val = q.rectified_sigmoid()
                reg_term = 1.0 - (2 * soft_val - 1).abs().pow(beta)
                loss_reg += reg_term.sum()

            # Total Loss
            total_loss = loss_rec + 1e-4 * loss_reg

            total_loss.backward()
            optimizer.step()

        if i % 200 == 0:
            print(f"Iter {i}: Total {total_loss.item():.4f} (Rec {loss_rec.item():.10f})")


    # 학습 종료 후 Hard Mode로 전환
    for q in quantizers:
        q.soft_targets = False

In [None]:
def get_quantizers_from_block(block):
    """블록 내부의 모든 QuantModule에서 weight_quantizer를 추출"""
    quantizers = []
    for name, module in block.named_modules():
        if isinstance(module, QuantModule):
            quantizers.append(module.weight_quantizer)
    return quantizers


def set_quant_state(module, use_quant=True):
    """
    블록 내의 모든 QuantModule의 양자화 스위치를 켜거나 끕니다.
    """
    for m in module.modules():
        if isinstance(m, QuantModule):
            m.use_quantization = use_quant


def run_brecq(model, dataloader, target_block, num_samples=128):
    model.eval().to(device)

    sample_img, _ = next(iter(dataloader))
    sample_img = sample_img[:1].to(device) # 1개만 사용

    block_shape = {}
    def get_shape_hook(module, input, output):
        # input[0]은 입력 텐서, output은 출력 텐서
        block_shape['in'] = input[0].shape[1:]  # (C, H, W)
        block_shape['out'] = output.shape[1:]    # (C, H, W)

    handle = target_block.register_forward_hook(get_shape_hook)
    with torch.no_grad():
        model(sample_img)
    handle.remove()

    in_c, in_h, in_w = block_shape['in']
    out_c, out_h, out_w = block_shape['out']

    # 입력과 출력의 크기가 다를 수 있으므로(stride 등) 각각 할당합니다.
    saver = EfficientDataSaver(num_samples, in_c, in_h, in_w, out_c, out_h, out_w)

    set_quant_state(target_block, use_quant=False)

    # Hook 등록
    # forward: 입력(input)과 정답(output) 수집
    h1 = target_block.register_forward_hook(saver.hook_fn)
    # backward: 중요도(gradient) 수집
    h2 = target_block.register_full_backward_hook(saver.hook_backward)

    criterion = nn.CrossEntropyLoss()

    print("Step 1: Collecting Data & Gradients...")
    current_samples = 0
    for imgs, labels in dataloader:
        if current_samples >= num_samples: break
        imgs, labels = imgs.to(device), labels.to(device)
        imgs.requires_grad_(True)

        model.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()

        current_samples += imgs.shape[0]


    # Hook 제거
    h1.remove()
    h2.remove()

    set_quant_state(target_block, use_quant=True)

    # Quantizer 추출
    quantizers = get_quantizers_from_block(target_block)

    print("Step 2: Optimizing Block...")
    block_reconstruction(
        target_block,
        saver.inputs, saver.outputs, saver.grads,
        quantizers,
        iters=1000
    )

    print("Done!")

In [None]:
def validate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.shape[0]
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
def optimize_full_model(model, calib_loader):
    target_blocks = []

    if isinstance(model.conv1, QuantModule):
        target_blocks.append(model.conv1)

    # 각 Layer 내부의 BasicBlock들 추출
    for layer_name in ['layer1', 'layer2', 'layer3']:
        layer = getattr(model, layer_name)
        for block in layer:
            target_blocks.append(block)

    print(f"Total blocks to optimize: {len(target_blocks)}")

    # 블록별 순차 최적화 진행
    for i, block in enumerate(target_blocks):
        print(f"\n>>> [Step {i+1}/{len(target_blocks)}] Optimizing Block: {block.__class__.__name__}")
        run_brecq(model, calib_loader, block)

In [None]:
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

# 1. 데이터셋 준비 (Test용과 Calibration용 분리)
calib_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
calib_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(calib_dataset, range(1024)), batch_size=32, shuffle=False)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 2. Baseline(FP32) 성능 측정
base_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet44", pretrained=True).to(device)
base_model.eval()
print("Calculating Baseline Accuracy...")
fp32_acc = validate_model(base_model, test_loader, device)
print(f"Baseline (FP32) Accuracy: {fp32_acc:.2f}%")

# 3. 양자화 모델 준비 (BN Folding & Module Replacement)
# base_model을 그대로 쓰면 덮어씌워지므로
def disable_inplace_relu(model):
    for m in model.modules():
        if isinstance(m, nn.ReLU):
            m.inplace = False

# 실행 코드에 추가
q_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet44", pretrained=True).to(device)
disable_inplace_relu(q_model) # 최적화 전 반드시 실행
q_model = fuse_resnet_module(q_model)
replace_to_quant_module(q_model) #
q_model.to(device)

# 모델의 첫 번째 파라미터가 어디에 있는지 확인
print(f"Model device: {next(q_model.parameters()).device}")

# 4. 전체 모델 BRECQ 최적화 실행
optimize_full_model(q_model, calib_loader)

# 5. 양자화 모델 성능 측정
print("\nCalculating Quantized Model Accuracy...")
q_acc = validate_model(q_model, test_loader, device)

# 6. 최종 결과 비교
print("\n" + "="*30)
print(f"FP32 Accuracy: {fp32_acc:.2f}%")
print(f"W4A32 BRECQ Accuracy: {q_acc:.2f}%")
print(f"Accuracy Drop: {fp32_acc - q_acc:.2f}%")
print("="*30)

100%|██████████| 169M/169M [00:09<00:00, 17.8MB/s]


Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/zipball/master" to /root/.cache/torch/hub/master.zip
Downloading: "https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet44-ffe32858.pt" to /root/.cache/torch/hub/checkpoints/cifar100_resnet44-ffe32858.pt


100%|██████████| 2.64M/2.64M [00:00<00:00, 75.0MB/s]


Calculating Baseline Accuracy...
Baseline (FP32) Accuracy: 67.79%


Using cache found in /root/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


Fusing layer1_0
Fusing layer1_1
Fusing layer1_2
Fusing layer1_3
Fusing layer1_4
Fusing layer1_5
Fusing layer1_6
Fusing layer2_0
Fusing layer2_1
Fusing layer2_2
Fusing layer2_3
Fusing layer2_4
Fusing layer2_5
Fusing layer2_6
Fusing layer3_0
Fusing layer3_1
Fusing layer3_2
Fusing layer3_3
Fusing layer3_4
Fusing layer3_5
Fusing layer3_6
 Folding completed
Model device: cuda:0
Total blocks to optimize: 21

>>> [Step 1/21] Optimizing Block: BasicBlock
Step 1: Collecting Data & Gradients...
Step 2: Optimizing Block...
Iter 0: Total 0.4385 (Rec 0.0000000004)
Iter 200: Total 0.3302 (Rec 0.0000001768)
Iter 400: Total 0.2061 (Rec 0.0000004140)
Iter 600: Total 0.1614 (Rec 0.0000006563)
Iter 800: Total 0.1050 (Rec 0.0000011631)
Done!

>>> [Step 2/21] Optimizing Block: BasicBlock
Step 1: Collecting Data & Gradients...
Step 2: Optimizing Block...
Iter 0: Total 0.4384 (Rec 0.0000000001)
Iter 200: Total 0.3359 (Rec 0.0000001498)
Iter 400: Total 0.2089 (Rec 0.0000001890)
Iter 600: Total 0.1608 (Rec 0.0

In [None]:
import torch
import torch.nn.functional as F

def pack_weights(w_int8):
    tmp = 2**(n_bits-1)
    group = 32 // n_bits
    w_uint = (w_int8.t().contiguous() + tmp).to(torch.int32)

    K, N = w_uint.shape
    w_reshaped = w_uint.view(K, N // group, group)

    w_packed = torch.zeros((K, N//group), dtype=torch.int32, device=w_int8.device)

    for i in range(group):
        curr = w_reshaped[:, :, i]
        w_packed |= (curr << (i * n_bits))

    return w_packed

def quantization(w, quantizer):

    w_floor = torch.floor(w / quantizer.uq.delta)
    rounding = (quantizer.alpha >= 0).float()
    w_int = w_floor + rounding
    w_int = torch.clamp(w_int - quantizer.uq.zero_point, quantizer.uq.min_q, quantizer.uq.max_q)

    return w_int

def padding_params(w, scale, zp, bias, BN=128):
    out_ch, in_ch = w.shape[0], w.shape[1]

    pad_n = (BN - out_ch % BN) % BN

    if pad_n > 0:
        w = F.pad(w, (0, 0, 0, pad_n), value=0.0)
        scale = F.pad(scale, (0, pad_n), value=1.0)
        zp = F.pad(zp, (0, pad_n), value=0.0)
        bias = F.pad(bias, (0, pad_n), value=0.0)

    return w, scale, zp, bias



def export_quantized_model(model, save_path="resnet44_w4a32.pt"):
    print("Exporting Quantized Model with Dual Padding (K & N)...")
    quantized_state_dict = {}

    for name, module in model.named_modules():
        if isinstance(module, QuantModule):
            print(f"Processing: {name}")

            # 1.
            quantizer = module.weight_quantizer
            w = module.org_module.weight # (Out, In, k, k)
            delta = quantizer.uq.delta

            # 2.
            w_int = quantization(w, quantizer)
            out_ch = w_int.shape[0]
            w_flat = w_int.view(out_ch, -1) # [Out, K]

            scale = delta.cpu().view(-1) # [Out, ]
            zp = quantizer.uq.zero_point.cpu().view(-1) # [Out, ]

            if module.org_module.bias is not None:
                bias = module.org_module.bias.cpu()
            else:
                bias = torch.zeros_like(scale)

            # 3.
            w_padded, scale, zp, bias = padding_params(w_flat, scale, zp, bias)

            w_packed = pack_weights(w_padded).contiguous() # [K, Out / (32 / n_bits)]
            zp = zp + 2**(n_bits-1)

            # 딕셔너리에 저장
            quantized_state_dict[f"{name}.w_packed"] = w_packed.cpu() # [K, Out_padded / (32 / n_bits)]
            quantized_state_dict[f"{name}.scale"] = scale # [Out_padded, ]
            quantized_state_dict[f"{name}.zero_point"] = zp # [Out_padded, ]
            quantized_state_dict[f"{name}.bias"] = bias # [Out_padded, ]

    # 저장
    torch.save(quantized_state_dict, save_path)
    print(f"Saved to {save_path}")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os

# 1. 저장할 경로와 파일명 지정
# 예: 내 드라이브의 'Projects' 폴더 안에 'resnet_quantized.pt'로 저장하고 싶다면:
save_dir = '/content/drive/MyDrive/Colab Notebooks/source'
filename = 'resnet44_w4a32.pt'
save_path = os.path.join(save_dir, filename)

# (폴더가 없으면 에러가 나므로 미리 생성해주는 것이 안전합니다)
os.makedirs(save_dir, exist_ok=True)

In [None]:
export_quantized_model(q_model, save_path)

Exporting Quantized Model with Dual Padding (K & N)...
Processing: layer1.0.conv1
Processing: layer1.0.conv2
Processing: layer1.1.conv1
Processing: layer1.1.conv2
Processing: layer1.2.conv1
Processing: layer1.2.conv2
Processing: layer1.3.conv1
Processing: layer1.3.conv2
Processing: layer1.4.conv1
Processing: layer1.4.conv2
Processing: layer1.5.conv1
Processing: layer1.5.conv2
Processing: layer1.6.conv1
Processing: layer1.6.conv2
Processing: layer2.0.conv1
Processing: layer2.0.conv2
Processing: layer2.0.downsample.0
Processing: layer2.1.conv1
Processing: layer2.1.conv2
Processing: layer2.2.conv1
Processing: layer2.2.conv2
Processing: layer2.3.conv1
Processing: layer2.3.conv2
Processing: layer2.4.conv1
Processing: layer2.4.conv2
Processing: layer2.5.conv1
Processing: layer2.5.conv2
Processing: layer2.6.conv1
Processing: layer2.6.conv2
Processing: layer3.0.conv1
Processing: layer3.0.conv2
Processing: layer3.0.downsample.0
Processing: layer3.1.conv1
Processing: layer3.1.conv2
Processing: l