In [21]:
##################################################################################
# 分为三个部分
# 1. tokenizer部分
# 2. transformer部分
# 3. pooling部分


# from multiprocessing.pool import Pool
import numpy as np
import onnxruntime
import psutil
from sympy import im
from transformers import (AutoConfig, AutoModel, AutoTokenizer)
import os
import json
from sentence_transformers.models import Pooling

from sentence_transformers import SentenceTransformer as sbert

from tqdm import tqdm
import torch as t 



In [2]:

##################################################################################
# 处理transformer和 tokenizer部分

big_model_path = "../models/paraphrase-multilingual-MiniLM-L12-v2"

modules_json_path = os.path.join(big_model_path, 'modules.json')
with open(modules_json_path) as fIn:
    modules_config = json.load(fIn)

tf_from_s_path = os.path.join(big_model_path, modules_config[0].get('path'))


# 基本参数

max_seq_length = 128
doc_stride = 128
max_query_length = 64
# Enable overwrite to export onnx model and download latest script each time when running this notebook.
enable_overwrite = True
# Total samples to inference. It shall be large enough to get stable latency measurement.
total_samples = 1000


# # Load pretrained model and tokenizer
# Load pretrained model and tokenizer
config_class, model_class, tokenizer_class = (
    AutoConfig, AutoModel, AutoTokenizer)

cache_dir = os.path.join(".", "cache_models")
config = config_class.from_pretrained(tf_from_s_path, cache_dir=cache_dir)
tokenizer = tokenizer_class.from_pretrained(
    tf_from_s_path, do_lower_case=True, cache_dir=cache_dir)
model_transformer = model_class.from_pretrained(
    tf_from_s_path, from_tf=False, config=config, cache_dir=cache_dir)


In [3]:

##################################################################################
# 使用onnx 和cuda推理部分

output_dir = os.path.join("..", "onnx_models")
export_model_path = os.path.join(output_dir, 'Multilingual_MiniLM_L12.onnx')
device_name = 'gpu'
sess_options = onnxruntime.SessionOptions()
sess_options.optimized_model_filepath = os.path.join(
    output_dir, "optimized_model_{}.onnx".format(device_name))
# Please change the value according to best setting in Performance Test Tool result.
sess_options.intra_op_num_threads = psutil.cpu_count(logical=True)
session = onnxruntime.InferenceSession(
    export_model_path, sess_options, providers=['CUDAExecutionProvider'])


2022-02-25 14:01:58.432835797 [W:onnxruntime:, inference_session.cc:1407 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.


In [4]:
##################################################################################
# 处理pooling部分
pooling_model_path = os.path.join(big_model_path, modules_config[1].get('path'))
pooling_model = Pooling.load(pooling_model_path)
# pooling_model_path

In [25]:

##################################################################################
# 推理函数


def inferpart(session, sentences, pooling_model):
    if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
        sentences = [sentences]

    inputs = tokenizer(
        sentences,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    ort_inputs = {k:v.cpu().numpy() for k, v in inputs.items()}
    ort_outputs_gpu = session.run(None, ort_inputs)
    ort_result = pooling_model.forward(features={'token_embeddings':t.Tensor(ort_outputs_gpu[0]),
    'attention_mask':inputs.get('attention_mask')})
    result = ort_result.get('sentence_embedding')
    return result


_ = [inferpart(session=session, sentences = ['您好'], pooling_model=pooling_model) for i in tqdm(range(2000))]

100%|██████████| 2000/2000 [00:05<00:00, 340.29it/s]


In [27]:
# 使用原生的sentence transformer代码
model_sbert_raw = sbert(big_model_path, device='cuda')

_ = [model_sbert_raw.encode(['您好'],device='cuda') for i in tqdm(range(2000))]

100%|██████████| 2000/2000 [00:25<00:00, 77.36it/s]
