Skip to content

Commit

Permalink
Merge pull request apache#29 from heliqi/paddle_frontend
Browse files Browse the repository at this point in the history
Paddle frontend
  • Loading branch information
jiangjiajun committed Sep 13, 2021
2 parents a0849a9 + 629929e commit 76194aa
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 39 deletions.
48 changes: 42 additions & 6 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ def convert_elementwise_op(g, op, block):
"elementwise_mul": "multiply",
"elementwise_sub": "subtract",
"elementwise_mod": "mod",
"elementwise_max": "maximum",
"elementwise_min": "minimum",
"elementwise_pow": "power",
"elementwise_floordiv": "floor_divide",
"floor_mod": "floor_mod",
Expand Down Expand Up @@ -948,6 +950,25 @@ def convert_mul(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_numel(g, op, block):
"""Operator converter for numel."""

input_x = g.get_node(op.input("Input")[0])
out = _op.ndarray_size(input_x)
out = _op.expand_dims(out, axis=0)
g.add_node(op.output("Out")[0], out)


def convert_nonzero(g, op, block):
"""Operator converter for nonzero."""

input_x = g.get_node(op.input("Condition")[0])
out = _op.transform.argwhere(input_x)
# Paddle NonZero always outputs int64
out = _op.cast(out, "int64")
g.add_node(op.output("Out")[0], out)


def convert_pool2d(g, op, block):
"""Operator converter for pool2d."""

Expand Down Expand Up @@ -1093,6 +1114,15 @@ def convert_range(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_reciprocal(g, op, block):
"""Operator converter for reciprocal."""

x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
out = _expr.const(1.0, dtype) / x
g.add_node(op.output("Out")[0], out)


def convert_reduce(g, op, block):
"""Operator converter for reduce."""

Expand Down Expand Up @@ -1288,7 +1318,7 @@ def convert_scale(g, op, block):
bias_after_scale = op.attr("bias_after_scale")
x = g.get_node(op.input("X")[0])
if np.isclose(scale, 1.0) and np.isclose(bias, 0.0):
out = _op.copy(x)
out = x
else:
if np.isclose(bias, 0.0):
out = x * _expr.const(np.array(scale).astype("float32"))
Expand Down Expand Up @@ -1318,7 +1348,7 @@ def convert_slice(g, op, block):
"""Operator converter for slice."""

data = g.get_node(op.input("Input")[0])
dims = len(block.var(op.input("Input")[0]).shape)
dims = len(infer_shape(data))
dtype = "int64"

axes = op.attr("axes")
Expand All @@ -1343,21 +1373,20 @@ def convert_slice(g, op, block):
else:
starts = op.attr("starts")
starts = _expr.const(starts)
start_dtype = infer_type(starts).checked_type.dtype
if isinstance(starts, _expr.Expr):
starts = _op.scatter(
_op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype),
_op.const([0] * dims, dtype=start_dtype),
axes,
starts,
axis=0,
)

data_shape = shape_of(data)
ends = op.input("EndsTensor")
if ends:
ends = g.get_node(ends[0])
elif op.input("EndsTensorList"):
ends = []
data_shape = data_shape.astype(dtype)
for end_index in op.input("EndsTensorList"):
end_index = g.get_node(end_index)
if not isinstance(end_index, _expr.Expr):
Expand All @@ -1370,9 +1399,11 @@ def convert_slice(g, op, block):
ends = op.attr("ends")
ends = _expr.const(ends)
if isinstance(ends, _expr.Expr):
data_shape = shape_of(data, infer_type(ends).checked_type.dtype)
ends = _op.scatter(data_shape, axes, ends, axis=0)

out = _op.strided_slice(data, begin=starts, end=ends)
strides = _op.const([1] * dims, dtype=start_dtype)
out = _op.strided_slice(data, begin=starts, end=ends, strides=strides)
if decrease_axis:
out = _op.squeeze(out, axis=decrease_axis)
g.add_node(op.output("Out")[0], out)
Expand Down Expand Up @@ -1554,6 +1585,8 @@ def convert_unsqueeze(g, op, block):
"elementwise_mul": convert_elementwise_op,
"elementwise_sub": convert_elementwise_op,
"elementwise_mod": convert_elementwise_op,
"elementwise_max": convert_elementwise_op,
"elementwise_min": convert_elementwise_op,
"elementwise_pow": convert_elementwise_op,
"elementwise_floordiv": convert_elementwise_op,
"equal": convert_elementwise_op,
Expand Down Expand Up @@ -1598,6 +1631,7 @@ def convert_unsqueeze(g, op, block):
"pow": convert_pow,
"p_norm": convert_norm,
"range": convert_range,
"reciprocal": convert_reciprocal,
"reduce_all": convert_reduce,
"reduce_any": convert_reduce,
"reduce_max": convert_reduce,
Expand All @@ -1613,6 +1647,7 @@ def convert_unsqueeze(g, op, block):
"shape": convert_shape,
"sigmoid": convert_unary_op,
"sin": convert_unary_op,
"size": convert_numel,
"slice": convert_slice,
"softmax": convert_softmax,
"split": convert_split,
Expand All @@ -1625,6 +1660,7 @@ def convert_unsqueeze(g, op, block):
"tile": convert_tile,
"transpose2": convert_transpose,
"unsqueeze2": convert_unsqueeze,
"where_index": convert_nonzero,
}


