Skip to content

Commit

Permalink
Extend support for exporting reshape to onnx.
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#16632

Differential Revision: D14020906

Pulled By: ezyang

fbshipit-source-id: 168616873044b980145a3554dab942bdec19efb2
  • Loading branch information
BowenBao authored and facebook-github-bot committed Feb 11, 2019
1 parent e661dc2 commit 4335aac
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 0 deletions.
48 changes: 48 additions & 0 deletions test/onnx/expect/TestOperators.test_reshape.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "0.4"
graph {
node {
input: "0"
output: "1"
op_type: "Flatten"
attribute {
name: "axis"
i: 1
type: INT
}
}
name: "torch-jit-export"
input {
name: "0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 9
}
74 changes: 74 additions & 0 deletions test/onnx/expect/TestOperators.test_reshape_as.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "0.4"
graph {
node {
output: "1"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 3
dims: 1
dims: 2
dims: 1
data_type: 1
raw_data: "\220\202l>\204m1?\363\241\210\276\374k\013@\\-\321=\353\234\204\276"
}
type: TENSOR
}
}
node {
input: "1"
output: "2"
op_type: "Shape"
}
node {
input: "0"
input: "2"
output: "3"
op_type: "Reshape"
}
name: "torch-jit-export"
input {
name: "0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "3"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 9
}
9 changes: 9 additions & 0 deletions test/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def test_view(self):
x = torch.tensor([0.0], requires_grad=True)
self.assertONNX(lambda x: x.view(1, 1), x)

def test_reshape(self):
x = torch.tensor([0.0], requires_grad=True)
self.assertONNX(lambda x: x.reshape(1, 1), x)

def test_reshape_as(self):
x = torch.randn(2, 3, requires_grad=True)
y = torch.randn(3, 1, 2, 1, requires_grad=True)
self.assertONNX(lambda x: x.reshape_as(y), x)

def test_index(self):
x = torch.tensor([[0.0]], requires_grad=True)
self.assertONNX(lambda x: x[0], x)
Expand Down
9 changes: 9 additions & 0 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ def _reshape_from_tensor(g, input, shape):
return g.op('Reshape', input, shape)


def reshape(g, self, shape):
return view(g, self, shape)


def reshape_as(g, self, other):
shape = g.op('Shape', other)
return reshape(g, self, shape)


def add(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
Expand Down

0 comments on commit 4335aac

Please sign in to comment.