Skip to content

Commit

Permalink
BroadcastTo関連のコードを移動し、pads関連のエクスポート時のコードがおかしかったのを修正
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 25, 2018
1 parent c46d264 commit 1f54698
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 319 deletions.
3 changes: 3 additions & 0 deletions build-tools/code_generator/function_types.yaml
Expand Up @@ -232,6 +232,9 @@ Transpose:
Broadcast:
float: [float]
half: [Half]
BroadcastTo:
float: [float]
half: [Half]
OneHot:
float: [int, float]
half: [int, Half]
Expand Down
18 changes: 18 additions & 0 deletions build-tools/code_generator/functions.yaml
Expand Up @@ -1827,6 +1827,24 @@ Array Manipulation:
outputs:
y:
doc: Broadcasted N-D array
BroadcastTo:
snake_name: broadcast_to
doc: |2
Broadcasting ND-array to the specified buffer.
inputs:
x:
doc: N-D array
y:
doc: N-D array
arguments:
axis:
doc: Target axis to start broadcasting. If this is not set, broadcast will try to fit y to x starting from the last dimension
type: int64
default: -1
outputs:
z:
doc: Broadcasted N-D array
OneHot:
snake_name: one_hot
doc: |2
Expand Down
18 changes: 18 additions & 0 deletions python/src/nnabla/utils/converter/functions.yaml
Expand Up @@ -1827,6 +1827,24 @@ Array Manipulation:
outputs:
y:
doc: Broadcasted N-D array
BroadcastTo:
snake_name: broadcast_to
doc: |2
Broadcasting ND-array to the specified buffer.
inputs:
x:
doc: N-D array
y:
doc: N-D array
arguments:
axis:
doc: Target axis to start broadcasting. If this is not set, broadcast will try to fit y to x starting from the last dimension
type: int64
default: -1
outputs:
z:
doc: Broadcasted N-D array
OneHot:
snake_name: one_hot
doc: |2
Expand Down
6 changes: 3 additions & 3 deletions python/src/nnabla/utils/converter/onnx/exporter.py
Expand Up @@ -113,7 +113,7 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
# Copy kernel, stride, and pads values
k = onnx.helper.make_attribute("kernel_shape", mpp.kernel.dim)
s = onnx.helper.make_attribute("strides", mpp.stride.dim)
p = onnx.helper.make_attribute("pads", mpp.pad.dim*2)
p = onnx.helper.make_attribute("pads", mpp.pad.dim[:]*2)
n.attribute.extend([k, s, p])
nl.append(n)
elif func.type == "Convolution":
Expand All @@ -138,7 +138,7 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
weight_shape.dim[weight_base:])
d = onnx.helper.make_attribute("dilations", cp.dilation.dim)
s = onnx.helper.make_attribute("strides", cp.stride.dim)
p = onnx.helper.make_attribute("pads", cp.pad.dim*2)
p = onnx.helper.make_attribute("pads", cp.pad.dim[:]*2)
g = onnx.helper.make_attribute("group", cp.group)
n.attribute.extend([k, d, s, p, g])
nl.append(n)
Expand Down Expand Up @@ -170,7 +170,7 @@ def convert_to_nodes(func, variables, input_types, output_types, broadcast_targe
# Copy kernel, stride, and pads values
k = onnx.helper.make_attribute("kernel_shape", app.kernel.dim)
s = onnx.helper.make_attribute("strides", app.stride.dim)
p = onnx.helper.make_attribute("pads", app.pad.dim*2)
p = onnx.helper.make_attribute("pads", app.pad.dim[:]*2)
n.attribute.extend([k, s, p])
nl.append(n)
elif func.type == "BatchNormalization":
Expand Down

0 comments on commit 1f54698

Please sign in to comment.