In [1]:
%load_ext autoreload
%autoreload 2
import anvil
import anvil.adaround

import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import copy
import torch.optim as optim
import os, pathlib

base_path = pathlib.Path("/home/mpuscian/Desktop/repozytoria/MINI_projects/anvil/models/")
model_path = base_path.joinpath("cifar_model2.pth")
adaround_model_path = base_path.joinpath("adaround_model.pth")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = resnet18(weights=None)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(512, 10)

adaround_model = copy.deepcopy(model)
adaround_model.load_state_dict(torch.load(model_path, map_location=device))

# Stwórz wejście testowe (np. batch 1 obrazka)
sample_input = torch.randn(1, 3, 32, 32)

In [49]:
def test():
    module = adaround_model.conv1

    captured_input = None

    def hook_fn(module, input, output):
        nonlocal captured_input
        captured_input = input[0].detach()

    hook = module.register_forward_hook(hook_fn)

    with torch.no_grad():
        _ = adaround_model(sample_input)

    hook.remove()

    if captured_input is None:
        raise RuntimeError(f"Nie udało się przechwycić wejścia do warstwy conv1")

    # Kwantyzacja wag
    quantized_weights = anvil.adaround.adaround_layer(module, captured_input)

    return quantized_weights, captured_input

In [50]:
qw, cap_inp = test()

final h_alpha: tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.

In [72]:
qw[0][0]

tensor([[-0.1177, -0.0938, -0.1177],
        [ 0.1894, -0.0598, -0.0787],
        [-0.1001, -0.1315,  0.1869]])

In [74]:
sc, zp = anvil.adaround.get_qparams(qw[0][0], qmin=-128, qmax=127)

In [76]:
torch.round(qw[0][0] / sc + zp).clamp(-128, 127).to(torch.int8)

tensor([[-117,  -98, -117],
        [ 127,  -70,  -86],
        [-103, -128,  126]], dtype=torch.int8)

# Layer quantization test

In [52]:
inputs = cap_inp
num_iterations=1000
beta_range=(20, 2)
reg_param=0.01
per_channel=True

layer = adaround_model.conv1

signed = False  # Można rozpoznać np. po typie aktywacji
qmin, qmax = (-128, 127) if signed else (0, 255)

weight = layer.weight.detach()
alpha = nn.Parameter(torch.zeros_like(weight))
scale_w, zp_w = anvil.adaround.get_qparams(weight, qmin, qmax, per_channel=True, channel_axis=0)

optimizer = torch.optim.Adam([alpha], lr=1e-2)
best_loss = float("inf")
best_alpha = alpha.data.clone()

# Kwantyzacja wejścia (per-tensor)
scale_in, zp_in = anvil.adaround.get_qparams(inputs, qmin, qmax, per_channel=False)
inputs_q = anvil.adaround.quantize_tensor(inputs, scale_in, zp_in, qmin, qmax)

for step in range(num_iterations):
    optimizer.zero_grad()

    # Kwantyzacja wag
    weight_q = anvil.adaround.adaround_weight(weight / scale_w + zp_w, alpha)
    weight_q = scale_w * (weight_q - zp_w)

    # Forward oryginalny vs. kwantyzowany
    out_fp = layer(inputs)
    out_q = F.conv2d(inputs_q, weight_q, bias=layer.bias, stride=layer.stride,
                        padding=layer.padding, dilation=layer.dilation, groups=layer.groups)

    loss_data = F.mse_loss(out_q, out_fp)

    beta = beta_range[0] * (1 - step / num_iterations) + beta_range[1] * (step / num_iterations)
    h_alpha = torch.clamp(torch.sigmoid(alpha) * 1.2 - 0.1, 0, 1) #according to up and down
    reg = torch.sum(1 - torch.abs(2 * h_alpha - 1) ** beta)

    loss = loss_data + reg_param * reg
    loss.backward()
    optimizer.step()

    if loss.item() < best_loss:
        best_loss = loss.item()
        best_alpha = alpha.data.clone()

# Finalizacja
h_alpha = torch.clamp(torch.sigmoid(best_alpha) * 1.2 - 0.1, 0, 1)
print(f"final h_alpha: {h_alpha}")
final_w_q = scale_w * (torch.floor(weight / scale_w + zp_w) + h_alpha - zp_w)

final h_alpha: tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]],

         [[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.

In [68]:
inputs_q[0][0][0][0]

tensor(0.9440)

In [3]:
tensor = torch.Tensor([-456., 328., 555., 678., 155., 676.])
tensor

tensor([-456.,  328.,  555.,  678.,  155.,  676.])

In [7]:
s = (tensor.max() - tensor.min())/255

In [29]:
torch.floor(tensor/s).to(torch.int8)

tensor([-103,   73,  124, -104,   34, -104], dtype=torch.int8)