In [1]:
import os
import argparse

from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader

from onnxutils.common import OnnxModel
from onnxutils.onnx2torch import convert

In [2]:
dataset_path = '/home/ycgao/Workdir/dataset/unimodel_calibrate'
model_path = 'unimodel.optimized.onnx'
qmodel_path = 'unimodel.quantized.onnx'

# Prepare Model

In [3]:
onnx_model = OnnxModel.from_file(model_path)
with onnx_model.session() as sess:
    for node in onnx_model.proto().graph.node:
        if node.name == '':
            node.name = sess.unique_name()

torch_model = convert(onnx_model)

# Prepare Dataset

In [4]:
from onnxutils.common import DatasetUtils

In [5]:
class UnimodelDataset:
    fields = [
        ("imgs", [10, 3, 576, 960], np.float32),
        ("fused_projection", [1, 10, 4, 4], np.float32),
        ("pose", [1, 4, 4], np.float32),
        ("prev_pose_inv", [1, 4, 4], np.float32),
        ("extrinsics", [1, 10, 4, 4], np.float32),
        ("norm_intrinsics", [1, 10, 4, 4], np.float32),
        ("distortion_coeff", [1, 10, 6], np.float32),
        ("prev_bev_feats", [1, 48, 60, 77], np.float32),
        ("sdmap_encode", [1, 9, 128, 160], np.float32),
        ("mpp", [1, 3, 224, 384], np.float32),
        ("mpp_pose_state", [1, 6], np.float32),
        ("mpp_valid", [1, 1], np.float32),
        ("prev_feat_stride16", [1, 48, 36, 60], np.float32),
        ("sdmap_mat", [1, 4, 240, 400], np.float32),
    ]

    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.snippets = os.listdir(self.root_dir)

    def load_item(self, path):
        return {
            f[0]: torch.from_numpy(np.fromfile(os.path.join(
                path, f"{f[0]}.bin"), dtype=f[2]).reshape(f[1]))
            for f in UnimodelDataset.fields
        }

    def __len__(self):
        return len(self.snippets)

    def __getitem__(self, idx):
        return self.load_item(os.path.join(self.root_dir, str(idx)))


dataset = UnimodelDataset(dataset_path)

# Quantize Model

In [6]:
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import QConfigMapping, get_default_qconfig_mapping
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import ReuseInputObserver, NoopObserver, HistogramObserver, MinMaxObserver
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig, ConvertCustomConfig

from torch.nn.intrinsic.modules.fused import ConvReLU2d

from onnxutils.onnx2torch.scatter_nd import TorchScatterNd
from onnxutils.onnx2torch.converter import normalize_module_name
from onnxutils.onnx2torch.utils import OnnxMapping

def fix_onnx_mapping(torch_model):
    for m in torch_model.children():
        fix_onnx_mapping(m)

    if isinstance(torch_model, (ConvReLU2d, )):
        conv = getattr(torch_model, '0')
        relu = getattr(torch_model, '1')

        if hasattr(conv, 'onnx_mapping') and hasattr(relu, 'onnx_mapping'):
            torch_model.onnx_mapping = OnnxMapping(
                inputs=conv.onnx_mapping.inputs,
                outputs=relu.onnx_mapping.outputs
            )

def quantize_model(torch_model, qconfig_mapping, dataset, loss_fn=None):
    prepare_custom_config = PrepareCustomConfig()
    prepare_custom_config.set_non_traceable_module_classes([TorchScatterNd])
    prepare_custom_config.set_preserved_attributes(['onnx_mapping'])

    torch_model.eval().cpu()
    model_prepared = prepare_fx(
        torch_model,
        qconfig_mapping,
        dataset[0],
        prepare_custom_config
    )
    
    dataloader = DataLoader(
        DatasetUtils.transform(
            dataset,
            lambda item: tuple(item[x].cuda() for x in torch_model.onnx_mapping.inputs)
        ),
        batch_size=None
    )
    model_prepared.cuda()
    with torch.no_grad():
        for data in tqdm(dataloader):
            model_prepared(*data)
    
    model_prepared.cpu()
    convert_custom_config = ConvertCustomConfig()
    convert_custom_config.set_preserved_attributes(['onnx_mapping'])
    model_converted = convert_fx(model_prepared, convert_custom_config=convert_custom_config)
    fix_onnx_mapping(model_converted)
    return model_converted

