In [1]:
import os
import sys
import pickle
import time
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import transformers
from transformers import BertPreTrainedModel,BertModel,BertForSequenceClassification,BertTokenizer

In [2]:
model_path='/Users/zhangsongpo/Downloads/bert-base-chinese'
max_length = 256

In [3]:
tokenizer = BertTokenizer.from_pretrained(model_path)

In [4]:
class MyBert(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.num_labels = num_labels

        self.bert = BertModel.from_pretrained(model_path)
        classifier_dropout = 0.2
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(768, num_labels)

        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

In [5]:
my_bert = MyBert(num_labels=2)

Some weights of the model checkpoint at /Users/zhangsongpo/Downloads/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
tmp_text = 'guideline hand hygiene health care setting recommendation healthcare infection control practice advisory committee hicpac shea ape idsa hand hygiene task force prepared john boyce md didier pittet md hospital saint raphael new haven connecticut university geneva geneva switzerland material report originate national center infectious disease james hughes md director division healthcare quality promotion steve solomon md acting director summary guideline hand hygiene health care setting health care worker hcw review data regarding handwash hand antisepsi health care setting addition specific recommendation promote improve hand hygiene practice reduce transmission pathogenic microorganism patient personnel health care setting report review study publish cdc guideline garner js favero cdc guideline handwash hospital environmental control infect control ape guideline larson el ape guidelines committee ape guideline handwash hand antisepsi health care setting infect control issue depth review hand hygiene practice hcw level adherence personnel recommend handwash practice factor adverse affecting adherence new study vivo efficacy alcohol base hand rub low incidence dermatitis associate use review recent study demonstrate value multidisciplinary hand hygiene promotion program potential role alcohol base hand rub improve hand hygiene practice summarize recommendation concerning related issue e use surgical hand antiseptic hand lotion cream wearing artificial fingernail part review scientific data regarding hand hygiene guideline hand hygiene health care setting recommendation healthcare infection control practice advisory committee hicpac shea ape idsa hand hygiene task force prepared john boyce md didier pittet md hospital saint raphael new haven connecticut university geneva geneva switzerland material report originate national center infectious disease james hughes md director division healthcare quality promotion steve'
sample_text = []
for _ in range(100):
    start_index = random.randint(0,200)
    text_len = random.randint(45, 500)
    sample_text.append(tmp_text[start_index:(start_index+text_len)])

In [30]:
inputs = tokenizer(text=tmp_text[:256],
                   return_tensors="pt",
                   padding=True,
                   truncation=True,
                   max_length=max_length
                  )

In [33]:
inputs['input_ids']

tensor([[  101, 11724,  8762, 12126,  8168,   150,  8179, 10006, 10600, 10168,
         10614,  9738,  9107,  8847,  9479, 11839,  8521,  8361, 10168, 11014,
          8217,  9568,  9116,  8809,  9470, 12183,  8877,  9145, 11233,  9428,
          8134, 11104, 12729,  8913, 11057,  9202,  9374,  8139,  9392,  8154,
          8231,  8606, 12126,  8168,   150,  8179, 10006, 10600,  8346,  8998,
          9019, 11685,  8797,  9749,  8675, 10447,  8328, 11399,  9796, 11588,
          8180, 10091,  9786,  8165, 11399, 10537, 10367, 10242,  8178, 10484,
         12619, 12465, 10361,  8178,  8343,  9531,  8171, 12280,  8317,  9194,
          8736,   102]])

In [26]:
%%time
s0 = time.time()
res = my_bert(**inputs)
print(time.time() - s0)

0.29541802406311035
CPU times: user 249 ms, sys: 53.6 ms, total: 303 ms
Wall time: 296 ms


In [27]:
%%time
s0 = time.time()
for _ in sample_text:
    res = my_bert(**inputs)
all_time = time.time() - s0
print(all_time, all_time / len(sample_text))

KeyboardInterrupt: 

tensor([[0.6906, 0.7076]], grad_fn=<SigmoidBackward>)

In [9]:
output_names = ['logits']
dynamic_axes = {'input_ids': [0, 1],'attention_mask': [0, 1],'token_type_ids': [0, 1],}

In [10]:
torch.onnx.export(my_bert,
                  f='./mybert.onnx',
                  args=tuple(inputs.values()),
                  input_names=list(inputs),
                  output_names=output_names,
                  dynamic_axes=dynamic_axes,
                  opset_version=10)

  'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))
  'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))
  'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))
  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]


In [11]:
import onnx

onnx_model = onnx.load('./mybert.onnx')
onnx.checker.check_model(onnx_model)

In [34]:
import onnxruntime

In [35]:
onxx_model_path = r'./mybert.onnx'
options = onnxruntime.SessionOptions()
session = onnxruntime.InferenceSession(onxx_model_path, options)

In [36]:
inputs = tokenizer(text=tmp_text[:256],
                   return_tensors="pt",
                   padding=True,
                   truncation=True,
                   max_length=max_length
                  )
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in inputs.items()}

In [43]:
%%time
s0 = time.time()
res = session.run(None, inputs_onnx)
print(time.time() - s0)

0.06377792358398438
CPU times: user 241 ms, sys: 2.97 ms, total: 244 ms
Wall time: 63.9 ms


In [42]:
%%time
s0 = time.time()
for _ in sample_text:
    res = session.run(None, inputs_onnx)
all_time = time.time() - s0
print(all_time, all_time / len(sample_text))

KeyboardInterrupt: 

In [24]:
from os import environ
from psutil import cpu_count

# Constants from the performance optimization available in onnxruntime
# It needs to be done before importing onnxruntime
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True)) # OMP 的线程数
environ["OMP_WAIT_POLICY"] = 'ACTIVE'

from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers


def create_model_for_provider(model_path: str, provider: str) -> InferenceSession: 
  
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"

    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = 1
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

  # Load the model as a graph and prepare the CPU backend 
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    
    return session

In [25]:
session_cpu = create_model_for_provider(onxx_model_path, "CPUExecutionProvider") # 使用 优化过的 onnx

In [27]:
%%time
s0 = time.time()
res = session_cpu.run(None, inputs_onnx)
print(time.time() - s0)

0.17477822303771973
CPU times: user 171 ms, sys: 4.02 ms, total: 175 ms
Wall time: 175 ms


In [28]:
%%time
s0 = time.time()
for _ in sample_text:
    res = session_cpu.run(None, inputs_onnx)
all_time = time.time() - s0
print(all_time, all_time / len(sample_text))

17.469280004501343 0.17469280004501342
CPU times: user 16.9 s, sys: 209 ms, total: 17.1 s
Wall time: 17.5 s
