In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18
from tqdm import tqdm

In [2]:
def reconstruct_block(block, quantizer, data, epochs=10, lr=1e-3):
    block.eval()
    
    #initialize V
    with torch.no_grad():
        dum_output = block(data[0:1])
    
    optimizer = optim.Adam([quantizer.s, quantizer.V], lr=lr)
    criterion = nn.MSELoss()
    
    for _ in range(epochs):
        for batch in data.split(32):
            with torch.no_grad():
                output_fp32 = block(batch)
            output_quant = block(quantizer(batch))
            loss = criterion(output_fp32, output_quant)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [3]:
class GenieQuantizer(nn.Module):
    def __init__(self, bitwidth=4):
        super().__init__()
        self.s = nn.Parameter(torch.tensor(1.0)) 
        self.V = nn.Parameter(torch.zeros(1))     
        self.bitwidth = bitwidth
        self.is_initialized = False

    def forward(self, x):
        return self.quantize(x)

    def quantize(self, x):
        if not self.is_initialized or self.V.shape != x.shape:
            # Re-initialize V to match input shape
            self.V = nn.Parameter(torch.zeros_like(x))
            self.is_initialized = True
            
        B = torch.floor(x / self.s.detach())  #  Eq. 9
        return self.s * (B + torch.sigmoid(self.V))  # Eq. 10

In [4]:
def get_blocks(model):
    blocks = []
    # For ResNet
    blocks.append(nn.Sequential(model.conv1, model.bn1, model.relu))
    blocks.append(model.layer1)
    blocks.append(model.layer2)
    blocks.append(model.layer3)
    blocks.append(model.layer4)
    blocks.append(nn.Sequential(
    model.avgpool,
    nn.Flatten(start_dim=1),  # Correct flattening
    model.fc
))
    return blocks

In [5]:
def adaround(weights, quantizer, synthetic_data):
    # Optimize V (soft rounding) while keeping B detached
    for data in synthetic_data:
        loss = mse_loss(quantizer(weights), weights)
        loss.backward()
        quantizer.V.step()  # Only update V, not B

In [61]:
 #run when you get synthetic data from Genie D
#synthetic_data = torch.load("genie_d_output.pth") 

synthetic_data = torch.load('dataset_checkpoint_final.pt', map_location=torch.device('cpu'))
synthetic_data = synthetic_data['dataset']

In [63]:
synthetic_data = torch.tensor(synthetic_data)
print(synthetic_data.shape)

tensor([[[[ 0.0071,  0.2781,  0.4939,  ..., -1.0295, -1.0029, -0.8201],
          [ 0.4856,  0.9631,  1.1108,  ..., -1.3381, -1.2670, -1.0167],
          [ 0.4402,  0.9827,  1.0418,  ..., -1.0399, -1.0685, -0.8432],
          ...,
          [ 0.8029,  1.1464,  1.0653,  ...,  0.5481,  0.5881, -0.0157],
          [ 0.7566,  1.0940,  1.0606,  ...,  0.2886,  0.2454, -0.1067],
          [ 0.4646,  0.7498,  0.7258,  ..., -0.1320, -0.1442, -0.2736]],

         [[ 0.7195,  0.9484,  1.0456,  ..., -1.1092, -1.0724, -0.6919],
          [ 0.9168,  1.0771,  1.1108,  ..., -1.5613, -1.4688, -1.0104],
          [ 0.8667,  1.0684,  1.0968,  ..., -1.3069, -1.2625, -0.7472],
          ...,
          [ 1.0567,  1.1105,  1.1066,  ...,  0.5271,  0.6663,  0.2882],
          [ 1.0506,  1.1071,  1.1080,  ...,  0.2589,  0.3131,  0.2259],
          [ 0.8690,  1.0185,  1.0220,  ...,  0.0102,  0.0187,  0.0227]],

         [[ 0.4346,  0.6293,  0.7823,  ..., -0.7303, -0.6574, -0.3668],
          [ 0.7087,  0.9355,  

  synthetic_data = torch.tensor(synthetic_data)


In [65]:
model = resnet18(pretrained=True).eval()
blocks = get_blocks(model) 

for block in blocks:
    quantizer = GenieQuantizer(bitwidth=4)
    reconstruct_block(block, quantizer, synthetic_data)
# After quantization, update data to be the output from the current block
    with torch.no_grad():
        synthetic_data = block(synthetic_data)

In [66]:
def apply_quantization(block, quantizer):
    for name, param in block.named_parameters():
        if "weight" in name:
            param.data = quantizer.quantize(param.data)

In [68]:
apply_quantization(block, quantizer)

In [69]:
# Save the entire quantized model (architecture + weights)
torch.save(model.state_dict(), "quantized_resnet18.pth")

# Optional: Save the entire model (including architecture)
torch.save(model, "quantized_resnet18_full.pth")