参考来源

- [Inference PyTorch Bert Model with ONNX Runtime on GPU](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb)
- [transformers to onnx](https://huggingface.co/docs/transformers/v4.25.1/en/serialization#export-to-onnx)

首先必须安装依赖, onnxruntime 的 python 包也是分为 CPU 版和 GPU 版的.

- onnxruntime
- onnxruntime-gpu

In [1]:
!pip install onnxruntime-gpu

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


检查 onnxruntime 环境已经安装正确

In [2]:
import onnxruntime
print(onnxruntime.__version__)
print(onnxruntime.get_device())
print(onnxruntime.get_available_providers())

1.12.0
GPU
['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']


同样的, 这次也是使用 BertForMaskedLM 模型

In [132]:
import torch
import numpy
from transformers import BertTokenizer
enc = BertTokenizer.from_pretrained('bert-base-uncased')

masked_sentences = ['Paris is the [MASK] of France.', 
                    'The primary [MASK] of the United States is English.', 
                    'A baseball game consists of at least nine [MASK].', 
                    'Topology is a branch of [MASK] concerned with the properties of geometric objects that remain unchanged under continuous transformations.']
pos_masks = [4, 3, 9, 6]

inputs = enc(masked_sentences, return_tensors="np", padding='max_length', max_length=128, truncation=True)
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [187]:
from transformers import BertForMaskedLM
origin_model = BertForMaskedLM.from_pretrained("bert-base-uncased", torchscript=True).eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# 转换成 ONNX 模型

可以直接使用 transformers.onnx 这个命令行转换模型, 我这里使用了特性头 `--feature=masked-lm`, 因为要和 BertForMaskedLM 类保持一致

In [165]:
# 本地转换模型还是有点报错的, 输出里提到绝对误差超过了 1e-5
!python -m transformers.onnx --model=bert-base-uncased --feature=masked-lm onnx/

Framework not requested. Using torch to export to ONNX.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using framework PyTorch: 1.12.0a0+8a1a93a
Overriding 1 configuration item(s)
	- use_cache -> False
Validating ONNX model...
	-[✓] ONNX model output names match reference model ({'logits'})
	- Validating ONNX Model output "logits":
		-[✓] (3, 9, 30522) matches (3, 9, 30522)
		-[x] values not close enough (atol:

## TODO: 使用 torch.onnx.export 转换模型

In [184]:
for key, val in inputs.items():
    print(key, val.dtype, val.shape)

input_ids int64 (4, 128)
token_type_ids int64 (4, 128)
attention_mask int64 (4, 128)


In [195]:
input_ids = torch.randint(0, 100, (4, 128), device="cpu", dtype=torch.int64)
attention_mask = torch.randint(0, 2, (4, 128), device="cpu", dtype=torch.int64)
token_type_ids = torch.randint(0, 2, (4, 128), device="cpu", dtype=torch.int64)

traced_origin_model = torch.jit.trace(origin_model, [input_ids, attention_mask, token_type_ids])
traced_origin_model(input_ids, attention_mask, token_type_ids)

(tensor([[[-7.1718, -7.2419, -7.6383,  ..., -7.2194, -7.9306, -2.7548],
          [-7.2182, -7.3509, -7.7468,  ..., -7.3301, -7.9822, -2.9581],
          [-7.1483, -7.2817, -7.6816,  ..., -7.2920, -7.9250, -2.8913],
          ...,
          [-7.1612, -7.3135, -7.6708,  ..., -7.1626, -7.9145, -2.7564],
          [-7.1650, -7.3228, -7.6832,  ..., -7.1629, -7.9353, -2.7415],
          [-7.2170, -7.3762, -7.7244,  ..., -7.2227, -7.9912, -2.7825]],
 
         [[-7.0209, -7.0549, -7.4369,  ..., -7.1138, -7.6105, -2.9733],
          [-7.0949, -7.2220, -7.5740,  ..., -7.2443, -7.7694, -3.1173],
          [-7.0031, -7.1278, -7.4964,  ..., -7.1937, -7.7031, -3.0237],
          ...,
          [-6.9061, -7.0161, -7.3499,  ..., -6.9274, -7.5017, -3.0070],
          [-6.9416, -7.0478, -7.3914,  ..., -6.9721, -7.5609, -2.8657],
          [-6.9180, -7.0231, -7.3659,  ..., -6.9642, -7.5726, -2.8618]],
 
         [[-7.0116, -7.0444, -7.4240,  ..., -7.1217, -7.5758, -3.0891],
          [-7.3173, -7.4623,

In [196]:
# 导出 ONNX 模型
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(
    traced_origin_model,
    args=(input_ids, attention_mask, token_type_ids),
    f="./onnx/model_torch.onnx",
    export_params=True,
    verbose=True,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["logits"],
    opset_version=14,
    do_constant_folding=True,
    # 设置动态 shape
    dynamic_axes={
        "input_ids": symbolic_names,
        "attention_mask" : symbolic_names,
        "token_type_ids" : symbolic_names,
    }
)



In [197]:
import onnx
onnx.checker.check_model(onnx.load("./onnx/model_torch.onnx"))

# 加载 ONNX 模型

In [198]:
from onnxruntime import InferenceSession

# 加载 ONNX 模型
session = InferenceSession("onnx/model_torch.onnx", providers=["CUDAExecutionProvider"])

In [199]:
print("输入:")
print([x.name for x in session.get_inputs()])
print([x.shape for x in session.get_inputs()])
print([x.type for x in session.get_inputs()])

print("输出:")
print([x.name for x in session.get_outputs()])
print([x.shape for x in session.get_outputs()])
print([x.type for x in session.get_outputs()])


输入:
['input_ids', 'attention_mask', 'token_type_ids']
[['batch_size', 'max_seq_len'], ['batch_size', 'max_seq_len'], ['batch_size', 'max_seq_len']]
['tensor(int64)', 'tensor(int64)', 'tensor(int64)']
输出:
['logits']
[['Addlogits_dim_0', 'Addlogits_dim_1', 30522]]
['tensor(float)']


In [200]:
# 进行推理, 推理时注意, 模型的输入是 numpy array 类型
outputs = session.run(output_names=["logits"], input_feed=dict(inputs))
outputs[0].shape

(4, 128, 30522)

In [201]:
most_likely_token_ids = [numpy.argmax(outputs[0][i, pos, :]) for i, pos in enumerate(pos_masks)]
print(most_likely_token_ids)
unmasked_tokens = enc.decode(most_likely_token_ids).split(' ')
unmasked_sentences = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens)]
for sentence in unmasked_sentences:
    print(sentence)

[3007, 2653, 7202, 5597]
Paris is the capital of France.
The primary language of the United States is English.
A baseball game consists of at least nine innings.
Topology is a branch of mathematics concerned with the properties of geometric objects that remain unchanged under continuous transformations.


In [205]:
# 和原始模型对照下
inputs_pt = enc(masked_sentences, return_tensors="pt", padding='max_length', max_length=128, truncation=True)
with torch.no_grad():
    outputs_origin = origin_model(**inputs_pt)

most_likely_token_ids = [torch.argmax(outputs_origin[0][i, pos, :]) for i, pos in enumerate(pos_masks)]
print(most_likely_token_ids)
unmasked_tokens = enc.decode(most_likely_token_ids).split(' ')
unmasked_sentences = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens)]
for sentence in unmasked_sentences:
    print(sentence)

[tensor(3007), tensor(2653), tensor(7202), tensor(5597)]
Paris is the capital of France.
The primary language of the United States is English.
A baseball game consists of at least nine innings.
Topology is a branch of mathematics concerned with the properties of geometric objects that remain unchanged under continuous transformations.


In [207]:
# 这个有毒, 差距有点大, 1e-3 都满足不了
a = outputs[0]
b = outputs_origin[0].detach().numpy()
numpy.allclose(a, b, rtol=1e-03, atol=1e-3)

False

In [206]:
b[0, 0, :]

array([-6.5416336, -6.5075865, -6.5212126, ..., -5.8960814, -5.7351847,
       -3.8943403], dtype=float32)

In [203]:
a[0, 0, :]

array([-6.5419493, -6.50757  , -6.521825 , ..., -5.8963876, -5.7355714,
       -3.8943367], dtype=float32)

# 测试性能

注意: 不同序列长度对模型的影响很大, 目前来看 ONNX 比较适合短序列

In [212]:
# 准备下模型, 以及调用函数
session_cpu = InferenceSession("onnx/model_torch.onnx", providers=["CPUExecutionProvider"])
session_gpu = InferenceSession("onnx/model_torch.onnx", providers=["CUDAExecutionProvider"])
origin_model_cpu = BertForMaskedLM.from_pretrained("bert-base-uncased").eval()
origin_model_gpu = BertForMaskedLM.from_pretrained("bert-base-uncased").cuda().eval()

# 应该要让模型的输出是一致的, 都返回 logits, numpy 格式的
def call_onnx_cpu():
    inputs = enc(masked_sentences, return_tensors="np", padding='max_length', max_length=64, truncation=True)
    return session_cpu.run(output_names=["logits"], input_feed=dict(inputs))[0]

def call_onnx_gpu():
    inputs = enc(masked_sentences, return_tensors="np", padding='max_length', max_length=64, truncation=True)
    return session_gpu.run(output_names=["logits"], input_feed=dict(inputs))[0]

def call_torch_cpu():
    inputs_pt = enc(masked_sentences, return_tensors="pt", padding='max_length', max_length=64, truncation=True)
    with torch.no_grad():
        return origin_model_cpu(**inputs_pt)[0].numpy()

def call_torch_gpu():
    inputs_pt = enc(masked_sentences, return_tensors="pt", padding='max_length', max_length=64, truncation=True)
    inputs_pt_gpu = dict({k: v.cuda() for k, v in inputs_pt.items()})
    with torch.no_grad():
        return origin_model_gpu(**inputs_pt_gpu)[0].cpu().numpy()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model

In [48]:
import timeit
def timeGraph(call_func, num_loops=50):
    print("Warm up ...")
    for _ in range(20):
        call_func()

    # 等待同步, cuda 默认是异步调用的
    torch.cuda.synchronize()

    print("Start timing ...")
    timings = []
    for i in range(num_loops):
        start_time = timeit.default_timer()
        call_func()
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        timings.append(end_time - start_time)
        # print("Iteration {}: {:.6f} s".format(i, end_time - start_time))

    return timings

In [47]:
import numpy as np
def printStats(graphName, timings, batch_size):
    times = np.array(timings)
    steps = len(times)
    speeds = batch_size / times
    time_mean = np.mean(times)
    time_med = np.median(times)
    time_99th = np.percentile(times, 99)
    time_std = np.std(times, ddof=0)
    speed_mean = np.mean(speeds)
    speed_med = np.median(speeds)

    msg = ("\n%s =================================\n"
            "batch size=%d, num iterations=%d\n"
            "  Median text batches/second: %.1f, mean: %.1f\n"
            "  Median latency: %.6f, mean: %.6f, 99th_p: %.6f, std_dev: %.6f\n"
            ) % (graphName,
                batch_size, steps,
                speed_med, speed_mean,
                time_med, time_mean, time_99th, time_std)
    print(msg)

In [125]:
timings = timeGraph(call_onnx_cpu)

printStats("BERT ONNX CPU", timings, 1)

Warm up ...
Start timing ...

batch size=1, num iterations=50
  Median text batches/second: 42.1, mean: 41.9
  Median latency: 0.023745, mean: 0.023971, 99th_p: 0.027929, std_dev: 0.001565



In [126]:
timings = timeGraph(call_torch_cpu)

printStats("BERT TORCH CPU", timings, 1)

Warm up ...
Start timing ...

batch size=1, num iterations=50
  Median text batches/second: 32.4, mean: 32.5
  Median latency: 0.030900, mean: 0.030915, 99th_p: 0.035542, std_dev: 0.002183



In [213]:
timings = timeGraph(call_onnx_gpu)

printStats("BERT ONNX GPU", timings, 1)

Warm up ...
Start timing ...

batch size=1, num iterations=50
  Median text batches/second: 67.7, mean: 67.3
  Median latency: 0.014773, mean: 0.015007, 99th_p: 0.018654, std_dev: 0.001541



In [211]:
timings = timeGraph(call_torch_gpu)

printStats("BERT TORCH GPU", timings, 1)

Warm up ...
Start timing ...

batch size=1, num iterations=50
  Median text batches/second: 82.3, mean: 82.2
  Median latency: 0.012153, mean: 0.012248, 99th_p: 0.014901, std_dev: 0.001057



# TODO: ONNX 的 CPU 比 Torch 的 CPU 快一点, 但 GPU 慢很多, 不知道是什么情况?

In [106]:
inputs = enc(masked_sentences, return_tensors="np", padding='max_length', max_length=128)

io_binding = session_gpu.io_binding()
for key, val in inputs.items():
    io_binding.bind_cpu_input(key, val)
io_binding.bind_output('logits')
session_gpu.run_with_iobinding(io_binding)
logits = io_binding.copy_outputs_to_cpu()[0]
logits

array([[[ -6.5419493,  -6.50757  ,  -6.521825 , ...,  -5.8963876,
          -5.7355714,  -3.8943367],
        [ -9.013971 ,  -9.046798 ,  -9.060098 , ...,  -8.257676 ,
          -8.033421 ,  -6.1780944],
        [ -8.652058 ,  -9.085825 ,  -8.771672 , ...,  -7.4472957,
          -5.338148 ,  -9.655459 ],
        ...,
        [ -8.809074 ,  -9.010075 ,  -8.932358 , ...,  -8.146495 ,
          -9.273098 ,  -5.312276 ],
        [ -8.758413 ,  -8.874108 ,  -8.888932 , ...,  -8.339445 ,
          -9.370628 ,  -4.932161 ],
        [ -8.783895 ,  -9.026108 ,  -8.919375 , ...,  -8.316187 ,
          -8.880913 ,  -6.533851 ]],

       [[ -6.6346483,  -6.600541 ,  -6.5921936, ...,  -5.9096456,
          -5.8169036,  -4.149713 ],
        [-12.518587 , -12.860065 , -12.95737  , ..., -13.159752 ,
         -10.315203 , -14.825867 ],
        [ -7.4197383,  -7.990819 ,  -7.946541 , ...,  -8.570905 ,
          -5.710056 , -11.684243 ],
        ...,
        [ -6.4807057,  -6.4920845,  -6.5437775, ...,  

In [109]:
def call_test():
    inputs = enc(masked_sentences, return_tensors="np", padding='max_length', max_length=128)

    io_binding = session_gpu.io_binding()
    for key, val in inputs.items():
        io_binding.bind_cpu_input(key, val)
    io_binding.bind_output('logits')
    session_gpu.run_with_iobinding(io_binding)
    logits = io_binding.copy_outputs_to_cpu()[0]

    return logits


timings = timeGraph(call_test)

printStats("BERT ONNX GPU", timings, 1)

Warm up ...
Start timing ...

batch size=1, num iterations=50
  Median text batches/second: 35.3, mean: 35.2
  Median latency: 0.028292, mean: 0.028495, 99th_p: 0.031135, std_dev: 0.001489



In [108]:
# TODO: 这个输出不对
io_binding = session_gpu.io_binding()
for key, val in inputs.items():
    X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(val, 'cuda', 0)
    io_binding.bind_input(name=key, device_type=X_ortvalue.device_name(), device_id=0, element_type=val.dtype, shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
io_binding.bind_output("logits")
session_gpu.run_with_iobinding(io_binding)
logits = io_binding.copy_outputs_to_cpu()[0]
logits

array([[[-4.934188  , -5.0961747 , -5.1510315 , ..., -4.956326  ,
         -6.8429976 ,  0.82923543],
        [-4.7176423 , -4.9170785 , -4.97676   , ..., -4.660313  ,
         -6.5653915 ,  0.71815693],
        [-4.355648  , -4.5349092 , -4.6175747 , ..., -4.3954697 ,
         -6.210213  ,  0.9891236 ],
        ...,
        [-4.6253095 , -4.8404446 , -4.8514004 , ..., -4.715632  ,
         -6.5525537 ,  0.5545111 ],
        [-4.5998197 , -4.8183794 , -4.844186  , ..., -4.71098   ,
         -6.552678  ,  0.6832975 ],
        [-4.7034082 , -4.911319  , -4.9586306 , ..., -4.7842    ,
         -6.604789  ,  0.66530704]],

       [[-4.7924905 , -4.9325333 , -5.0450907 , ..., -4.6667776 ,
         -6.526173  ,  0.33888003],
        [-4.5583034 , -4.7254114 , -4.8429117 , ..., -4.3630447 ,
         -6.245224  ,  0.24502823],
        [-4.17154   , -4.3237524 , -4.46185   , ..., -4.1069174 ,
         -5.8931756 ,  0.5509099 ],
        ...,
        [-4.4804244 , -4.6598744 , -4.725569  , ..., -

In [82]:
def call_test():
    io_binding = session_gpu.io_binding()
    for key, val in inputs.items():
        X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(val, 'cuda', 0)
        io_binding.bind_input(name=key, device_type=X_ortvalue.device_name(), device_id=0, element_type=val.dtype, shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
    io_binding.bind_output("logits")
    session_gpu.run_with_iobinding(io_binding)
    logits = io_binding.copy_outputs_to_cpu()[0]
    return logits

timings = timeGraph(call_test)

printStats("BERT ONNX GPU", timings, 1)

Warm up ...
Start timing ...

batch size=1, num iterations=50
  Median text batches/second: 65.8, mean: 65.3
  Median latency: 0.015200, mean: 0.015383, 99th_p: 0.018203, std_dev: 0.001020

