In [1]:
dataset_path = '/mnt/edisk/dataset/unimodel_calibrate'
model_path = 'base.onnx'
qmodel_path = 'base.quantized.onnx'

In [2]:
import os
import pickle

from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader

from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantizeBase, FakeQuantize

from onnxutils.quantization import symbolic_trace, ModuleQuantizer, compute_metric, mse_kernel, cosine_kernel, snr_kernel

from onnxutils.common import DatasetUtils
from onnxutils.onnx import OnnxModel
from onnxutils.onnx2torch import convert, normalize_module_name
from onnxutils.onnx2torch.scatter_nd import TorchScatterNd
from onnxutils.onnx2torch.binary_math_operations import TorchBinaryOp

from unimodel_pipeline import UnimodelDataset

# Prepare

In [3]:
onnx_model = OnnxModel.from_file(model_path)
with onnx_model.session() as sess:
    for node in onnx_model.proto().graph.node:
        if node.name == '':
            node.name = sess.unique_name()
onnx_model.topological_sort()
torch_model = convert(onnx_model)

dataset = UnimodelDataset(dataset_path, torch_model.onnx_mapping.inputs)

In [4]:
def shared_fq(cls):
    fq = cls()
    def wrapper():
        return fq
    return wrapper

qconfig_mappings = {}
for node in onnx_model.nodes():
    if node.op_type() == 'Conv':
        qconfig_mappings[normalize_module_name(node.name())] = {
            'activation': FakeQuantize.with_args(observer=HistogramObserver),
            'weight': FakeQuantize.with_args(observer=PerChannelMinMaxObserver),
        }
    elif node.op_type() == 'Relu':
        qconfig_mappings[normalize_module_name(node.name())] = {
            'activation': FakeQuantize.with_args(observer=HistogramObserver)
        }
        maybe_conv_node = onnx_model.get_node_by_output(node.inputs()[0])
        if maybe_conv_node.op_type() == 'Conv':
            qconfig_mappings[normalize_module_name(maybe_conv_node.name())].pop('activation')
    elif node.op_type() == 'Add':
        fq_cls = shared_fq(FakeQuantize.with_args(observer=HistogramObserver))
        qconfig_mappings[normalize_module_name(node.name())] = {
            'activation': fq_cls
        }
        for input_name in node.inputs():
            maybe_node = onnx_model.get_node_by_output(input_name)
            if normalize_module_name(maybe_node.name()) in qconfig_mappings:
                qconfig_mappings[normalize_module_name(maybe_node.name())]['activation'] = fq_cls

for item in qconfig_mappings.items():
    print(item)

qconfigs = [
    {'module_name': name} | qconfig
    for name, qconfig in qconfig_mappings.items()
] + [{'name': 'imgs', 'activation': FakeQuantize.with_args(observer=HistogramObserver)}]
for item in qconfigs:
    print(item)

('backbone/stage1/conv/Conv', {'weight': functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>){}})
('backbone/stage1/act/Relu', {'activation': functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.ao.quantization.observer.HistogramObserver'>){}})
('backbone/stage2/stage2/0/conv/Conv', {'weight': functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>){}})
('backbone/stage2/stage2/0/act/Relu', {'activation': <function shared_fq.<locals>.wrapper at 0x71f8f6168430>})
('backbone/stage2/stage2/1/conv/conv/0/conv/Conv', {'weight': functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>){}})
('backbone/stage2/stage2/1/conv/conv/0/act/Relu', {'activation': fu

In [5]:
graph_model = pickle.loads(pickle.dumps(torch_model))
graph_model = symbolic_trace(graph_model, skipped_module_classes=[TorchScatterNd, TorchBinaryOp])

quantizer = ModuleQuantizer()
graph_model = quantizer.quantize(graph_model, qconfigs)

graph_model.print_readable()

class GraphModule(torch.nn.Module):
    def forward(self, imgs):
        # No stacktrace found for following nodes
        fq0 = self.fq0(imgs);  imgs = None
        
         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
        backbone_stage1_conv_conv = getattr(self, "backbone/stage1/conv/Conv")(fq0);  fq0 = None
        
         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
        backbone_stage1_act_relu = getattr(self, "backbone/stage1/act/Relu")(backbone_stage1_conv_conv);  backbone_stage1_conv_conv = None
        
        # No stacktrace found for following nodes
        fq1 = self.fq1(backbone_stage1_act_relu);  backbone_stage1_act_relu = None
        
         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:2

'class GraphModule(torch.nn.Module):\n    def forward(self, imgs):\n        # No stacktrace found for following nodes\n        fq0 = self.fq0(imgs);  imgs = None\n        \n         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = \'\'.join(CapturedTraceback.extract().format())\n        backbone_stage1_conv_conv = getattr(self, "backbone/stage1/conv/Conv")(fq0);  fq0 = None\n        \n         # File: /opt/miniconda3/lib/python3.10/site-packages/torch/fx/proxy.py:219 in create_proxy, code: proxy.node.stack_trace = \'\'.join(CapturedTraceback.extract().format())\n        backbone_stage1_act_relu = getattr(self, "backbone/stage1/act/Relu")(backbone_stage1_conv_conv);  backbone_stage1_conv_conv = None\n        \n        # No stacktrace found for following nodes\n        fq1 = self.fq1(backbone_stage1_act_relu);  backbone_stage1_act_relu = None\n        \n         # File: /opt/miniconda3/lib/python3.10/site-packages/

In [6]:
# calibration
for m in graph_model.modules():
    if isinstance(m, FakeQuantizeBase):
        m.disable_fake_quant()

dataloader = DataLoader(
    DatasetUtils.take_front(
        DatasetUtils.transform(
            dataset,
            lambda items: tuple(x.to('cuda') for x in items)
        ),
        256
    ),
    batch_size=None
)
graph_model.eval().to('cuda')
for data in tqdm(dataloader):
    graph_model(*data)

for m in graph_model.modules():
    if isinstance(m, FakeQuantizeBase):
        m.enable_fake_quant()

100%|████████████████████████████████████████████████| 256/256 [01:49<00:00,  2.35it/s]


# Analysis

In [7]:
dataloader = DataLoader(
    DatasetUtils.take_front(
        DatasetUtils.transform(
            dataset,
            lambda items: tuple(x.to('cuda') for x in items)
        ),
        256
    ),
    batch_size=None
)
torch_model.eval().to('cuda')
graph_model.eval().to('cuda')

for data in dataloader:
    gt = torch_model(*data)
    pred = graph_model(*data)

    for name, metric, val0, val1 in zip(torch_model.onnx_mapping.outputs, compute_metric(gt, pred, cosine_kernel), gt, pred):
        print(name, metric)
    break

feats_0 0.9939883947372437


# Export

In [9]:
finalized_model = quantizer.finalize(graph_model).to('cuda')

torch.onnx.export(
    finalized_model,
    tuple(next(iter(dataloader))),
    qmodel_path,
    input_names=torch_model.onnx_mapping.inputs,
    output_names=torch_model.onnx_mapping.outputs,
)
# torch.onnx.export(
#     torch_model,
#     tuple(next(iter(dataloader))),
#     model_path,
#     input_names=torch_model.onnx_mapping.inputs,
#     output_names=torch_model.onnx_mapping.outputs,
# )