In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torch.utils.data import TensorDataset, DataLoader

In [2]:
class WeightQuantizer(nn.Module):
    def __init__(self, weights, bits=8):
        super().__init__()
        self.weights = weights
        self.bits = bits
        self.scale = nn.Parameter(torch.ones(1))
        self.zero_point = nn.Parameter(torch.zeros(1))
        
    def forward(self):
        q_min, q_max = 0, 2**self.bits - 1
        scale = (self.weights.max() - self.weights.min()) / (q_max - q_min)
        quantized = torch.clamp((self.weights / scale).round(), q_min, q_max)
        return scale * quantized + self.zero_point

In [3]:
class ActivationQuantizer(nn.Module):
    def __init__(self, bits=8):
        super().__init__()
        self.bits = bits
        self.scale = nn.Parameter(torch.ones(1))
        self.zero_point = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        q_min, q_max = 0, 2**self.bits - 1
        scale = x.abs().max() / q_max
        quantized = torch.clamp((x / scale).round(), q_min, q_max)
        return scale * quantized + self.zero_point

In [4]:
class QuantizedLinear(nn.Module):
    def __init__(self, original_linear, weight_bits=8, activation_bits=8):
        super().__init__()
        self.original_linear = original_linear
        self.weight_quantizer = WeightQuantizer(original_linear.weight, weight_bits)
        self.activation_quantizer = ActivationQuantizer(activation_bits)
        
    def forward(self, x):
        x = self.activation_quantizer(x)
        self.original_linear.weight.data = self.weight_quantizer()
        return self.original_linear(x)

In [5]:
class QuantizedConv2d(nn.Module):
    def __init__(self, original_conv, weight_bits=8, activation_bits=8):
        super().__init__()
        self.original_conv = original_conv
        self.weight_quantizer = WeightQuantizer(original_conv.weight, weight_bits)
        self.activation_quantizer = ActivationQuantizer(activation_bits)
        
    def forward(self, x):
        x = self.activation_quantizer(x)
        self.original_conv.weight.data = self.weight_quantizer()
        return self.original_conv(x)

In [6]:
def quantize_resnet18(weight_bits=8, activation_bits=8):
    model = resnet18(pretrained=True)
    model.eval()
    
    for name, module in model.named_modules():
        parent = model
        names = name.split('.')
        for n in names[:-1]:
            parent = getattr(parent, n)
        
        if isinstance(module, nn.Conv2d):
            setattr(parent, names[-1], QuantizedConv2d(module, weight_bits, activation_bits))
        elif isinstance(module, nn.Linear):
            setattr(parent, names[-1], QuantizedLinear(module, weight_bits, activation_bits))
        elif isinstance(module, nn.BatchNorm2d):
            # Quantize BatchNorm parameters (optional)
            module.weight.data = WeightQuantizer(module.weight, weight_bits)()
            module.bias.data = WeightQuantizer(module.bias, weight_bits)()
    
    return model

In [7]:
#Loading data generated from Genie D
synthetic_data = torch.load('dataset_checkpoint_final.pt', map_location=torch.device('cpu'))
synthetic_data = synthetic_data['dataset']
synthetic_data = torch.tensor(synthetic_data)

  synthetic_data = torch.tensor(synthetic_data)


In [8]:
def calibrate_model(model, data_loader):
    model.eval()
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(data_loader):
            data = batch_data[0] if isinstance(batch_data, (list, tuple)) else batch_data
            if batch_idx % 10 == 0:
                print(f"Calibrating batch {batch_idx}")
            model(data)
    return model

In [9]:
def main():
    weight_bits = 4
    activation_bits = 4
    batch_size = 32
    
    data_loader = DataLoader(TensorDataset(synthetic_data), batch_size=batch_size)
    
    print("Quantizing ResNet18...")
    quant_model = quantize_resnet18(weight_bits, activation_bits)
    
    print("Calibrating...")
    quant_model = calibrate_model(quant_model, data_loader)
    
    torch.save(quant_model.state_dict(), f"resnet18_quantized_W{weight_bits}A{activation_bits}.pth")
    torch.save(quant_model, "quantized_resnet18_full_.pth")
    print(f"Quantized model saved!")

if __name__ == "__main__":
    main()



Quantizing ResNet18...
Calibrating...
Calibrating batch 0
Calibrating batch 10
Calibrating batch 20
Calibrating batch 30
Quantized model saved!
