In [58]:
import pathlib
import os
import pickle
import collections
import timeit

import optimum
import optimum.onnxruntime
import onnxruntime
import datasets
import torch
import torch.nn as nn
import numpy as np

import segmentador
import eval_model

%load_ext autoreload
%autoreload 2

DEV_RUN = False

QUANTIZED_MODELS_DIR = "quantized_models"
pathlib.Path(QUANTIZED_MODELS_DIR).mkdir(exist_ok=True, parents=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
curated_df = datasets.Dataset.from_file(
    "../data/refined_datasets/df_tokenized_split_0_120000_6000/"
    "combined_test_48_parts_1036_instances/dataset.arrow"
)

concat_curated_df = collections.defaultdict(list)

for key in curated_df.features.keys():
    for val in curated_df[key]:
        concat_curated_df[key] += val
        if DEV_RUN and len(concat_curated_df[key]) >= 100000:
            break
        
concat_curated_df.keys(), tuple(len(val) for key, val in concat_curated_df.items())

(dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask']),
 (509827, 509827, 509827, 509827))

## Creating LSTM Quantization

In [32]:
segmenter_lstm = segmentador.LSTMSegmenter(
    uri_model="../pretrained_segmenter_model/512_6000_1_lstm/checkpoints/epoch=3-step=3591.ckpt",
    uri_tokenizer="../tokenizers/6000_subwords",
    device="cpu",
    inference_pooling_operation="max",
)

In [47]:
model_lstm = segmenter_lstm.model
model_lstm.embeddings.qconfig = torch.quantization.float_qparams_weight_only_qconfig

quantized_model_lstm = torch.quantization.quantize_dynamic(
    model_lstm,
    {nn.Embedding, nn.LSTM, nn.Linear},
    dtype=torch.qint8,
)

print(quantized_model_lstm)

_LSTMSegmenterTorchModule(
  (embeddings): QuantizedEmbedding(num_embeddings=6000, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
  (lstm): DynamicQuantizedLSTM(768, 512, batch_first=True, bidirectional=True)
  (lin_out): DynamicQuantizedLinear(in_features=1024, out_features=4, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)


In [48]:
torch.save(
    quantized_model_lstm.state_dict(),
    os.path.join(QUANTIZED_MODELS_DIR, "q_512_6000_1_lstm.pt"),
    pickle_protocol=pickle.HIGHEST_PROTOCOL,
)

In [55]:
segmenter_lstm_quantized = segmentador.LSTMSegmenter(
    uri_model="quantized_models/q_512_6000_1_lstm.pt",
    uri_tokenizer="../tokenizers/6000_subwords",
    device="cpu",
    quantize_weights=True,
    lstm_hidden_layer_size=512,
    lstm_num_layers=1,
    inference_pooling_operation="max",
)

## Creating BERT Quantization

In [7]:
onnx_model_path = os.path.join(QUANTIZED_MODELS_DIR, "4_6000_layer_model.onnx")

onnx_quantized_model_output_path_dynamic = os.path.join(
    QUANTIZED_MODELS_DIR,
    "q_dynamic_4_6000_layer_model.onnx",
)
onnx_optimized_model_output_path = os.path.join(
    QUANTIZED_MODELS_DIR,
    "4_6000_layer_model_optimized.onnx",
)

onnx_config_quant_path = os.path.join(
    QUANTIZED_MODELS_DIR,
    "q_4_6000_layer_model_config.pickle",
)
onnx_config_opt_path = os.path.join(
    QUANTIZED_MODELS_DIR,
    "4_6000_layer_model_optimization_config.pickle",
)

In [8]:
segmenter_bert = segmentador.BERTSegmenter(
    uri_model="../pretrained_segmenter_model/4_6000_layer_model/",
    device="cpu",
)

In [9]:
# The type of quantization to apply
qconfig = optimum.onnxruntime.configuration.AutoQuantizationConfig.avx2(
    is_static=False,
    per_channel=True,
)

quantizer = optimum.onnxruntime.ORTQuantizer(
    model=segmenter_bert.model,
    tokenizer=segmenter_bert.tokenizer,
)

quantizer.export(
    onnx_model_path=onnx_model_path,
    onnx_quantized_model_output_path=onnx_quantized_model_output_path_dynamic,
    quantization_config=qconfig,
)

with open(onnx_config_quant_path, "wb") as f_out:
    pickle.dump(quantizer._onnx_config, f_out)

In [10]:
segmenter_bert_quantized = segmentador.QONNXBERTSegmenter(
    uri_model=onnx_quantized_model_output_path_dynamic,
    uri_tokenizer="../tokenizers/6000_subwords/",
    uri_onnx_config=onnx_config_quant_path,
)

In [63]:
from optimum.onnxruntime.configuration import OptimizationConfig
from optimum.onnxruntime import ORTOptimizer

# optimization_config=99 enables all available graph optimisations
optimization_config = OptimizationConfig(optimization_level=99)

optimizer = ORTOptimizer(
    model=segmenter_bert.model,
    tokenizer=segmenter_bert.tokenizer,
)

optimizer.export(
    onnx_model_path=onnx_model_path,
    onnx_optimized_model_output_path=onnx_optimized_model_output_path,
    optimization_config=optimization_config,
)

segmenter_bert_opt_quantized = onnxruntime.quantization.quantize_dynamic(
    onnx_optimized_model_output_path,
    os.path.join(QUANTIZED_MODELS_DIR, "q_4_6000_layer_model_optimized.onnx"),
    weight_type=onnxruntime.quantization.QuantType.QUInt8,
)

with open(onnx_config_opt_path, "wb") as f_out:
    pickle.dump(optimizer._onnx_config, f_out)

2022-04-06 22:25:49.659594166 [W:onnxruntime:, inference_session.cc:1546 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.


failed in shape inference <class 'AssertionError'>
failed in shape inference <class 'AssertionError'>


In [65]:
segmenter_bert_optimized = segmentador.QONNXBERTSegmenter(
    uri_model=onnx_optimized_model_output_path,
    uri_tokenizer="../tokenizers/6000_subwords/",
    uri_onnx_config=onnx_config_opt_path,
)

segmenter_bert_optimized_quantized = segmentador.QONNXBERTSegmenter(
    uri_model=os.path.join(QUANTIZED_MODELS_DIR, "q_4_6000_layer_model_optimized.onnx"),
    uri_tokenizer="../tokenizers/6000_subwords/",
    uri_onnx_config=onnx_config_opt_path,
)

## Validating performance

In [30]:
def validate(
    model,
    moving_window_size: int = 1024,
    window_shift_size: float = 0.5,
    batch_size: int = 64,
) -> dict[str, float]:
    t_start = timeit.time.perf_counter()
    
    logits = model(
        concat_curated_df,
        batch_size=batch_size,
        return_logits=True,
        show_progress_bar=True,
        window_shift_size=window_shift_size,
        moving_window_size=moving_window_size,
        
    ).logits
    
    t_delta = timeit.time.perf_counter() - t_start
    
    metrics = eval_model.compute_metrics(
        ([logits], [concat_curated_df["labels"]]),
    )
    metrics["approx_inference_time"] = t_delta
    
    return metrics

In [15]:
metrics_bert = validate(segmenter_bert)
metrics_bert_quantized = validate(segmenter_bert_quantized)
metrics_bert_optimized = validate(segmenter_bert_optimized)
metrics_lstm = validate(segmenter_lstm)
metrics_lstm_quantized = validate(segmenter_lstm_quantized)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [66]:
metrics_bert_optimized_quantized = validate(segmenter_bert_optimized_quantized)

  0%|          | 0/16 [00:00<?, ?it/s]

In [69]:
metrics_bert_quantized

{'per_cls_precision_0': 0.997789569280125,
 'per_cls_precision_1': 0.9220820863607074,
 'per_cls_precision_2': 0.8533333333219556,
 'per_cls_precision_3': 0.5202702702351168,
 'per_cls_recall_0': 0.9975274803711349,
 'per_cls_recall_1': 0.9386356218029173,
 'per_cls_recall_2': 0.7683073229199483,
 'per_cls_recall_3': 0.5579710144523209,
 'macro_precision': 0.8233688147994762,
 'macro_recall': 0.8156103598865804,
 'macro_f1': 0.8194712191987011,
 'overall_accuracy': 0.9953275695592821,
 'approx_inference_time': 241.41552834300091}

In [68]:
metrics_bert_optimized_quantized

{'per_cls_precision_0': 0.9978579078900107,
 'per_cls_precision_1': 0.9207109408249033,
 'per_cls_precision_2': 0.861185983815887,
 'per_cls_precision_3': 0.5234899328507725,
 'per_cls_recall_0': 0.9975017844396574,
 'per_cls_recall_1': 0.9412468719389171,
 'per_cls_recall_2': 0.7671068427278859,
 'per_cls_recall_3': 0.56521739126339,
 'macro_precision': 0.8258111913453934,
 'macro_recall': 0.8177682225924625,
 'macro_f1': 0.821770022559032,
 'overall_accuracy': 0.9953691885952743,
 'approx_inference_time': 207.6828866969954}

In [56]:
metrics_lstm_max = validate(segmenter_lstm, moving_window_size=2048)
metrics_lstm_max_quantized = validate(segmenter_lstm_quantized, moving_window_size=2048)

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

In [35]:
metrics_lstm_max

{'per_cls_precision_0': 0.9983203198521369,
 'per_cls_precision_1': 0.9875860512346375,
 'per_cls_precision_2': 0.9202702702578341,
 'per_cls_precision_3': 0.6442953019701815,
 'per_cls_recall_0': 0.9994946466809136,
 'per_cls_recall_1': 0.9521270808389161,
 'per_cls_recall_2': 0.8175270107945075,
 'per_cls_recall_3': 0.6956521738626339,
 'macro_precision': 0.8876179858286974,
 'macro_recall': 0.8662002280442428,
 'macro_f1': 0.876778324315885,
 'overall_accuracy': 0.9977497974540248,
 'approx_inference_time': 45.545376046000456}

In [36]:
metrics_lstm_max_quantized

{'per_cls_precision_0': 0.9983174729013041,
 'per_cls_precision_1': 0.9875832486725503,
 'per_cls_precision_2': 0.920377867733868,
 'per_cls_precision_3': 0.6442953019701815,
 'per_cls_recall_0': 0.9994946466809136,
 'per_cls_recall_1': 0.9519094766609161,
 'per_cls_recall_2': 0.8187274909865699,
 'per_cls_recall_3': 0.6956521738626339,
 'macro_precision': 0.887643472819476,
 'macro_recall': 0.8664459470477585,
 'macro_f1': 0.8769166227813644,
 'overall_accuracy': 0.9977470228516253,
 'approx_inference_time': 34.70538644899716}

In [57]:
metrics_lstm_max_quantized

{'per_cls_precision_0': 0.9983203198521369,
 'per_cls_precision_1': 0.9875846501117521,
 'per_cls_precision_2': 0.920377867733868,
 'per_cls_precision_3': 0.6442953019701815,
 'per_cls_recall_0': 0.9994946466809136,
 'per_cls_recall_1': 0.9520182787499161,
 'per_cls_recall_2': 0.8187274909865699,
 'per_cls_recall_3': 0.6956521738626339,
 'macro_precision': 0.8876445349169846,
 'macro_recall': 0.8664731475700084,
 'macro_f1': 0.8769310718283391,
 'overall_accuracy': 0.9977497974540248,
 'approx_inference_time': 33.45247530399865}