In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18
from tqdm import tqdm

In [None]:
def reconstruct_block(block, quantizer, data, epochs=100, lr=1e-3):
    block.eval()
    
    with torch.no_grad():
        dum_output = block(data[0:1])
    
    optimizer = optim.Adam([quantizer.s, quantizer.V], lr=lr)
    criterion = nn.MSELoss()
    
    loss_history = []
    
    for epoch in tqdm(range(epochs), desc="Reconstructing block"):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch in data.split(32):
            with torch.no_grad():
                output_fp32 = block(batch)
            output_quant = block(quantizer(batch))
            loss = criterion(output_fp32, output_quant)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches
        loss_history.append(avg_loss)
        
        if (epoch + 1) % 20 == 0:
            print(f"\nEpoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.6f}")
            print(f"Quantizer scale (s): {quantizer.s.item():.4f}")
            if hasattr(quantizer, 'V'):
                print(f"Quantizer offset (V) mean: {quantizer.V.mean().item():.4f}")
    
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history)
    plt.title("Training Loss History")
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.grid(True)
    plt.show()
    
    return loss_history

In [None]:
class GenieQuantizer(nn.Module):
    def __init__(self, bitwidth=4):
        super().__init__()
        self.s = nn.Parameter(torch.tensor(1.0)) 
        self.V = nn.Parameter(torch.zeros(1))     
        self.bitwidth = bitwidth
        self.is_initialized = False

    def forward(self, x):
        return self.quantize(x)

    def quantize(self, x):
        if not self.is_initialized or self.V.shape != x.shape:
            self.V = nn.Parameter(torch.zeros_like(x))
            self.is_initialized = True
            
        B = torch.floor(x / self.s.detach())  #  Eq. 9
        return self.s * (B + torch.sigmoid(self.V))  # Eq. 10

In [None]:
def get_blocks(model):
    blocks = []
    # For ResNet
    blocks.append(nn.Sequential(model.conv1, model.bn1, model.relu))
    blocks.append(model.layer1)
    blocks.append(model.layer2)
    blocks.append(model.layer3)
    blocks.append(model.layer4)
    blocks.append(nn.Sequential(
    model.avgpool,
    nn.Flatten(start_dim=1), 
    model.fc
))
    return blocks

In [None]:
def adaround(weights, quantizer, synthetic_data):
    for data in synthetic_data:
        loss = mse_loss(quantizer(weights), weights)
        loss.backward()
        quantizer.V.step()  # Only update V, not B

In [None]:
 #run when you get synthetic data from Genie D
synthetic_data = torch.load('dataset_checkpoint_final.pt', map_location=torch.device('cpu'))
synthetic_data = synthetic_data['dataset']

In [None]:
synthetic_data = torch.tensor(synthetic_data)
print(synthetic_data.shape)

In [None]:
model = resnet18(pretrained=True).eval()
blocks = get_blocks(model) 

for block in blocks:
    quantizer = GenieQuantizer(bitwidth=4)
    reconstruct_block_with_adaround(block, quantizer, synthetic_data)
    with torch.no_grad():
        synthetic_data = block(synthetic_data)

In [None]:
def apply_quantization(block, quantizer):
    for name, param in block.named_parameters():
        if "weight" in name:
            param.data = quantizer.quantize(param.data)

In [None]:
apply_quantization(block, quantizer)

In [None]:
torch.save(model.state_dict(), "quantized_resnet18_60epochs.pth")
torch.save(model, "quantized_resnet18_full_60epochs.pth")