Skip to content

Commit

Permalink
[TOPI][ARM] Improve injective schedule (apache#2801)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 authored and wweic committed Mar 20, 2019
1 parent a3b703b commit e902d18
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions topi/python/topi/arm_cpu/__init__.py
Expand Up @@ -4,3 +4,4 @@
from . import depthwise_conv2d
from . import conv2d_transpose
from . import bitserial_conv2d
from . import injective
37 changes: 37 additions & 0 deletions topi/python/topi/arm_cpu/injective.py
@@ -0,0 +1,37 @@
# pylint: disable=invalid-name, unused-variable
"""Schedule for pooling operators"""
import tvm
from .. import generic

@generic.schedule_injective.register(["arm_cpu"])
def schedule_injective(outs):
"""ARM CPU schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
if list(s[x].op.axis):
# do not vectorize for broadcast
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 4:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 2:
s[x].parallel(s[x].op.axis[0])
return s
4 changes: 3 additions & 1 deletion topi/tests/python/test_topi_resize.py
Expand Up @@ -5,6 +5,8 @@
import topi.testing
import math

from common import get_all_backend

def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False):

if layout == 'NCHW':
Expand Down Expand Up @@ -40,7 +42,7 @@ def check_device(device):

tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)

for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
for device in get_all_backend():
check_device(device)

def test_resize():
Expand Down
4 changes: 3 additions & 1 deletion topi/tests/python/test_topi_upsampling.py
Expand Up @@ -5,6 +5,8 @@
import topi.testing
import math

from common import get_all_backend

def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"):


Expand Down Expand Up @@ -45,7 +47,7 @@ def check_device(device):

tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
for device in get_all_backend():
check_device(device)

def test_upsampling():
Expand Down

0 comments on commit e902d18

Please sign in to comment.