Expand Down
72 changes: 39 additions & 33 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,9 @@ def verify_model(func, input_data, rtol=1e-5, atol=1e-5, input_shape=None):
input_name = "input{}".format(idx)
if input_shape:
shape = input_shape[idx]
input_shape_dict[input_name] = [relay.Any()] * len(shape)
else:
shape = data.shape
input_shape_dict[input_name] = shape
input_shape_dict[input_name] = shape
input_spec.append(paddle.static.InputSpec(dtype=data.dtype, shape=shape, name=input_name))
input_names.append(input_name)
if isinstance(data, np.ndarray):
Expand Down Expand Up @@ -159,6 +158,8 @@ def forward(self, inputs):
"log2",
"log10",
"log1p",
"numel",
"reciprocal",
"relu",
"rsqrt",
"sigmoid",
Expand Down Expand Up @@ -652,6 +653,7 @@ def test_forward_elemwise():
class ElemwiseOp(nn.Layer):
def __init__(self, op_name):
super(ElemwiseOp, self).__init__()
self.op_name_ = op_name
for candidate in (paddle, paddle.nn.functional):
self.func = getattr(candidate, op_name, None)
if self.func:
Expand All @@ -660,11 +662,15 @@ def __init__(self, op_name):
@paddle.jit.to_static
def forward(self, input1, input2):
y = self.func(input1, input2)
return paddle.cast(y, "int32")
if "equal" in self.op_name_ or "than" in self.op_name_:
y = paddle.cast(y, "int32")
return y

op_list = [
"floor_divide",
"floor_mod",
"maximum",
"minimum",
"equal",
"greater_than",
"less_equal",
Expand Down Expand Up @@ -766,7 +772,7 @@ def index_select2(x, index):

input_shape = [3, 10]
input_data = paddle.rand(input_shape, dtype="float32")
index = paddle.to_tensor(np.array([0, 1, 1]).astype('int32'))
index = paddle.to_tensor(np.array([0, 1, 1]).astype("int32"))
verify_model(index_select1, input_data=[input_data, index])
verify_model(index_select2, input_data=[input_data, index])

Expand Down Expand Up @@ -950,6 +956,31 @@ def forward(self, input1, input2):


@tvm.testing.uses_gpu
def test_forward_nonzero():
class Nonzero(nn.Layer):
def __init__(self, as_tuple=False):
super().__init__()
self.as_tuple = as_tuple

@paddle.jit.to_static
def forward(self, inputs):
return paddle.nonzero(inputs, self.as_tuple)

x1 = paddle.to_tensor([[1.0, 0.0, 0.0, 2.0], [0.0, 2.0, 0.0, 1.1], [0.0, 0.0, 3.0, 0.0]])
verify_model(Nonzero(), x1, input_shape=[[3, 4]])
verify_model(Nonzero(True), x1, input_shape=[[3, 4]])
x2 = paddle.to_tensor([0, 1, 0, 3])
verify_model(
Nonzero(),
x2,
input_shape=[
[
4,
]
],
)


def test_forward_norm():
class Norm1(nn.Layer):
@paddle.jit.to_static
Expand Down Expand Up @@ -985,17 +1016,17 @@ class Norm7(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.norm(inputs, p=float(1), axis=None, keepdim=False)

class Norm8(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.norm(inputs, p=float(2.0), axis=1, keepdim=False)

class Norm9(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.norm(inputs, p=float(-0.5), axis=[1, 2], keepdim=False)

class Norm10(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
Expand All @@ -1015,31 +1046,6 @@ def forward(self, inputs):
verify_model(Norm10(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_not_equal():
class Not_equal(nn.Layer):
@paddle.jit.to_static
def forward(self, x, y):
output = paddle.not_equal(x, y)
output = paddle.cast(output, "int32")
return output

x_shape = [10]
y_shape = [10]
x_data = paddle.randint(1, 10, x_shape, dtype="int32")
y_data = paddle.randint(1, 10, y_shape, dtype="int32")
x_data_1 = paddle.randint(1, 10, x_shape, dtype="int64")
y_data_1 = paddle.randint(1, 10, y_shape, dtype="int64")
verify_model(Not_equal(), input_data=[x_data, y_data])
verify_model(Not_equal(), input_data=[x_data_1, y_data_1])
# For broadcast
x_shape_1 = [10]
y_shape_1 = [10, 1]
x_data_2 = paddle.rand(x_shape_1, dtype="float32")
y_data_2 = paddle.rand(y_shape_1, dtype="float32")
verify_model(Not_equal(), input_data=[x_data_2, y_data_2])


@tvm.testing.uses_gpu
def test_forward_pool2d():
@paddle.jit.to_static
Expand Down Expand Up @@ -1472,7 +1478,7 @@ def zeros2(inputs):
test_forward_lstm()
test_forward_matmul()
test_forward_multiply()
test_forward_not_equal()
test_forward_nonzero()
test_forward_norm()
test_forward_pool2d()
test_forward_pad()
Expand Down

0 comments on commit 76194aa

Please sign in to comment.