Skip to content

Commit 36cce4e

Browse files
author
wanglichun
committed
add test and onnx_export interface
1 parent 2eb9961 commit 36cce4e

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

docs/python/test.mdx

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,48 @@ test_from_raw_file(forward_function, os.path.join("assets/norm_jpg/"),
104104
</details>
105105

106106

107+
### torchpipe.utils.test.throughout
108+
```python
109+
def throughout(model_path, precision, post=None,input_data=None, calibrate_dir="calibrate_dir")
110+
```
111+
> 给定模型文件路径[onnx or tensorrt],测试该模型在torchpipe吞吐,默认配置max=4,instance=2
112+
113+
:::tip 参数
114+
- **model_path** - 测试模型的路径,支持onnx和tensorrt模型,onnx模型需后缀名为.onnx,测试时会自动转为tensorrt模型,tensorrt模型需要后缀名为.trt
115+
- **precision** - 支持fp32、fp16、int8(**还在优化中**)
116+
- **post** - 模型需要的postprocessor算子,这在检测中比较常用。
117+
- **input_data** 设置输入测试数据,torch.tensor类型,如果不设置,会根据模型自动读取输入尺寸,并将输入设置为torch.randn(shape)
118+
- **calibrate_dir** 指向int8量化时calibrate的数据路径,存放数据需为torch.save类型
119+
120+
:::
121+
122+
123+
124+
<details><summary>示例</summary>
125+
126+
```python
127+
import torchpipe
128+
import torch
129+
import os
130+
131+
## test throughout
132+
onnx_path = os.path.join("/tmp", f"resnet50.onnx")
133+
input_data = torch.randn(1, 3, 224, 224).cuda()
134+
precision = "fp16"
135+
torchpipe.utils.test.throughout(onnx_path, precision, post=None, input_data=input_data)
136+
```
137+
138+
<details><summary>calibrate_dir存放的calibrate文件示例</summary>
139+
140+
```python
141+
142+
data = cv2.imread("../static/dog_result.jpg")
143+
data=preprocess_image(data)
144+
import random
145+
torch.save(torch.from_numpy(data),os.path.join(calibrate_dir, f"cache_{random.random()}.pt"))
146+
147+
```
148+
107149

108150
### *torchpipe.utils.test.test_functions*
109151
```python

docs/tools/onnx_export.mdx

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
---
2+
id: onnx_export
3+
title: 导出onnx模型
4+
type: explainer
5+
---
6+
7+
torchpipe封装的PyTorch模型转为ONNX模型接口。
8+
9+
:::caution 验证阶段
10+
- onnx_export API暂不稳定。需要收集更多意见。
11+
- 转换后,建议验证与转换前原模型的输出是否一致。
12+
13+
### torchpipe.utils.onnx_export
14+
```python
15+
def onnx_export(model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], onnx_path, input = None, opset = 17, simplify=True):
16+
```
17+
> 给定模型PyTorch模型,将该模型转换为ONNX模型,默认采用opset=17
18+
19+
:::tip 参数
20+
- **model** - PyTorch模型。
21+
- **onnx_path** - ONNX模型保存路径。
22+
- **input** - 模型输入,如果不设置默认为torch.randn(1,3,224,224)。
23+
- **opset** ONNX的opset。
24+
- **simplify** 是否调用simplify库进行onnx的模型简化。[simplify库连接:https://github.com/daquexian/onnx-simplifier](https://github.com/daquexian/onnx-simplifier)
25+
26+
:::
27+
28+
29+
30+
<details><summary>示例</summary>
31+
32+
```python
33+
import os
34+
from torchvision import models
35+
import torch
36+
import torchpipe
37+
38+
## export onnx
39+
m = models.resnet50(weights=None).eval()
40+
onnx_path = os.path.join("/tmp", f"resnet50.onnx")
41+
input = torch.randn(1, 3, 224, 224)
42+
torchpipe.utils.onnx_export(m, onnx_path, input, opset=17, simplify=True)
43+
44+
```
45+
46+
47+
48+

0 commit comments

Comments
 (0)