In [1]:
num_features = 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from gptq import GPTQ

# Define a small neural network
class SmallNet(nn.Module):
    def __init__(self):
        super(SmallNet, self).__init__()
        self.fc1 = nn.Linear(num_features, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 20)
        self.fc4 = nn.Linear(20, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

# Function to quantize all layers in the model with layer-specific inputs
def quantize_model_with_propagation(model, input_data, bits=4):
    activations = input_data  # Start with the original input
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Linear):
            print(f"Quantizing layer: {name}")
            gptq = GPTQ(layer, bits=bits)
            
            # Add activations (layer-specific inputs) for Hessian calculation
            gptq.add_batch(activations.detach().cpu().numpy())
            gptq.quantize()  # Perform GPTQ-based quantization
            
            # Replace the original weights with quantized weights
            layer.weight.data = torch.tensor(gptq.get_quantized_weights(), dtype=layer.weight.dtype)
            if layer.bias is not None:
                layer.bias.data = torch.tensor(gptq.get_quantized_bias(), dtype=layer.bias.dtype)
            print(f"Layer {name} quantized successfully.\n")
            
            # Forward pass through the quantized layer to get new activations
            with torch.no_grad():
                activations = torch.relu(layer(activations))  # Update activations


In [3]:
import numpy as np

# Generate training data
def generate_data(samples=1000, features=num_features):
    x = np.random.rand(samples, features).astype(np.float32)
    y = np.mean(x, axis=1, keepdims=True).astype(np.float32)
    return torch.tensor(x), torch.tensor(y)

# Create datasets and dataloaders
x_train, y_train = generate_data(samples=1000)
x_test, y_test = generate_data(samples=200)

train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [4]:
# Train the SmallNet to learn the mean function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # Forward pass
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")


Epoch [1/10], Loss: 0.1781
Epoch [2/10], Loss: 0.0579
Epoch [3/10], Loss: 0.0354
Epoch [4/10], Loss: 0.0284
Epoch [5/10], Loss: 0.0176
Epoch [6/10], Loss: 0.0062
Epoch [7/10], Loss: 0.0015
Epoch [8/10], Loss: 0.0009
Epoch [9/10], Loss: 0.0006
Epoch [10/10], Loss: 0.0004


In [5]:
# Evaluate the trained model
model.eval()
with torch.no_grad():
    test_loss = 0.0
    for x_batch, y_batch in test_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.4f}")


Test Loss: 0.0004


In [6]:
# Test the original model
model.eval()
with torch.no_grad():
    original_outputs = []
    for x_batch, _ in test_loader:
        x_batch = x_batch.to(device)
        original_outputs.append(model(x_batch))

# Concatenate original outputs for comparison
original_outputs = torch.cat(original_outputs, dim=0)


In [7]:
# Instantiate the model
model = SmallNet()
model.eval()

# Create dummy input data
dummy_input = torch.randn(100, 2)  # 100 samples, 10 features

# Quantize all layers in the model
quantize_model_with_propagation(model, dummy_input, bits=4)


Quantizing layer: fc1
Estimating Hessian diagonal for layer with shape (20, 2)...
Hessian diagonal estimated.
Starting GPTQ quantization for layer with shape (20, 2)...
GPTQ quantization using Cholesky inverse completed.
Layer fc1 quantized successfully.

Quantizing layer: fc2
Estimating Hessian diagonal for layer with shape (30, 20)...
Hessian diagonal estimated.
Starting GPTQ quantization for layer with shape (30, 20)...
Hessian is not positive definite. Falling back to diagonal inverse.
GPTQ quantization using Cholesky inverse completed.
Layer fc2 quantized successfully.

Quantizing layer: fc3
Estimating Hessian diagonal for layer with shape (20, 30)...
Hessian diagonal estimated.
Starting GPTQ quantization for layer with shape (20, 30)...
Hessian is not positive definite. Falling back to diagonal inverse.
GPTQ quantization using Cholesky inverse completed.
Layer fc3 quantized successfully.

Quantizing layer: fc4
Estimating Hessian diagonal for layer with shape (1, 20)...
Hessian di

In [8]:
# Test the quantized model
with torch.no_grad():
    quantized_outputs = []
    for x_batch, _ in test_loader:
        x_batch = x_batch.to(device)
        quantized_outputs.append(model(x_batch))

# Concatenate quantized outputs for comparison
quantized_outputs = torch.cat(quantized_outputs, dim=0)


In [9]:
# Compute Mean Squared Error (MSE)
mse = torch.mean((original_outputs - quantized_outputs) ** 2)
print("Mean Squared Error (MSE) between original and quantized outputs:", mse.item())


Mean Squared Error (MSE) between original and quantized outputs: 0.44415464997291565


In [11]:
list(zip(original_outputs, quantized_outputs))

[(tensor([0.9018]), tensor([-0.1396])),
 (tensor([0.2360]), tensor([-0.1674])),
 (tensor([0.5717]), tensor([-0.1489])),
 (tensor([0.4785]), tensor([-0.1557])),
 (tensor([0.6948]), tensor([-0.1459])),
 (tensor([0.1646]), tensor([-0.1736])),
 (tensor([0.8627]), tensor([-0.1407])),
 (tensor([0.3160]), tensor([-0.1645])),
 (tensor([0.7842]), tensor([-0.1432])),
 (tensor([0.2346]), tensor([-0.1666])),
 (tensor([0.2289]), tensor([-0.1691])),
 (tensor([0.5551]), tensor([-0.1504])),
 (tensor([0.3892]), tensor([-0.1612])),
 (tensor([0.4153]), tensor([-0.1594])),
 (tensor([0.2706]), tensor([-0.1642])),
 (tensor([0.2444]), tensor([-0.1675])),
 (tensor([0.4993]), tensor([-0.1549])),
 (tensor([0.2443]), tensor([-0.1662])),
 (tensor([0.1657]), tensor([-0.1722])),
 (tensor([0.7272]), tensor([-0.1448])),
 (tensor([0.3481]), tensor([-0.1615])),
 (tensor([0.7397]), tensor([-0.1467])),
 (tensor([0.5269]), tensor([-0.1547])),
 (tensor([0.8105]), tensor([-0.1425])),
 (tensor([0.3670]), tensor([-0.1619])),
