Skip to content

Commit

Permalink
[NVPTX] libdevice support, enable NVPTX backend in topi tests (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and sergei-mironov committed Aug 8, 2018
1 parent 1ee51b6 commit cea2b43
Show file tree
Hide file tree
Showing 22 changed files with 121 additions and 32 deletions.
2 changes: 1 addition & 1 deletion apps/benchmark/gpu_imagenet_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
choices=['resnet', 'mobilenet'],
help="The model type.")
parser.add_argument('--target', type=str, required=True,
choices=['cuda', 'rocm', 'opencl', 'metal'],
choices=['cuda', 'rocm', 'opencl', 'metal', 'nvptx'],
help="Compilation target.")
parser.add_argument('--opt-level', type=int, default=1, help="Level of optimization.")
parser.add_argument('--num-iter', type=int, default=1000, help="Number of iteration during benchmark.")
Expand Down
17 changes: 15 additions & 2 deletions src/codegen/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ class CodeGenNVPTX : public CodeGenLLVM {
// Additional optimization hook to tweak the builder.
}

void Optimize() final {
for (auto& f : *module_) {
auto fname = static_cast<std::string>(f.getName());
if (fname.substr(0, 4) != "__nv") continue;
// This is to strip off unused __nv_* functions from the final module
// The one that is actually used will be inlined at call site
// Adapted from Halide's runtime linker
if (!f.isDeclaration() && !f.hasFnAttribute(llvm::Attribute::NoInline)) {
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
}
}
CodeGenLLVM::Optimize();
}

protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
Expand Down Expand Up @@ -179,8 +193,7 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
}
mlib->setTargetTriple(tm->getTargetTriple().str());
mlib->setDataLayout(tm->createDataLayout());
// TODO(tqchen) libdevice linking not yet working.
// cg->AddLinkModule(std::move(mlib));
cg->AddLinkModule(std::move(mlib));
}
}
std::unique_ptr<llvm::Module> module = cg->Finish();
Expand Down
64 changes: 64 additions & 0 deletions src/codegen/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_nvptx.cc
*/
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
#include <tvm/expr.h>
#include <tvm/api_registry.h>
#include <sstream>

namespace tvm {
namespace codegen {

inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
using namespace ir;
const Call* call = e.as<Call>();
CHECK(call != nullptr);
CHECK(call->type.bits() == 32 || call->type.bits() == 64) << "Only support float32 or float64.";
std::ostringstream intrinsic_name;
intrinsic_name << "__nv_" << call->name;
if (call->type.bits() == 32) intrinsic_name << "f";
*rv = Call::make(call->type, intrinsic_name.str(), call->args,
Call::PureExtern);
}

namespace llvm {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
.set_body(DispatchExternLibDevice);

} // namespace llvm
} // namespace codegen
} // namespace tvm

#endif // LLVM_VERSION
2 changes: 1 addition & 1 deletion src/codegen/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2017 by Contributors
* \file intrin_rule_llvm.cc
* \file intrin_rule_rocm.cc
*/
#ifdef TVM_LLVM_VERSION

Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _schedule(temp, Filter, DepthwiseConv2d):
# num_thread here could be 728, it is larger than cuda.max_num_threads
num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value
target = tvm.target.current_target()
if target and target.target_name != "cuda":
if target and (target.target_name not in ["cuda", "nvptx"]):
num_thread = target.max_num_threads
xoc, xic = s[Output].split(c, factor=num_thread)
s[Output].reorder(xoc, b, h, w, xic)
Expand Down
2 changes: 2 additions & 0 deletions topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def check_device(device):
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("nvptx")


def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
Expand Down Expand Up @@ -85,6 +86,7 @@ def check_device(device):
check_device("cuda")
check_device("metal")
check_device("rocm")
check_device("nvptx")

def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_device(device):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)


