In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from torch.quantization import quantize_dynamic

# Function to evaluate the network
def evaluate_network(network, data_loader, criterion):
    network.eval()
    total_loss = 0.0
    total_correct = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = network(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
    return total_correct / len(data_loader.dataset), total_loss / len(data_loader.dataset)

# Function to calculate SQNR
def calculate_sqnr(full_precision_output, quantized_output):
    signal_power = torch.mean(full_precision_output**2).item()
    noise_power = torch.mean((full_precision_output - quantized_output)**2).item()
    sqnr = 10 * torch.log10(signal_power / noise_power)
    return sqnr


In [3]:
def generate_sensitivity_list(network, data_loader, bit_widths, criterion):
    sensitivity_list = []
    original_state_dict = network.state_dict()
    layer_names = [name for name, _ in network.named_modules() if isinstance(_, (nn.Conv2d, nn.Linear))]
    
    for name in layer_names:
        for bit_width in bit_widths:
            if bit_width != 8:
                network.load_state_dict(original_state_dict)
                quantized_network = quantize_dynamic(network, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)
                performance, _ = evaluate_network(quantized_network, data_loader, criterion)
                sensitivity = performance
                sensitivity_list.append((name, bit_width, sensitivity))
    
    sensitivity_list.sort(key=lambda x: x[2], reverse=True)
    return sensitivity_list


In [4]:
def find_mixed_precision_config(network, data_loader, sensitivity_list, criterion, gamma):
    original_state_dict = network.state_dict()
    network.load_state_dict(original_state_dict)
    baseline_performance, _ = evaluate_network(network, data_loader, criterion)
    
    for name, bit_width, _ in sensitivity_list:
        quantized_network = quantize_dynamic(network, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)
        current_performance, _ = evaluate_network(quantized_network, data_loader, criterion)
        
        if current_performance < gamma:
            network.load_state_dict(original_state_dict)
            break
        else:
            network = quantized_network
    
    return network, baseline_performance


In [5]:
# Define ResNet model, data loader, and parameters
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load a subset of the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset_indices = list(range(0, 1000))  # Use only the first 1000 samples
subset_train_dataset = Subset(train_dataset, subset_indices)
train_loader = DataLoader(subset_train_dataset, batch_size=64, shuffle=True)

# Initialize ResNet model and adjust for MNIST
network = models.resnet18(pretrained=True)
network.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  # Adjust for single channel input
network.fc = nn.Linear(network.fc.in_features, 10)  # Adjust for 10 classes in MNIST

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define bit-width candidates and performance budget
bit_widths = [2, 4, 8]  # Example bit-width candidates
gamma = 0.8  # Performance budget

# Phase 1: Generate sensitivity list
sensitivity_list = generate_sensitivity_list(network, train_loader, bit_widths, criterion)

# Phase 2: Find mixed precision configuration
quantized_network, baseline_performance = find_mixed_precision_config(network, train_loader, sensitivity_list, criterion, gamma)

# Evaluate the quantized network's performance
quantized_performance, quantized_loss = evaluate_network(quantized_network, train_loader, criterion)

# Print the results
print("Baseline Network Performance:", baseline_performance)
print("Quantized Network Performance:", quantized_performance)
print("Quantized Network:", quantized_network)


In [6]:
def find_mixed_precision_config(network, data_loader, sensitivity_list, criterion, gamma):
    original_state_dict = network.state_dict()
    network.load_state_dict(original_state_dict)
    baseline_performance, _ = evaluate_network(network, data_loader, criterion)
    
    for name, bit_width, _ in sensitivity_list:
        quantized_network = quantize_dynamic(network, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)
        current_performance, _ = evaluate_network(quantized_network, data_loader, criterion)
        
        if current_performance < gamma:
            network.load_state_dict(original_state_dict)
            break
        else:
            network = quantized_network
    
    return network, baseline_performance


In [7]:
# Define ResNet model, data loader, and parameters
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load a subset of the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
subset_indices = list(range(0, 1000))  # Use only the first 1000 samples
subset_train_dataset = Subset(train_dataset, subset_indices)
train_loader = DataLoader(subset_train_dataset, batch_size=64, shuffle=True)

# Initialize ResNet model and adjust for MNIST
network = models.resnet18(pretrained=True)
network.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  # Adjust for single channel input
network.fc = nn.Linear(network.fc.in_features, 10)  # Adjust for 10 classes in MNIST

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(network.parameters(), lr=0.001)

# Train the network
train_network(network, train_loader, criterion, optimizer, num_epochs=5)

# Define bit-width candidates and performance budget
bit_widths = [2, 4, 8]  # Example bit-width candidates
gamma = 0.8  # Performance budget

# Phase 1: Generate sensitivity list
sensitivity_list = generate_sensitivity_list(network, train_loader, bit_widths, criterion)

# Phase 2: Find mixed precision configuration
quantized_network, baseline_performance = find_mixed_precision_config(network, train_loader, sensitivity_list, criterion, gamma)

# Evaluate the quantized network's performance
quantized_performance, quantized_loss = evaluate_network(quantized_network, train_loader, criterion)

# Print the results
print("Baseline Network Performance:", baseline_performance)
print("Quantized Network Performance:", quantized_performance)
print("Quantized Network:", quantized_network)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:28<00:00, 344879.17it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 118383.75it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:08<00:00, 198138.24it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4363382.68it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch 1/5, Loss: 0.5079
Epoch 2/5, Loss: 0.0797
Epoch 3/5, Loss: 0.0393
