In [1]:
pip install torch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [29]:
import torch
import torch.nn as nn
import torch.quantization as quant
from torch.quantization.observer import MovingAverageMinMaxObserver, default_weight_observer

class CustomObserver(MovingAverageMinMaxObserver):
    def calculate_qparams(self):
        scale, _ = super().calculate_qparams()
        zero_point = torch.tensor(0, dtype=torch.int32)
        return scale, zero_point

class QuantAwareModel(nn.Module):
    def __init__(self):
        super(QuantAwareModel, self).__init__()
        self.conv1 = nn.Conv1d(3, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm1d(16)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm1d(32)
        self.relu2 = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Linear(32, 10)
        
        self.quant1 = quant.QuantStub()  # Quantizes the input
        self.dequant1 = quant.DeQuantStub()  # Dequantizes the output
        self.quant2 = quant.QuantStub()  # Quantizes the input
        self.dequant2 = quant.DeQuantStub()  # Dequantizes the output
        self.quant3 = quant.QuantStub()  # Quantizes the input
        self.dequant3 = quant.DeQuantStub()  # Dequantizes the output
        
        self.int_output1 = None
        self.int_output2 = None

    def forward(self, x):
        x = self.quant1(x)  # Quantize the input
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dequant1(x)
        x = self.quant2(x)  # Quantize the input
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dequant2(x)
        print(x.shape)
        x = self.pool(x)
        print(x.shape)
        x = torch.flatten(x, 1)
        x = self.quant3(x)  # Quantize the input
        print(x.shape)
        x = self.fc(x)
        x = self.dequant3(x)  # Dequantize the output
        return x

# Create the model
model = QuantAwareModel()

# Set the QAT configuration with custom observer forcing zero_point to 0
model.qconfig = quant.QConfig(
    activation=quant.FakeQuantize.with_args(observer=CustomObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine),
    weight=quant.FakeQuantize.with_args(observer=default_weight_observer, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

# Prepare the model for QAT
quant.prepare_qat(model, inplace=True)

# Move the model to GPU for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)




QuantAwareModel(
  (conv1): Conv1d(
    3, 16, kernel_size=(3,), stride=(1,), padding=(1,)
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): CustomObserver(min_val=inf, max_val=-inf)
    )
  )
  (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv2): Conv1d(
    16, 32, kernel_size=(3,), stride=(1,), padding=(1,)
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch

In [30]:
# Dummy training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Dummy data
inputs = torch.randn(16, 3, 32).to(device)
targets = torch.randint(0, 10, (16,)).to(device)

# Training with fake quantization
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Move the model back to CPU for folding Batch Normalization layers
model.to('cpu')




torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 1, Loss: 2.30914568901062
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 2, Loss: 2.308565855026245
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 3, Loss: 2.3078770637512207
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 4, Loss: 2.3063342571258545
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 5, Loss: 2.3044683933258057
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 6, Loss: 2.302889347076416
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 7, Loss: 2.300825357437134
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 8, Loss: 2.2984256744384766
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])
Epoch 9, Loss: 2.295379400253296
torch.Size([16, 32, 32])
torch.Size([16, 32, 1])
torch.Size([16, 32])


QuantAwareModel(
  (conv1): Conv1d(
    3, 16, kernel_size=(3,), stride=(1,), padding=(1,)
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0134]), zero_point=tensor([0], dtype=torch.int32)
      (activation_post_process): CustomObserver(min_val=-1.8524562120437622, max_val=1.574406385421753)
    )
  )
  (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv2): Conv1d(
    16, 32, kernel_size=(3,), stride=(1,), padding=(1,)
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0252]), zero_point=tensor([0], dtype=torch.

In [31]:
def fold_batch_norm(conv, bn):
    # Fold BatchNorm parameters into Convolution layer
    with torch.no_grad():
        scale_factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
        conv.weight.copy_(conv.weight * scale_factor.reshape([-1, 1, 1]))
        if conv.bias is None:
            conv.bias = torch.nn.Parameter(torch.zeros(conv.weight.size(0), dtype=conv.weight.dtype, device=conv.weight.device))
        conv.bias.copy_((conv.bias - bn.running_mean) * scale_factor + bn.bias)
    return conv

# Fold BatchNorm layers
model.conv1 = fold_batch_norm(model.conv1, model.bn1)
model.conv2 = fold_batch_norm(model.conv2, model.bn2)

# Remove BatchNorm layers
model.bn1 = None
model.bn2 = None

# Adjust the forward pass to remove Batch Normalization layers
def new_forward(self, x):
    x = self.quant1(x)  # Quantize the input
    x = self.conv1(x)
    x = self.relu1(x)
    x = self.dequant1(x)
    x = self.quant2(x)  # Quantize the input
    x = self.conv2(x)
    x = self.relu2(x)
    x = self.dequant2(x)
    x = self.pool(x)
    x = torch.flatten(x, 1)
    x = self.quant3(x)  # Quantize the input
    x = self.fc(x)
    x = self.dequant3(x)  # Dequantize the output
    return x

model.forward = new_forward.__get__(model, QuantAwareModel)


In [32]:
# Move the model back to CPU for quantized inference
# Convert the trained model to a quantized version
model.eval()
quant.convert(model, inplace=True)




QuantAwareModel(
  (conv1): QuantizedConv1d(3, 16, kernel_size=(3,), stride=(1,), scale=0.013438677415251732, zero_point=0, padding=(1,))
  (bn1): None
  (relu1): ReLU()
  (conv2): QuantizedConv1d(16, 32, kernel_size=(3,), stride=(1,), scale=0.025234129279851913, zero_point=0, padding=(1,))
  (bn2): None
  (relu2): ReLU()
  (pool): AdaptiveAvgPool1d(output_size=1)
  (fc): QuantizedLinear(in_features=32, out_features=10, scale=0.004543437156826258, zero_point=0, qscheme=torch.per_tensor_affine)
  (quant1): Quantize(scale=tensor([0.0264]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant1): DeQuantize()
  (quant2): Quantize(scale=tensor([0.0317]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant2): DeQuantize()
  (quant3): Quantize(scale=tensor([0.0040]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant3): DeQuantize()
)

In [33]:
# Inspect the quantization parameters (scale and zero-point) for each layer
for name, module in model.named_modules():
    if hasattr(module, 'weight_fake_quant'):
        print(f"Layer: {name}")
        print(f"Weight scale: {module.weight_fake_quant.scale}")
        print(f"Weight zero-point: {module.weight_fake_quant.zero_point}")
    if hasattr(module, 'activation_post_process'):
        print(f"Layer: {name}")
        print(f"Activation scale: {module.activation_post_process.scale}")
        print(f"Activation zero-point: {module.activation_post_process.zero_point}")


print( model )

QuantAwareModel(
  (conv1): QuantizedConv1d(3, 16, kernel_size=(3,), stride=(1,), scale=0.013438677415251732, zero_point=0, padding=(1,))
  (bn1): None
  (relu1): ReLU()
  (conv2): QuantizedConv1d(16, 32, kernel_size=(3,), stride=(1,), scale=0.025234129279851913, zero_point=0, padding=(1,))
  (bn2): None
  (relu2): ReLU()
  (pool): AdaptiveAvgPool1d(output_size=1)
  (fc): QuantizedLinear(in_features=32, out_features=10, scale=0.004543437156826258, zero_point=0, qscheme=torch.per_tensor_affine)
  (quant1): Quantize(scale=tensor([0.0264]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant1): DeQuantize()
  (quant2): Quantize(scale=tensor([0.0317]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant2): DeQuantize()
  (quant3): Quantize(scale=tensor([0.0040]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant3): DeQuantize()
)


In [37]:
def hook_conv1(module, input, output):
    if output.is_quantized:
        print('hello')
        model.int_output1 = output.int_repr().cpu().numpy()  # Capture integer values after conv1

def hook_conv2(module, input, output):
    if output.is_quantized:
        model.int_output2 = output.int_repr().cpu().numpy()  # Capture integer values after conv2

# Register hooks
model.conv1.register_forward_hook(hook_conv1)
model.conv2.register_forward_hook(hook_conv2)

# Dummy input data for inference on CPU
inputs_cpu = inputs.to('cpu')

# Perform inference and print integer outputs
with torch.no_grad():
    _ = model(inputs_cpu)

print("Integer values after conv1:")
print(model.int_output1)

print("\nInteger values after conv2:")
print(model.int_output2)




hello
Integer values after conv1:
[[[33  0 57 ...  0 31  0]
  [ 0  0  7 ... 23 36  0]
  [ 0 67  5 ... 75 17  0]
  ...
  [ 0 10  0 ...  0  0 41]
  [11 24  0 ...  8  0  0]
  [ 0  7  0 ... 21  0  0]]

 [[ 0  3  0 ...  0  0  0]
  [ 0 24 50 ...  0 59 61]
  [ 0  0 18 ...  0  0 43]
  ...
  [ 0  0 43 ... 45 58  0]
  [25  2 15 ...  0  0 96]
  [ 0  0  0 ...  0  0  0]]

 [[ 1  0  0 ...  0  6  3]
  [ 0  1  0 ...  0  2  0]
  [ 0  0  9 ...  0  0  0]
  ...
  [ 0  0  0 ...  0  0  0]
  [46 13 38 ... 19  0  0]
  [ 0  0  0 ...  0  0  0]]

 ...

 [[ 0  0 24 ...  0  0  0]
  [ 5 28  7 ... 53  0 24]
  [ 0 10 22 ... 65  0  4]
  ...
  [ 0  0  0 ...  0 83  0]
  [55 33  0 ...  0  0 66]
  [ 0  0  0 ...  0  0  0]]

 [[14  0  0 ...  0  0  0]
  [ 3  0  0 ...  0  4  0]
  [45 51 10 ...  0  0 13]
  ...
  [ 0  0 41 ... 24  0  0]
  [62 12  0 ...  0 44 33]
  [ 0  0  0 ...  0  0  0]]

 [[ 0  1  0 ... 31  0  0]
  [60  8 20 ...  0 49 55]
  [ 0 24 27 ...  0 43 12]
  ...
  [ 0  0 23 ...  0 85  0]
  [ 2  5  0 ...  0 14 47]
  [ 