In [18]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as trans
from model import SemanticSegmentationModel, dice_loss
from cityscapes_dataset import CitySegDataset
from torch.ao.quantization import quantize_fx
from tinynn.graph.quantization.quantizer import QATQuantizer

# Load the original model

In [3]:
original_model = SemanticSegmentationModel.load_from_checkpoint("model.ckpt")
original_model = original_model.eval()

# Fuze the batchnorms

In [4]:
model = quantize_fx.fuse_fx(original_model)

In [5]:
model

GraphModule(
  (encoder): Module(
    (features): Module(
      (0): Module(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (2): SiLU(inplace=True)
      )
      (1): Module(
        (0): Module(
          (block): Module(
            (0): Module(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
              (2): SiLU(inplace=True)
            )
            (1): Module(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (scale_activation): Sigmoid()
            )
            (2): Module(
              (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
        )
      )
      (2): Module(
        (0): Module(
          (block): Module(
            (0): Module(
       

# Quantize the model

### Prepare the quantizer

In [6]:
quantizer = QATQuantizer(
    model,
    torch.randn(1, 3, 256, 256),
    work_dir="quant_output",
    config={
        "asymmetric": True,
        "backend": "qnnpack",
        "disable_requantization_for_cat": True,
        "per_tensor": True,
    },
)
model_with_quantizer = quantizer.quantize()
model_with_quantizer.eval()

QGraphModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (encoder_features_0_0): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation

### Post training quantization calibration

In [7]:
def calibration(model, num_iteration, dataloader):
    iteration_num = num_iteration
    count = 0
    for data in dataloader:
        images, labels = data
        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()
        model(images)
        count += 1
        if count >= iteration_num:
            break
    return model

In [8]:
ds = CitySegDataset(
    root="data/cityscapes_data",
    split="val",
    transforms_img=trans.Compose([trans.ToTensor(), trans.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]),
)
testloader = DataLoader(dataset=ds, batch_size=32, shuffle=False, num_workers=8)

model_with_quantizer = calibration(model_with_quantizer, 50, testloader)

# Evaluate and compare the models

In [15]:
def step(model, batch):
        x, y = batch
        batch_size, num_classes, width, height = y.shape
        out = model(x)  # the output is (35x64x64) so we need to iterpolate it to (35x256x256)
        out = torch.nn.functional.interpolate(
            out, size=(width, height), mode="bilinear", align_corners=True
        )  # in the FFNet repo this is exactly how they trasform the output.
        out = F.softmax(out, dim=1)
        loss = dice_loss(out, y)

        # pixel accuracy
        y_idx = y.argmax(dim=1)
        out_idx = out.argmax(dim=1)
        y_equal_pred = out_idx == y_idx
        pix_accuracy = y_equal_pred.sum() / (batch_size * width * height)

        # mean accuracy
        class_accuracies = []
        for cls in range(num_classes):
            mask = y_idx == cls
            if mask.sum() > 0:
                class_accuracies.append((y_equal_pred * mask).sum() / mask.sum())
        mean_accuracy = sum(class_accuracies) / len(class_accuracies)

        # IoU
        ious = []
        for cls in range(num_classes):
            pred_mask = out_idx == cls
            target_mask = y_idx == cls

            intersection = (pred_mask & target_mask).sum(dim=(1, 2))
            union = (pred_mask | target_mask).sum(dim=(1, 2))

            union = torch.where(union == 0, torch.tensor(1.0).to(union.device), union)

            iou = intersection.float() / union.float()
            ious.append(iou)
        iou = torch.stack(ious).mean().item()

        metrics = {
            "pixel_accuracy": pix_accuracy, 
            "mean_accuracy": mean_accuracy,
            "iou": iou,
        }
        return loss, metrics

def eval_model(model, dataloader):
    val_loss = 0
    val_iou = 0
    val_pix_acc = 0
    val_mean_acc = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()
            loss, metrics = step(model, (images, labels))
            val_loss += loss
            val_iou += metrics["iou"]
            val_pix_acc += metrics["pixel_accuracy"]
            val_mean_acc += metrics["mean_accuracy"]

    val_loss /= len(dataloader)
    val_iou /= len(dataloader)
    val_pix_acc /= len(dataloader)
    val_mean_acc /= len(dataloader)

    print(f" Validation Loss: {val_loss:.3f} Pixel accuracy: {val_pix_acc:.2f} Mean class accuracy: {val_mean_acc:.2f} Mean IoU: {val_iou:.2f}")

In [16]:
print("Original model")
eval_model(original_model, testloader)
print("Quantized model")
eval_model(model_with_quantizer, testloader)

Original model
 Validation Loss: 0.127 Pixel accuracy: 0.84 Mean class accuracy: 0.40 Mean IoU: 0.20
Quantized model
 Validation Loss: 0.136 Pixel accuracy: 0.83 Mean class accuracy: 0.37 Mean IoU: 0.19


In [19]:
def print_model_size(path):
    file_size = os.path.getsize(path)
    file_size_mb = file_size / (1024 * 1024)
    print(f"Model size: {file_size_mb:.2f} MB")

print("Original model size")
print_model_size("model.ckpt")
print("Quantized model size")
print_model_size("quant_output/graphmodule_q.pth")

Original model size
Model size: 70.01 MB
Quantized model size
Model size: 17.54 MB


As we can see the drop in performance, though visible, is not really that significant. On the other hand, the model size decrease is very much considerable. The model size has decreased by a factor of around 4 (from 70 MB to 17.54 MB). We can therefore conclude that we have successfully quantized the model