Skip to content

Commit

Permalink
add onnx model of fixed input shape for OpenCV (ShiqiYu#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyuentau committed Feb 13, 2022
1 parent 86591d5 commit 1688402
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tasks/task1/exportonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def str2bool(v):
type=str, help='Trained state_dict file path to open')
parser.add_argument('-o', '--output_name', default='yunet',
type=str, help='The output ONNX file, trained parameters inside')
parser.add_argument('--input_batch_size', type=int, default=1, help='Input batch size for the output ONNX model.')
parser.add_argument('--input_height', type=int, default=120, help='Input height set for the output ONNX model.')
parser.add_argument('--input_width', type=int, default=160, help='Input width set for the output ONNX model.')
parser.add_argument('--enable_dynamic_axes', default=True,
type=str2bool, help='Enable dynamic axes for ONNX model.')
parser.add_argument('--opset_version', default=11, help='ONNX opset version to output.')
Expand Down Expand Up @@ -83,7 +86,10 @@ def load_model(model, pretrained_path, load_to_cpu):

print('Finished loading model!')

img = torch.randn(1, 3, 480, 640, requires_grad=False)
bs = args.input_batch_size
h = args.input_height
w = args.input_width
img = torch.randn(bs, 3, h, w, requires_grad=False)
img = img.to(torch.device('cpu'))

input_names = ['input']
Expand All @@ -96,6 +102,6 @@ def load_model(model, pretrained_path, load_to_cpu):
output_path = os.path.join('./onnx', args.output_name + '.onnx')
torch.onnx.export(net, img, output_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=args.opset_version)
else:
output_path = os.path.join('./onnx', args.output_name + '_' + str(args.image_dim) + '.onnx')
output_path = os.path.join('./onnx', '{}_{}x{}.onnx'.format(args.output_name, h, w))
torch.onnx.export(net, img, output_path, input_names=input_names, output_names=output_names, opset_version=args.opset_version)
print('Finished exporing model to ' + output_path)
Binary file added tasks/task1/onnx/yunet_120x160.onnx
Binary file not shown.

0 comments on commit 1688402

Please sign in to comment.