Skip to content

Commit

Permalink
[topi] block sparse dense on cuda (apache#5746)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceruleangu authored and Trevor Morris committed Jun 12, 2020
1 parent f12956b commit e7a3b38
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 21 deletions.
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Expand Up @@ -50,3 +50,4 @@
from .conv3d_ndhwc_tensorcore import *
from .dense_tensorcore import *
from .correlation import *
from .sparse import *
94 changes: 94 additions & 0 deletions topi/python/topi/cuda/sparse.py
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Sparse operators"""
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from ..util import traverse_inline
from .. import nn


@autotvm.register_topi_compute("sparse_dense.cuda")
def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.te.Tensor
2-D with shape [M, K], float32
weight_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
weight_indices : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)
weight_indptr : tvm.te.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)
Returns
-------
output : tvm.te.Tensor
2-D with shape [M, N]
"""
# pylint:disable=unused-argument
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr)


@autotvm.register_topi_schedule("sparse_dense.cuda")
def schedule_sparse_dense(cfg, outs):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
out = s.outputs[0].output(0)
(_, c) = s[y_bsrmm].op.reduce_axis

(m_o, n_o) = s[out].op.axis
s[out].bind(m_o, te.thread_axis("blockIdx.x"))
s[out].bind(n_o, te.thread_axis("blockIdx.y"))
s[y_bsrmm].compute_at(s[out], n_o)

thread_x = te.thread_axis("threadIdx.x")

cfg.define_split("tile_c", c, num_outputs=2)
if cfg.is_fallback:
cfg["tile_c"] = SplitEntity([-1, 8])
_, ci = cfg['tile_c'].apply(s, y_bsrmm, c)

y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
tx = s[y_bsrmm].op.reduce_axis[0]
s[y_bsrmm].bind(tx, thread_x)
s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
s[out].set_store_predicate(thread_x.var.equal(0))

traverse_inline(s, outs[0].op, _callback)
return s
2 changes: 1 addition & 1 deletion topi/python/topi/nn/sparse.py
Expand Up @@ -30,7 +30,7 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
Parameters
----------
x : tvm.te.Tensor
data : tvm.te.Tensor
2-D with shape [M, K], float32
weight_data : tvm.te.Tensor
Expand Down
70 changes: 50 additions & 20 deletions topi/tests/python/test_topi_sparse.py
Expand Up @@ -26,6 +26,12 @@
import time
import scipy.sparse as sp

_sparse_dense_implement = {
"generic": (topi.nn.sparse_dense, topi.generic.schedule_sparse_dense),
"cuda": (topi.cuda.sparse_dense, topi.cuda.schedule_sparse_dense),
"x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense)
}

def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
dtype = 'float32'
Expand Down Expand Up @@ -293,16 +299,28 @@ def test_sparse_dense_bsr():
W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = te.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np),
tvm.nd.array(W_sp_np.data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement)
with tvm.target.create(device):
Y = fcompute(X, W_data, W_indices, W_indptr)
s = fschedule([Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
func(tvm.nd.array(X_np, ctx=ctx),
tvm.nd.array(W_sp_np.data, ctx=ctx),
tvm.nd.array(W_sp_np.indices, ctx=ctx),
tvm.nd.array(W_sp_np.indptr, ctx=ctx),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)

for device in ['llvm', 'cuda']:
check_device(device)

def test_sparse_dense_bsr_randomized():
for _ in range(20):
Expand All @@ -322,16 +340,28 @@ def test_sparse_dense_bsr_randomized():
W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = te.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.nd.array(X_np),
tvm.nd.array(W_sp_np.data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement)
with tvm.target.create(device):
Y = fcompute(X, W_data, W_indices, W_indptr)
s = fschedule([Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx)
func(tvm.nd.array(X_np, ctx=ctx),
tvm.nd.array(W_sp_np.data, ctx=ctx),
tvm.nd.array(W_sp_np.indices, ctx=ctx),
tvm.nd.array(W_sp_np.indptr, ctx=ctx),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)

for device in ['llvm', 'cuda']:
check_device(device)


def test_sparse_dense():
Expand Down

0 comments on commit e7a3b38

Please sign in to comment.