In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

In [None]:
class Quantizer(nn.Module):

    def __init__(self, bits, scale_init=None):
        super().__init__()
        self.bits = bits
        self.scale = nn.Parameter(torch.tensor(scale_init if scale_init else 1.0))
        self.soft_bit = nn.Parameter(torch.zeros(1))  
        
    def forward(self, x):
        if not self.training:
            q_step = self.scale.detach()
            x_q = torch.clamp(torch.round(x / q_step), -2**(self.bits-1), 2**(self.bits-1)-1)
            return x_q * q_step
        else:
        
            q_step = self.scale
            x_q = torch.clamp(torch.round(x / q_step), -2**(self.bits-1), 2**(self.bits-1)-1)
            return x_q * q_step

In [None]:
class QuantizableResNetBlock(nn.Module):
    """Wrapper for ResNet blocks with quantization"""
    def __init__(self, block, w_bits=4, a_bits=4):
        super().__init__()
        self.block = block
        self.w_quant = Quantizer(w_bits)
        self.a_quant = Quantizer(a_bits)
        
    def forward(self, x):
        # Quantize weights
        for name, param in self.block.named_parameters():
            if 'weight' in name:
                param.data = self.w_quant(param.data)
        
        # Quantize activations
        if self.training:
            x = self.a_quant(x)
        return self.block(x)

In [None]:
def linear_temp_decay(current_iter, max_iter, start_temp=20, end_temp=2, warmup=0.2):
    """Temperature decay for rounding loss"""
    warmup_iter = warmup * max_iter
    if current_iter < warmup_iter:
        return start_temp
    progress = (current_iter - warmup_iter) / (max_iter - warmup_iter)
    return end_temp + (start_temp - end_temp) * max(0.0, 1 - progress)

In [None]:
def reconstruct_resnet18(teacher, student, genie_data, num_iterations=20000, batch_size=32):

    w_params = []
    a_params = []
    
    for module in student.modules():
        if isinstance(module, Quantizer):
            if hasattr(module, 'weight'):
                w_params.append(module.scale)
            else:
                a_params.append(module.scale)
    
    optimizer = Adam([
        {'params': w_params, 'lr': 1e-4},
        {'params': a_params, 'lr': 4e-5}
    ])
    scheduler = CosineAnnealingLR(optimizer, T_max=num_iterations)
    

    for iteration in range(num_iterations):
        idx = torch.randperm(len(genie_data))[:batch_size]
        x = genie_data[idx]
        
        with torch.no_grad():
            teacher_out = teacher(x)

        if torch.rand(1) < 0.5:
            student_out = student(x)  
        else:
            student_out = student(x)  
        
        recon_loss = (student_out - teacher_out).pow(2).mean()
        
        temp = linear_temp_decay(iteration, num_iterations)
        round_loss = torch.tensor(0.0, device=x.device) 
        for module in student.modules():
            if isinstance(module, Quantizer) and hasattr(module, 'weight'):
                round_loss = round_loss + (1 - (2 * torch.sigmoid(module.soft_bit) - 1).abs().pow(temp)).sum()
        
        total_loss = recon_loss + 1.0 * round_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        if iteration % 1000 == 0:
            print(f"Iter {iteration}/{num_iterations} - Loss: {total_loss.item():.4f} "
                  f"(Recon: {recon_loss.item():.4f}, Round: {round_loss.item():.4f})")
    
    student.eval()
    return student

In [None]:
#Loading data generated from Genie D
synthetic_data = torch.load('dataset_checkpoint_final.pt', map_location=torch.device('cpu'))
synthetic_data = synthetic_data['dataset']
synthetic_data = torch.tensor(synthetic_data)

In [None]:
if __name__ == "__main__":
    teacher_model = resnet18(pretrained=True).eval()
    student_model = resnet18(pretrained=True)
    
    for name, module in student_model.named_children():
        if isinstance(module, torch.nn.Sequential):  # For ResNet layers
            for i, block in enumerate(module):
                module[i] = QuantizableResNetBlock(block, w_bits=4, a_bits=4)
    
    # Run reconstruction
    quantized_model = reconstruct_resnet18(teacher_model, student_model, synthetic_data)
    
    torch.save(quantized_model.state_dict(), "quantized_resnet18.pth")
    torch.save(model, "quantized_resnet18_full__.pth")