In [13]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as trans
from model import SemanticSegmentationModel, dice_loss
from cityscapes_dataset import CitySegDataset
from torch.ao.quantization import quantize_fx
from tinynn.graph.quantization.quantizer import QATQuantizer

In [2]:
original_model = SemanticSegmentationModel.load_from_checkpoint("model.ckpt")
original_model = original_model.eval()

In [3]:
model = quantize_fx.fuse_fx(original_model)

In [4]:
model

GraphModule(
  (encoder): Module(
    (features): Module(
      (0): Module(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (2): SiLU(inplace=True)
      )
      (1): Module(
        (0): Module(
          (block): Module(
            (0): Module(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
              (2): SiLU(inplace=True)
            )
            (1): Module(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (scale_activation): Sigmoid()
            )
            (2): Module(
              (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
        )
      )
      (2): Module(
        (0): Module(
          (block): Module(
            (0): Module(
       

In [11]:
quantizer = QATQuantizer(
    model,
    torch.randn(1, 3, 256, 256),
    work_dir="quant_output",
    config={
        "asymmetric": True,
        "backend": "qnnpack",
        "disable_requantization_for_cat": True,
        "per_tensor": True,
    },
)
model_with_quantizer = quantizer.quantize()
model_with_quantizer.eval()

QGraphModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (encoder_features_0_0): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation

In [14]:
def calibration(model, num_iteration, dataloader):
    iteration_num = num_iteration
    count = 0
    for data in dataloader:
        images, labels = data
        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()
        model(images)
        count += 1
        if count >= iteration_num:
            break
    return model

In [15]:
ds = CitySegDataset(
    root="data/cityscapes_data",
    split="val",
    transforms_img=trans.Compose([trans.ToTensor(), trans.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]),
)
testloader = DataLoader(dataset=ds, batch_size=32, shuffle=False, num_workers=8)

model_with_quantizer = calibration(model_with_quantizer, 50, testloader)

In [17]:
def eval_model(model, dataloader):
    val_loss = 0
    val_iou = 0
    val_pix_acc = 0
    val_mean_acc = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()
            loss, metrics = model._step((images, labels))
            val_loss += loss
            val_iou += metrics["iou"]
            val_pix_acc += metrics["pixel_accuracy"]
            val_mean_acc += metrics["mean_accuracy"]

    val_loss /= len(dataloader)
    val_iou /= len(dataloader)
    val_pix_acc /= len(dataloader)
    val_mean_acc /= len(dataloader)

    print(f" Validation Loss: {val_loss:.3f} Pixel accuracy: {val_pix_acc:.2f} Mean class accuracy: {val_mean_acc:.2f} Mean IoU: {val_iou:.2f}")

In [None]:
print("Original model")
eval_model(model_with_quantizer, testloader)

