Skip to content

Commit

Permalink
[Vulkan][Topi] Parametrizing additional topi tests, marking vulkan fa…
Browse files Browse the repository at this point in the history
…ilures (apache#8904)

* [Pytest] Fixed TestTargetAutoParametrization in cases where LLVM is disabled.

* [UnitTests][Vulkan] Improved robustness of test_tir_intrin::test_clz

Previously, would fail during build since support for Int64 primitives wasn't
declared in the `"vulkan"` target.  Now, uses `"vulkan -from_device=0"` target
and marks the test as xfail if the current target doesn't support Int64.

* [UnitTest][Topi] Parametrized several unit tests, identify vulkan failures

- Parametrized topi modules
  - test_topi_conv1d_transpose_ncw.py
  - test_topi_conv2d_nhwc.py
  - test_topi_correlation.py
  - test_topi_loss.py
  - test_topi_math.py
  - test_topi_reduce.py
  - test_topi_softmax.py
  - test_topi_sort.py
  - test_topi_unique.py
  - test_topi_vision.py

- Unit Tests fixed

  - `test_topi_loss::test_nll_loss`, failure due to `supports_float64`
    not being passed from the target to the codegen.

- Known Vulkan failures (tracked in apache#8903)

  - test_topi_math.py::test_ewise, ["tan", "erf", "isnan", "isfinite", "isinf"]

    Unimplemented CallNode operations

  - test_topi_reduce.py::test_reduce_map, ["sum", "any", "all"]

    Fails during codegen, unexpected size of data type.

  - test_topi_vision.py::test_proposal

    Marked test_proposal as xfail on vulkan, currently has a type error
    between bool/int8.

  - test_topi_conv1d_transpose_ncw.py::test_conv1d_transpose_ncw

    Incorrect numeric output, a few elements outside of allowed
    tolerance, only occurs on vulkan backend.

  - test_softmax.py::test_softmax

    Marked float64 operations as xfail in vulkan, because GLSL.std.450
    only supports 16/32-bit floats.
  • Loading branch information
Lunderberg authored and ylc committed Sep 29, 2021
1 parent 3552c83 commit 56c4f37
Show file tree
Hide file tree
Showing 13 changed files with 1,035 additions and 1,128 deletions.
3 changes: 3 additions & 0 deletions src/target/spirv/spirv_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) {
if (target->GetAttr<Bool>("supports_float16")) {
supports_float16 = target->GetAttr<Bool>("supports_float16").value();
}
if (target->GetAttr<Bool>("supports_float64")) {
supports_float64 = target->GetAttr<Bool>("supports_float64").value();
}
if (target->GetAttr<Bool>("supports_int8")) {
supports_int8 = target->GetAttr<Bool>("supports_int8").value();
}
Expand Down
161 changes: 89 additions & 72 deletions tests/python/topi/python/test_topi_conv1d_transpose_ncw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,90 +15,107 @@
# specific language governing permissions and limitations
# under the License.
"""Test code for transposed convolution."""
import numpy as np

import itertools
import os

import numpy as np

import tvm
from tvm import te
from tvm import topi
import tvm.testing
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize

from tvm import te, topi
from tvm.topi.utils import get_const_tuple
import tvm.testing

_conv1d_transpose_ncw_implement = {
"generic": (topi.nn.conv1d_transpose_ncw, topi.generic.schedule_conv1d_transpose_ncw),
"gpu": (topi.cuda.conv1d_transpose_ncw, topi.cuda.schedule_conv1d_transpose_ncw),
}


def verify_conv1d_transpose_ncw(
batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding
(
batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding,
output_padding,
) = tvm.testing.parameters(
(1, 3, 224, 32, 5, 1, 0, (0,)),
(1, 3, 224, 32, 7, 1, 2, (0,)),
(1, 3, 224, 32, 5, 2, 1, (0,)),
(1, 3, 224, 32, 5, 2, 1, (1,)),
(1, 3, 224, 32, 5, 2, 0, (0,)),
(1, 32, 32, 128, 5, 1, 0, (0,)),
(1, 32, 32, 128, 5, 2, 1, (0,)),
(1, 1, 1024, 1, 512, 1, 256, (0,)),
(1, 1, 1024, 1, 512, 2, 256, (0,)),
(1, 1, 1024, 1, 512, 5, 256, (0,)),
(1, 1, 1024, 1, 512, 5, 256, (3,)),
(1, 2, 1024, 1, 128, 128, 0, (0,)),
(1, 1, 1024, 2, 128, 128, 0, (0,)),
(1, 1, 1024, 2, 2, 2, 0, (0,)),
(1, 1, 10, 1, 5, 1, (0, 3), (0,)),
(1, 1, 10, 1, 5, 1, (1, 3), (0,)),
(1, 1, 10, 1, 5, 1, (2, 3), (0,)),
(1, 257, 128, 1, 512, 128, 256, (0,)),
)

dtype = tvm.testing.parameter("float32")


@tvm.testing.fixture(cache_return_value=True)
def ref_data(
dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding
):
dtype = "float32"
a_shape = (batch, in_channel, in_size)
w_shape = (in_channel, num_filter, kernel)

a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = tvm.topi.testing.conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np


@tvm.testing.known_failing_targets("vulkan")
def test_conv1d_transpose_ncw(
target,
dev,
ref_data,
dtype,
stride,
padding,
output_padding,
):
in_width = in_size
A = te.placeholder((batch, in_channel, in_width), name="A")
W = te.placeholder((in_channel, num_filter, kernel), name="W")

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv1d_transpose.verify_conv1d_transpose_ncw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = tvm.topi.testing.conv1d_transpose_ncw_python(
a_np, w_np, stride, padding, output_padding
)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np

a_np, w_np, b_np, c_np = get_ref_data()

def check_target(target, dev):
dev = tvm.device(target, 0)
with tvm.target.Target(target):
fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv1d_transpose_ncw_implement)
B = fcompute(A, W, stride, padding, A.dtype, output_padding)
C = topi.nn.relu(B)
s1 = fschedule([B])
s2 = fschedule([C])
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)

func1 = tvm.build(s1, [A, W, B], target)
func2 = tvm.build(s2, [A, W, C], target)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

for target, dev in tvm.testing.enabled_targets():
check_target(target, dev)


@tvm.testing.uses_gpu
def test_conv1d_transpose_ncw():
verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 1, 0, (0,))
verify_conv1d_transpose_ncw(1, 3, 224, 32, 7, 1, 2, (0,))
verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1, (0,))
verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 1, (1,))
verify_conv1d_transpose_ncw(1, 3, 224, 32, 5, 2, 0, (0,))
verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 1, 0, (0,))
verify_conv1d_transpose_ncw(1, 32, 32, 128, 5, 2, 1, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 1, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,))
verify_conv1d_transpose_ncw(1, 2, 1024, 1, 128, 128, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 128, 128, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 2, 2, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,))
verify_conv1d_transpose_ncw(1, 257, 128, 1, 512, 128, 256, (0,))

