In [1]:
import pathlib
import os
import pickle

import optimum
import optimum.onnxruntime
import datasets
import torch
import torch.nn as nn

import segmentador

%load_ext autoreload
%autoreload 2


QUANTIZED_MODELS_DIR = "quantized_models"
pathlib.Path(QUANTIZED_MODELS_DIR).mkdir(exist_ok=True, parents=True)

## LSTM Quantization

In [None]:
segmenter_lstm = segmentador.LSTMSegmenter(
    uri_model="../pretrained_segmenter_model/512_6000_1_lstm/checkpoints/epoch=3-step=3591.ckpt",
    uri_tokenizer="../tokenizers/6000_subwords",
    device="cpu",
)

In [None]:
model_lstm = segmenter_lstm.model
# model_lstm.qconfig = torch.quantization.float_qparams_weight_only_qconfig

quantized_model_lstm = torch.quantization.quantize_dynamic(
    model_lstm,
    {nn.LSTM, nn.Linear},
    dtype=torch.qint8,
)

print(quantized_model_lstm)

In [None]:
# torch.save(model_lstm.state_dict(), os.path.join(QUANTIZED_MODELS_DIR, "512_6000_1_lstm.pt"))
torch.save(
    quantized_model_lstm.state_dict(),
    os.path.join(QUANTIZED_MODELS_DIR, "q_512_6000_1_lstm.pt"),
    pickle_protocol=pickle.HIGHEST_PROTOCOL,
)

## BERT Quantization

In [3]:
onnx_model_path = os.path.join(QUANTIZED_MODELS_DIR, "4_6000_layer_model.onnx")
onnx_quantized_model_output_path = os.path.join(QUANTIZED_MODELS_DIR, "q_4_6000_layer_model_per_channel.onnx")
onnx_config_path = os.path.join(QUANTIZED_MODELS_DIR, "q_4_6000_layer_model_per_channel_config.pickle")

In [2]:
segmenter_bert = segmentador.BERTSegmenter(
    uri_model="../pretrained_segmenter_model/4_6000_layer_model/",
    uri_tokenizer="../tokenizers/6000_subwords",
    device="cpu",
)

In [7]:
# The type of quantization to apply
qconfig = optimum.onnxruntime.configuration.AutoQuantizationConfig.arm64(
    is_static=False,
    per_channel=True,
)

quantizer = optimum.onnxruntime.ORTQuantizer(
    model=segmenter_bert.model,
    tokenizer=segmenter_bert.tokenizer,
)

quantizer.export(
    onnx_model_path=onnx_model_path,
    onnx_quantized_model_output_path=onnx_quantized_model_output_path,
    quantization_config=qconfig,
)

with open(onnx_config_path, "wb") as f_out:
    pickle.dump(quantizer._onnx_config, f_out)

In [7]:
segmenter_bert_quantized = segmentador.QONNXBERTSegmenter(
    uri_model=onnx_quantized_model_output_path,
    uri_tokenizer="../tokenizers/6000_subwords/",
    uri_onnx_config=onnx_config_path,
)

In [8]:
df = datasets.Dataset.from_file(
    "../data/refined_datasets/df_tokenized_split_0_120000_6000/"
    "combined_test_48_parts_1036_instances/dataset.arrow"
)[:10]

In [10]:
segmenter_bert_quantized("Artigo 7: hello Artigo 8 Hello again!")

['Artigo 7 : hello', 'Artigo 8 Hello again!']