In [1]:
import torch, gc
import onnxruntime
from fastT5 import export_and_get_onnx_model
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
from fastT5 import (OnnxT5, get_onnx_runtime_sessions,
                    generate_onnx_representation, quantize)

# Pytorch baseline

In [2]:
tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
model = T5ForConditionalGeneration.from_pretrained("../serving/trained_model/GuwenNet")

In [3]:
def run_model(model):
    t_input = "先帝开创的事业没有完成一半，却中途去世了。现在天下分裂成三个国家。蜀汉民力困乏，这实在是危急存亡的时候啊。"
    token = tokenizer(t_input, return_tensors='pt')

    tokens = model.generate(input_ids=token['input_ids'],
                   attention_mask=token['attention_mask'],
                           max_length=100)
    output = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
    return output

In [4]:
%timeit -r 10 run_model(model)

1.82 s ± 31.2 ms per loop (mean ± std. dev. of 10 runs, 1 loop each)


In [5]:
del model
gc.collect()

9

# ONNX model 

In [6]:
model_or_model_path = '../serving/trained_model/GuwenNet'

# Step 1. convert huggingfaces t5 model to onnx
onnx_model_paths = generate_onnx_representation(model_or_model_path)
model_sessions = get_onnx_runtime_sessions(onnx_model_paths)
model = OnnxT5(model_or_model_path, model_sessions)

Exporting to onnx... |################################| 3/3
[?25h

In [7]:
%timeit -r 50 run_model(model)

1.15 s ± 47.5 ms per loop (mean ± std. dev. of 50 runs, 1 loop each)


In [8]:
del model
gc.collect()

0

# ONNX quantized

In [9]:
# Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
quant_model_paths = quantize(onnx_model_paths)

# step 3. setup onnx runtime
model_sessions = get_onnx_runtime_sessions(quant_model_paths)

# step 4. get the onnx model
model = OnnxT5(model_or_model_path, model_sessions)

Quantizing... |################################| 3/3
[?25h

In [10]:
%timeit -r 50 run_model(model)

687 ms ± 42.7 ms per loop (mean ± std. dev. of 50 runs, 1 loop each)