a_np, w_np, b_np, c_np = ref_data

A = te.placeholder(a_np.shape, name="A", dtype=dtype)
W = te.placeholder(w_np.shape, name="W", dtype=dtype)

with tvm.target.Target(target):
fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv1d_transpose_ncw_implement)
B = fcompute(A, W, stride, padding, A.dtype, output_padding)
C = topi.nn.relu(B)
s1 = fschedule([B])
s2 = fschedule([C])
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)

func1 = tvm.build(s1, [A, W, B], target)
func2 = tvm.build(s2, [A, W, C], target)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)


if __name__ == "__main__":
test_conv1d_transpose_ncw()
sys.exit(pytest.main(sys.argv))
92 changes: 43 additions & 49 deletions tests/python/topi/python/test_topi_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


_conv2d_nhwc_implement = {
"llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
"cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc),
"generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
"gpu": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc),
"cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc),
"arm_cpu": (
topi.arm_cpu.conv2d_nhwc_spatial_pack,
Expand All @@ -45,61 +45,55 @@
"hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc),
}

dtype = tvm.testing.parameter("float32")

def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
in_height = in_width = in_size

A = te.placeholder((batch, in_height, in_width, in_channel), name="A")
W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W")
batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters(
(1, 256, 32, 256, 3, 1, "SAME", 1),
(4, 128, 16, 128, 5, 2, "SAME", 1),
(4, 128, 16, 256, 5, 2, "SAME", 1),
(1, 256, 32, 256, 3, 1, "VALID", 1),
(1, 256, 32, 256, 3, 1, "VALID", 1),
(4, 128, 16, 128, 5, 2, "VALID", 1),
(4, 128, 16, 256, 5, 2, "VALID", 1),
(1, 128, 16, 256, 3, 2, (0, 0, 1, 1), 1),
(1, 128, 16, 256, 3, 2, (1, 1, 2, 2), 1),
(1, 128, 16, 128, 5, 2, (3, 3, 2, 2), 1),
(1, 128, 16, 256, 3, 2, (0, 1, 2, 3), 1),
(1, 256, 32, 256, 3, 1, "SAME", 2),
(1, 256, 32, 256, 3, 1, (1, 1, 2, 2), 2),
)

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d_nhwc.verify_nhwc.v2")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
return a_np, w_np, b_np
@tvm.testing.fixture(cache_return_value=True)
def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation):
in_height = in_width = in_size
a_shape = (batch, in_height, in_width, in_channel)
w_shape = (kernel, kernel, in_channel, num_filter)

a_np, w_np, b_np = get_ref_data()
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
return a_np, w_np, b_np

def check_device(target, dev):
print("Running on target: %s" % target)
with tvm.target.Target(target):
fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv2d_nhwc_implement)
B = fcompute(A, W, stride, padding, dilation, dtype)
s = fschedule([B])
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
func = tvm.build(s, [A, W, B], target)
func(a, w, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)

for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)
def test_conv2d_nhwc(target, dev, ref_data, dtype, stride, padding, dilation):
a_np, w_np, b_np = ref_data

A = te.placeholder(a_np.shape, name="A", dtype=dtype)
W = te.placeholder(w_np.shape, name="W", dtype=dtype)

@tvm.testing.uses_gpu
def test_conv2d_nhwc():
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID")
verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 0, 1, 1))
verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (1, 1, 2, 2))
verify_conv2d_nhwc(1, 128, 16, 128, 5, 2, (3, 3, 2, 2))
verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 1, 2, 3))
# dilation = 2
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, (1, 1, 2, 2), dilation=2)
with tvm.target.Target(target):
fcompute, fschedule = tvm.topi.testing.dispatch(target, _conv2d_nhwc_implement)
B = fcompute(A, W, stride, padding, dilation, dtype)
s = fschedule([B])
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
func = tvm.build(s, [A, W, B], target)
func(a, w, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)


if __name__ == "__main__":
test_conv2d_nhwc()
sys.exit(pytest.main(sys.argv))
Loading

0 comments on commit 56c4f37

Please sign in to comment.