-
Notifications
You must be signed in to change notification settings - Fork 21.3k
/
symbolic_opset13.py
178 lines (140 loc) · 7.05 KB
/
symbolic_opset13.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 13
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 13
@parse_args('v', 'i', 'none')
def softmax(g, input, dim, dtype=None):
softmax = g.op('Softmax', input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
@parse_args('v', 'i', 'none')
def log_softmax(g, input, dim, dtype=None):
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return return_op
@parse_args('v', 'v', 'i')
def frobenius_norm(g, self, dim=None, keepdim=False):
dim_val = sym_help._maybe_get_const(dim, 'is')
if not sym_help._is_value(dim_val) and len(dim_val) == 0:
return g.op("ReduceL2", self, keepdims_i=0)
sqr = g.op('Mul', self, self)
sumsqr = sym_help._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
return g.op('Sqrt', sumsqr)
@parse_args('v', 'v', 'i', 'i')
def split(g, self, split_size_or_sizes, dim, _outputs=None):
if not sym_help._is_split_static(split_size_or_sizes, _outputs):
split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
if _outputs is None:
return split_out
# Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
if sym_help._is_packed_list(split_size_or_sizes) and \
len(sym_help._unpack_list(split_size_or_sizes)) == _outputs:
split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)]
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
res = []
for i in range(_outputs):
end = g.op("Add", start, split_sizes[i]) # split_sizes is a list of same length as _outputs
res.append(g.op("Slice", self, start, end, axis))
start = end
return res
return [g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)))
for i in range(_outputs)]
split_val = split_size_or_sizes.node()['value']
if split_val.dim() > 0:
return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs)
split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size')
size = self.type().sizes()[dim]
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
splits = g.op("Constant", value_t=torch.tensor(splits))
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
return split(g, self, split_sizes, dim, _outputs)
def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None):
return split(g, self, split_size_or_sizes, dim, _outputs)
def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None):
return split_with_sizes(g, self, split_sizes, dim, _outputs)
@parse_args('v', 'i', 'i')
def unbind(g, self, dim=0, _outputs=None):
if _outputs is None:
return g.op("SplitToSequence",
self,
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
axis_i=dim, keepdims_i=0)
splits = g.op("Constant", value_t=torch.tensor([1] * _outputs))
outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
outputs = [outputs] if _outputs == 1 else outputs
squeezed_outputs = [g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) for out in outputs]
return squeezed_outputs
# Emitted from `torch.nonzero(x, as_tuple=True)`
def nonzero_numpy(g, input, _outputs=None):
return unbind(g, nonzero(g, input), 1, _outputs=_outputs)
@parse_args('v', 'v', 'v', 'i')
def where(g, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if condition.type().scalarType() != 'Bool':
condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx['Bool'])
if self is None:
condition = nonzero(g, condition)
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
return g.op("Where", condition, self, other)
def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
return symbolic
def _reduce_with_dtype(onnx_op, name):
symbolic = _reduce_op_symbolic(onnx_op)
@overload_by_arg_count
def reduce(g, *args, **kwargs):
@parse_args('v', 'none')
def reduce_nodim(g, self, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self)
@parse_args('v', 'v', 'i', 'none')
def reduce_dim(g, self, dim, keepdim, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self, dim, keepdim)
return reduce_nodim, reduce_dim
return reduce
sum = _reduce_with_dtype('ReduceSum', 'sum')
@parse_args('v', 'i', 'i', 'i')
def unsafe_chunk(g, self, chunks, dim, _outputs=None):
if _outputs is None:
return g.op("SplitToSequence",
self,
g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
axis_i=dim, keepdims_i=0)
size = sym_help._get_tensor_dim_size(self, dim)
if size is None:
return _unimplemented('unsafe_chunk', 'unknown dimension size')
split_size = (size + chunks - 1) // chunks
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
# TODO: So far we don't have a module using this method. We'll keep
# this as a constant unless we see a request of dynamics in any
# user's modules.
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)