In [1]:
dataset_path = '/mnt/edisk/dataset/unimodel_calibrate'
model_path = 'unimodel.optimized.onnx'
qmodel_path = 'unimodel.quantized.onnx'

In [2]:
import os
import pickle

from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader

from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantizeBase, FakeQuantize

from onnxutils.quantization import symbolic_trace, ModuleQuantizer, compute_metric, mse_kernel, cosine_kernel, snr_kernel

from onnxutils.common import DatasetUtils
from onnxutils.onnx import OnnxModel
from onnxutils.onnx2torch import convert
from onnxutils.onnx2torch.scatter_nd import TorchScatterNd

from unimodel_pipeline import UnimodelDataset

# Prepare

In [3]:
onnx_model = OnnxModel.from_file(model_path)
with onnx_model.session() as sess:
    for node in onnx_model.proto().graph.node:
        if node.name == '':
            node.name = sess.unique_name()
onnx_model.topological_sort()
torch_model = convert(onnx_model)

dataset = UnimodelDataset(dataset_path, torch_model.onnx_mapping.inputs)

# Quantization

In [4]:
graph_model = pickle.loads(pickle.dumps(torch_model))
graph_model = symbolic_trace(graph_model, skipped_module_classes=[TorchScatterNd])

quantizer = ModuleQuantizer()
graph_model = quantizer.quantize(graph_model, [
    {
        'name': 'imgs',
        'activation': FakeQuantize.with_args(observer=HistogramObserver),
    },
    {
        'module_name': 'backbone/stage1/conv/Conv',
        'weight': FakeQuantize.with_args(observer=PerChannelMinMaxObserver),
    },
    {
        'module_name': 'backbone/stage1/act/Relu',
        'activation': FakeQuantize.with_args(observer=HistogramObserver),
    },
])

graph_model.print_readable()

  x = nn.functional.upsample_bilinear(x, self.sizes)


class GraphModule(torch.nn.Module):
    def forward(self, distortion_coeff, extrinsics, fused_projection, imgs, mpp, mpp_pose_state, mpp_valid, norm_intrinsics, pose, prev_bev_feats, prev_feat_stride16, prev_pose_inv, sdmap_encode, sdmap_mat):
        # No stacktrace found for following nodes
        fq0 = self.fq0(imgs);  imgs = None
        
         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
        mppmodule_encoder_backbone_stem_stem_0_conv = getattr(self, "MPPModule_encoder/backbone/stem/stem/0/Conv")(mpp);  mpp = None
        
         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
        random_861d96e0_d0f0_11ef_aeac_8586c47ece25_64 = self.random_861d96e0_d0f0_11ef_aeac_8586c47ece25_64(mppmodule_encoder_backbone_stem_stem_0_conv);  

'class GraphModule(torch.nn.Module):\n    def forward(self, distortion_coeff, extrinsics, fused_projection, imgs, mpp, mpp_pose_state, mpp_valid, norm_intrinsics, pose, prev_bev_feats, prev_feat_stride16, prev_pose_inv, sdmap_encode, sdmap_mat):\n        # No stacktrace found for following nodes\n        fq0 = self.fq0(imgs);  imgs = None\n        \n         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = \'\'.join(CapturedTraceback.extract().format())\n        mppmodule_encoder_backbone_stem_stem_0_conv = getattr(self, "MPPModule_encoder/backbone/stem/stem/0/Conv")(mpp);  mpp = None\n        \n         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = \'\'.join(CapturedTraceback.extract().format())\n        random_861d96e0_d0f0_11ef_aeac_8586c47ece25_64 = self.random_861d96e0_d0f0_11ef_aeac_8586c47ece25_64(mppmodule_encoder_backbone_stem_s

In [5]:
# calibration
for m in graph_model.modules():
    if isinstance(m, FakeQuantizeBase):
        m.disable_fake_quant()

dataloader = DataLoader(
    DatasetUtils.take_front(
        DatasetUtils.transform(
            dataset,
            lambda items: tuple(x.to('cuda') for x in items)
        ),
        256
    ),
    batch_size=None
)
graph_model.eval().to('cuda')
for data in tqdm(dataloader):
    graph_model(*data)

for m in graph_model.modules():
    if isinstance(m, FakeQuantizeBase):
        m.enable_fake_quant()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:25<00:00, 10.06it/s]


# Analysis

In [9]:
dataloader = DataLoader(
    DatasetUtils.take_front(
        DatasetUtils.transform(
            dataset,
            lambda items: tuple(x.to('cuda') for x in items)
        ),
        256
    ),
    batch_size=None
)
torch_model.eval().to('cuda')
graph_model.eval().to('cuda')

for data in dataloader:
    gt = torch_model(*data)
    pred = graph_model(*data)

    for name, metric, val0, val1 in zip(torch_model.onnx_mapping.outputs, compute_metric(gt, pred, cosine_kernel), gt, pred):
        print(name, metric)
    break

allsign_flatten_bboxes [0.9999966025352478, 0.9999951124191284, 0.9999938607215881]
allsign_flatten_cls_scores [0.9955963492393494, 0.9944125413894653, 0.9909151196479797]
allsign_flatten_objectness [0.9991427659988403, 0.9987731575965881, 0.998445451259613]
bev2d_cls_result [0.9997857809066772]
bev2d_semantic_cls_result [0.9945343732833862]
det3d_bbox_score [0.9997599124908447]
det3d_bboxes_attr_scores [0.9999975562095642]
det3d_bboxes_category [0.9741077423095703]
det3d_bboxes_subcategory [0.9997507333755493]
det3d_decoded_bboxes [0.9999534487724304]
ego_velocity [1.0]
feats_0 [0.999281644821167, 0.999181866645813, 0.9992108345031738, 0.9992364645004272, 0.9991506934165955, 0.9992476105690002, 0.9991976618766785, 0.9996046423912048, 0.9991903305053711, 0.9990318417549133]
feats_1 [0.9990729689598083, 0.9988433122634888, 0.9986827969551086, 0.9990419745445251, 0.9986850619316101, 0.9988800287246704, 0.998879075050354, 0.9987078309059143, 0.9989297389984131, 0.9984086751937866]
feats_2