In [1]:
import importlib
import utils

# Reload utils module
importlib.reload(utils)

# import 
from utils import *

In [2]:
model = ResNet18().to(device)
model.load_state_dict(torch.load("resnet18_float32.pth"))

<All keys matched successfully>

In [None]:
class STEQuant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale):
        s = scale
        q = torch.clamp(torch.round(x/s), -128, 127)
        return q * s
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None
    
class LSQFakeQuant(nn.Module):
    def __init__(self):
        super().__init__()
        self.scale = nn.Parameter(torch.tensor(0.1)) # learnable, dummy value
        self.register_buffer('initialized', torch.tensor(False))
    
    def forward(self, x):
        if self.training:
            if not self.initialized:
                Qp = 127.0
                mean_abs = x.detach().abs().mean()
                init_scale = 2*mean_abs/ (Qp ** 0.5)
                self.scale.data.copy_(init_scale)
                self.initialized.fill_(True)

            return STEQuant.apply(x, self.scale)
        return x
    
# LSQ QAT Model
class LSQQATModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

        # put LSQFakeQuant on main activation
        self.fq_conv1 = LSQFakeQuant()
        self.fq_layer1 = LSQFakeQuant()
        self.fq_layer2 = LSQFakeQuant()
        self.fq_layer3 = LSQFakeQuant()
        self.fq_layer4 = LSQFakeQuant()

    def forward(self, x):
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        x = self.fq_conv1(x)

        x = self.base.layer1(x)
        x = self.fq_layer1(x)
        x = self.base.layer2(x)
        x = self.fq_layer2(x)
        x = self.base.layer3(x)
        x = self.fq_layer3(x)
        x = self.base.layer4(x)
        x = self.fq_layer4(x)

        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.base.fc(x)
        return x

In [17]:
qat_model = LSQQATModel(model).to(device)

In [18]:
# Weight learnable scale
weight_scale_params = {}
Qp= 127.0
for name, p in qat_model.named_parameters():
    if 'weight' in name and p.dim() > 1:
        original_name = name.replace('base.', '')
        mean_abs = p.detach().abs().abs().mean().item()
        init_scale = 2 * mean_abs / (Qp ** 0.5)
        weight_scale_params[original_name] = nn.Parameter(torch.tensor(init_scale))

# LSQ weight quantization function
def apply_lsq_weight_quant(model, scale_params):
    for name, p in model.named_parameters():
        if 'weight' in name and p.dim() > 1:
            original_name = name.replace('base.', '')
            s = scale_params[original_name]
            p.data = STEQuant.apply(p.data, s)

apply_lsq_weight_quant(qat_model, weight_scale_params)

In [19]:
# Optimizer
optimizer = optim.Adam(list(qat_model.parameters()) + list(weight_scale_params.values()), lr = 1e-4)
criterion = nn.CrossEntropyLoss()
epochs = 30
for epoch in range(epochs):
    qat_model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(qat_model(x), y)
        loss.backward()
        optimizer.step()
        apply_lsq_weight_quant(qat_model, weight_scale_params)
    
    print(f"Epoch {epoch+1} Acc: {evaluate(qat_model):.2f}%")

AttributeError: 'Tensor' object has no attribute 'fill'