-
Notifications
You must be signed in to change notification settings - Fork 26
/
test_ops.py
103 lines (86 loc) · 2.72 KB
/
test_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import logging
import shrub
import tflite2onnx as t2o
shrub.util.formatLogging(logging.DEBUG)
def end2end_test(model_name, use_layout):
cur_dir = os.path.dirname(os.path.abspath(__file__))
tflm_dir = os.path.abspath(cur_dir + '/../assets/tests')
tflm_name = model_name + '.tflite'
onnx_name = model_name + '.onnx'
tflm_path = os.path.join(tflm_dir, tflm_name)
t2o.convert(tflm_path, onnx_name)
m = shrub.tflite.parse(tflm_path)
m.genInput()
onnx_ret = shrub.onnx.run(onnx_name, m.inputs, use_layout)
tflite_ret = shrub.tflite.run(tflm_path, m.inputs)
assert(shrub.network.cmpTensors(onnx_ret, tflite_ret, useLayout=use_layout))
def test_ops_implicit_layout():
# these ops will stop layout propagation
OP_LIST = (
'avgpooling.float32',
'avgpool-concat.float32',
'conv.float32',
'conv-dilation.float32',
'conv-quant-fp16.float32',
'conv-relu.float32',
'conv-relu6.float32',
'conv-stride.float32',
'depthwise-conv.float32',
'depthwise-conv-stride.float32',
'fullyconnected.float32',
'fullyconnected-relu6.float32',
'maxpooling.float32',
'resize-bilinear.float32',
'resize-nearest-neighbor.float32',
'conv-reshape.float32',
'reshape-conv.float32',
'conv-reshape-multiple-conv.float32',
'transposeconv-samepad-stride2.float32',
'transposeconv-samepad.float32',
'transposeconv-validpad-stride2.float32',
'transposeconv-validpad.float32',
)
for op in OP_LIST:
end2end_test(op, 'NCHW')
def test_ops_post_propagation():
# these ops need post-propagation handling
OP_LIST = (
'concat.float32',
'mean.float32',
'padding.float32',
'reshape.float32',
'softmax.float32',
'split.float32',
'stridedslice-beginmask.float32',
'stridedslice-endmask.float32',
'stridedslice-stride.float32',
'stridedslice.float32',
'transpose.float32',
'mirror-pad.int32',
)
for op in OP_LIST:
end2end_test(op, 'NHWC')
def test_ops_layout_transparent():
# these ops are very wild :)
OP_LIST = (
'abs.float32',
'add.float32',
'add-relu.float32',
'mul.float32',
'relu6.float32',
'relu.float32',
'prelu.float32',
'sigmoid.float32',
'sub.float32',
'abs-sqrt.float32',
'relu6-power.float32',
'squared-diff.float32',
'abs-add-rsqrt.float32',
)
for op in OP_LIST:
end2end_test(op, 'NHWC')
if __name__ == '__main__':
test_ops_implicit_layout()
test_ops_post_propagation()
test_ops_layout_transparent()