def quantize_qat_model(torch_model, qconfig_mapping, dataset, optimizer=None, loss_fn=None):
    prepare_custom_config = PrepareCustomConfig()
    prepare_custom_config.set_non_traceable_module_classes([TorchScatterNd])
    prepare_custom_config.set_preserved_attributes(['onnx_mapping'])

    torch_model.train().cpu()
    model_prepared = prepare_qat_fx(
        torch_model,
        qconfig_mapping,
        dataset[0],
        prepare_custom_config
    )
    
    dataloader = DataLoader(
        DatasetUtils.transform(
            dataset,
            lambda item: tuple(item[x].cuda() for x in torch_model.onnx_mapping.inputs)
        ),
        batch_size=None
    )
    torch_model.cuda()
    model_prepared.cuda()
    for data in tqdm(dataloader):
        vals = model_prepared(*data)
        if not isinstance(vals, (tuple, list)):
            vals = (vals,)
        vals = {name: val for name, val in zip(model_prepared.onnx_mapping.outputs, vals)}
        optimizer.zero_grad()
        loss_fn(vals, torch_model, data)
        optimizer.step()
    
    model_prepared.cpu()
    convert_custom_config = ConvertCustomConfig()
    convert_custom_config.set_preserved_attributes(['onnx_mapping'])
    model_converted = convert_fx(model_prepared, convert_custom_config=convert_custom_config)
    fix_onnx_mapping(model_converted)
    return model_converted

In [7]:
none_qconfig = QConfig(
    activation=NoopObserver,
    weight=NoopObserver
)
default_qconfig = QConfig(
    activation=ReuseInputObserver,
    weight=NoopObserver
)
conv2d_qconfig = QConfig(
    activation=HistogramObserver.with_args(
        reduce_range=True),
    # activation=MinMaxObserver.with_args(
    #     dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
    weight=MinMaxObserver.with_args(
        dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
)
qconfig_mapping = (QConfigMapping()
                       .set_global(none_qconfig)
                       .set_module_name(normalize_module_name('/RoutingMaskHead/up_head/up_head.1/conv1/conv/Conv'), conv2d_qconfig)
                       .set_module_name(normalize_module_name('/RoutingMaskHead/up_head/up_head.1/conv1/act/Relu'), conv2d_qconfig)
                       )

model_converted = quantize_model(torch_model, qconfig_mapping, DatasetUtils.take_front(dataset, 10))

  x = nn.functional.upsample_bilinear(x, self.sizes)
100%|██████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.97it/s]


# Analysis

In [8]:
from onnxutils.quantization.metric import compute_metrics, print_stats

def model_infer(model, data):
    with torch.no_grad():
        vals = model(*data)
    if not isinstance(vals, (tuple, list)):
        vals = (vals,)
    return {name: val for name, val in zip(model.onnx_mapping.outputs, vals)}

In [9]:
model_converted.cpu()
torch_model.cpu()
dataloader = DataLoader(
    DatasetUtils.take_front(
        DatasetUtils.transform(
            dataset,
            lambda item: tuple(item[x].cpu() for x in torch_model.onnx_mapping.inputs)
        ),
        2),
    batch_size=None
)

for data in dataloader:
    real = model_infer(torch_model, data)
    pred = model_infer(model_converted, data)
    analysis_reports = [compute_metrics(real, pred, metrics=['snr', 'mse', 'cosine'])]
    print_stats(analysis_reports, sorted_metric='snr', reversed_order=True)

  x = nn.functional.upsample_bilinear(x, self.sizes)


refline_instance_impassable_lane_mask {'snr': [0.0006998606259003282], 'mse': [7.35356380232588e-08], 'cosine': [0.9996775388717651]}
refline_instance_passable_road_mask {'snr': [3.551760528353043e-05], 'mse': [1.516431666459539e-06], 'cosine': [0.9999885559082031]}
refline_instance_passable_lane_mask {'snr': [2.8107297112001106e-05], 'mse': [2.3454799702449236e-06], 'cosine': [0.9999882578849792]}
refline_instance_nearest_passable_lane_mask {'snr': [2.023341403400991e-05], 'mse': [1.521337793519706e-07], 'cosine': [0.9999995231628418]}
feats_0 {'snr': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mse': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'cosine': [1.0000076293945312, 1.0000081062316895, 1.000004768371582, 1.000006914138794, 1.0000076293945312, 1.0000059604644775, 1.0000076293945312, 1.000040888786316, 1.000005841255188, 1.0000065565109253]}
feats_2 {'snr': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'mse': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0

# Model Export

In [10]:
model_converted.cpu()
torch.onnx.export(
    model_converted,
    dataset[0],
    qmodel_path,
    input_names=torch_model.onnx_mapping.inputs,
    output_names=torch_model.onnx_mapping.outputs,
    keep_initializers_as_inputs=False
)

  output_indices = indices.reshape((-1, indices.shape[-1])).T.tolist()
  return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