Expand Down
5 changes: 3 additions & 2 deletions topi/tests/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def check_device(device):
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
no_unroll_explicit = device in ["cuda", "nvptx", "rocm"]
with tvm.build_config(auto_unroll_max_step=1400,
unroll_explicit=(device != "cuda")):
unroll_explicit=not no_unroll_explicit):
func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)


Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def check_device(device):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)


Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def check_device(device):
f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)

def test_dense():
Expand Down
2 changes: 2 additions & 0 deletions topi/tests/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_ref_data():
check_device("metal")
check_device("rocm")
check_device("vulkan")
check_device("nvptx")


def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
Expand Down Expand Up @@ -184,6 +185,7 @@ def get_ref_data():
check_device("metal")
check_device("rocm")
check_device("vulkan")
check_device("nvptx")

def test_depthwise_conv2d():
print("testing nchw")
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_depthwise_conv2d_back_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_ref_data():
check_device("metal")
check_device("rocm")
check_device("vulkan")
check_device("nvptx")

def test_topi_depthwise_conv2d_backward_input_nhwc():
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_ref_data():
check_device("metal")
check_device("rocm")
check_device("vulkan")
check_device("nvptx")

def test_topi_depthwise_conv2d_backward_weight_nhwc():
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check_device(device):
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)

def test_l2_normalize():
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_lrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def check_device(device):
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)

def test_lrn():
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_device(device):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx']:
check_device(device)


Expand Down
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def check_device(device):
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)

def test_pool():
Expand Down Expand Up @@ -104,7 +104,7 @@ def check_device(device):
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)

def test_global_pool():
Expand Down
18 changes: 11 additions & 7 deletions topi/tests/python/test_topi_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ def _my_npy_argmin(arr, axis, keepdims):
return arr.argmin(axis=axis).reshape(out_shape)


def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32"):
# Build the logic and compile the function
dat_dtype = "float32"
A = tvm.placeholder(shape=in_shape, name="A", dtype=dat_dtype)
A = tvm.placeholder(shape=in_shape, name="A", dtype=dtype)
A1 = topi.sqrt(topi.exp(A))
out_dtype = "float32"
out_dtype = dtype
if type == "sum":
B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "max":
Expand All @@ -57,8 +56,8 @@ def check_device(device):

foo = tvm.build(s, [A, B], device, name=type)
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
in_npy = np.random.uniform(size=in_shape).astype(dtype)
in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max":
Expand Down Expand Up @@ -91,7 +90,7 @@ def check_device(device):
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
else:
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm", "vulkan"]:
for device in ["cuda", "opencl", "metal", "llvm", "rocm", "vulkan", "nvptx"]:
check_device(device)


Expand Down Expand Up @@ -128,6 +127,11 @@ def test_reduce_map():
axis=None,
keepdims=False,
type="sum")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(1, 2, 3),
keepdims=True,
type="sum",
dtype="float64")

if __name__ == "__main__":
test_reduce_map()
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def check_device(device):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)


Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def check_device(device):

np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)

for device in ['llvm', 'cuda', 'vulkan']:
for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
check_device(device)

def test_resize():
Expand Down
15 changes: 8 additions & 7 deletions topi/tests/python/test_topi_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import logging
from topi.util import get_const_tuple

def verify_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
def verify_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
Expand All @@ -32,16 +32,16 @@ def check_device(device):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)

def test_softmax():
verify_softmax(32, 10)
verify_softmax(3, 4)
verify_softmax(32, 10, "float64")


def verify_log_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
def verify_log_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.log_softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
Expand All @@ -63,13 +63,14 @@ def check_device(device):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ["cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["cuda", "opencl", "metal", "rocm", "vulkan", "nvptx"]:
check_device(device)


def test_log_softmax():
verify_log_softmax(32, 10)
verify_log_softmax(3, 4)
verify_log_softmax(32, 10, "float64")

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def check_device(device):

np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['llvm', 'cuda', 'vulkan']:
for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
check_device(device)

def test_upsampling():
Expand Down

0 comments on commit cea2b43

Please sign in to comment.