Skip to content

Commit

Permalink
AffineのNNPをONNXに変換できるよう修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Mar 14, 2018
1 parent 68a7b9f commit e6b9699
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion python/src/nnabla/utils/converter/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,22 @@ def convert_to_nodes(func, variables):
p = onnx.helper.make_attribute("perm", tp.axes)
n.attribute.extend([p])
elif func.type == "Affine":
ap = func.affine_param
flatten_postfix = "_flatten"
# Broadcast tensor C by default since it's usually a 1D vector
b = onnx.helper.make_attribute("broadcast", 1)
# When base_axis is set, we need to flatten the input to 2D based on the axis
n.attribute.extend([b])
# We need to flatten tensor A to 2D based on the base_axis
x = func.input[0]
flout = x+flatten_postfix
fl = onnx.helper.make_node(
"Flatten",
[x],
[flout])
n.input[0] = flout # rewire input data
a = onnx.helper.make_attribute("axis", ap.base_axis)
fl.attribute.extend([a])
nl.append(fl)
nl.append(n)
return nl

Expand Down

0 comments on commit e6b9699

Please sign in to comment.