Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
merge with 38f7c55

compiles on GPU

update check alloc:

Checkpoint. Pass elem-sum gpu test

bug fix for copyfromto. sparse sgd test pass on gpu

inefficient implementation for csr copy

update submodule

fix lint

Simple bind with infer storage type (apache#32)

* Symbol binding for sparse tensor development. (apache#31)

* Initial checkin

* Add init functions for simple bind in graph_executor

* Add simple_bind c_api

* Add simple bind c-api

* Assign zeros to in_args, arg_grads, and aux_states

* Add simple_bind2 python interface

* Fix python interface bugs

* Interface changes

* Fix

* Fix core dump

* Add bind_ith_exec c_api

* Change simple_bind2

* Fix seg fault

* Finish simple_bind

* Change _bind_ith_exec

* Refactor simple_bind initialization flow for bind

* Consolidate bind and simple_bind graph init flow

* Fix bug

* Clean up

* Add comments

* Clean up

* Clean up

* Minor correction

* Rename APIs in graph executor

* Refactor

* Rebase

* Delete deprecated functions

* Move more front-end work to backend

* Bug fix

* Fix failed tests

* Minor fix

* Fix lint

* Fix lint

* Revert unnecessary changes

* Revert

* Revert

* Clean up

* Fix lint

Conflicts:
	python/mxnet/symbol.py
	src/executor/graph_executor.cc

* Add inferstorage to graph executor

* re-enable tests for sparse embedding with simple_bind

* type switch fix in sparse embedding"
;

change `default` to `default_storage` for cast storage op (apache#33)

* change default to default_storage

* disable cpp test build temporarily

attempt to fix windows build error, and fix lint (apache#34)

update nnvm submodule (apache#37)

Scipy build (apache#38)

* update nnvm submodule

* add scipy pip install for dockerfile

Python3 unit tests (apache#39)

* change xrange to range for python3 compatiblity"

* remove more xrange from tests

replace long with int for python3 (apache#40)

fix the rest of TShape constructor errors (apache#41)

fix lint (apache#42)

fix wrong usage of mshadow::Shape1" (apache#43)

implementation for Csr slice on cpu (apache#36)

* CPU implementation for CSR

remove seg_len from csr slice

add some docs for slice csr

change indptr, values, etc to be private member

bug fix in sparse embedding

update nnvm submoduel

fix lint

update unit test for sparse nd"

* add const for SliceCsrIndPtr kernel

Fix sparse dot according to the new RSP definition (apache#35)

* Fix csr dot dns

* Fix sparse dot

* Add fallback and test cases for dot(csr, dns)=dns

* Add int type switch

* Fix

* Fix

* Fix

update mshadow submodule (apache#44)

Fix dns to rsp (apache#46)

fix lint (apache#47)

add runtime storage fallback detection" (apache#48)

* add runtime storage fallback detection"

* replace cast storage ex with cast storage impl

Fm example (apache#45)

* update csr slice logic to avoid confusion. add more exmaples.

* add hint to module.update

* more testcases(fallback) for sparse_nd

* add to_csr() and to_rsp() method. More unit test (fallback now)

* add fm test. fix lint

* register sparse sgd under Optim.SGD

* update dmlc-core submoduel

* change indptr to _indptr temporarily. add const ref to fname

fix lint

fix lint; (apache#51)

Guard gpu cast storage (apache#50)

* Clean up

* Fix typo

Rearrange unit test files (apache#52)

fix lint. add scipy for python_test. fix scipy.sparse import error. fix truediv for python3

fix travis test (apache#54)

* remove pyc files

* add verbose for travis nosetests

cleanup some testing code and enums (apache#57)

* update Makefile

* refactor test_sparse_operator

* change `default_storage` back to `default`

* remove unused cpp tests

port libsvm parser to mxnet as libsvm iter (apache#55)

* copied csv iter to libsvm iter

test

libsvm iter draft

handle round batch == false for csr batch loader

code refactoring

add get stype, shape interface to iiter

separate class for sparse iter

add missing file

fix mem corruption'

rename variables

add comments

also read label from libsvm

add test. update docs. update submodule

Conflicts:
	python/mxnet/sparse_ndarray.py

* update submodule

* fix lint

* update test

* revert naming change

add benchmark scritp for dot (apache#59)

* add benchmark scritp for dot

add gpu option for bench

add get_data funciton for benchmark

print t_sparse, too;

add comment

change nnz to dnesity

add backward

* add comment

update fm test (apache#62)

introduce CSRNDarray and rowsparseNDarray to python frontend api (apache#58)

* introduce CSRNDarray and rowsparseNDarray to python frontend api

* temporarily disable fm_module test

fix lint (apache#64)

fix typo. disable libsvm io test (apache#65)

Improve dot (apache#61)

* Init checkin

* Fix

* Adjust dot parallelization methods

* Set num_omp_threads for benchmark from command line

* Fix omp thread number

* Clean up

* Add scipy as dot baseline

* Fix format

sparse_retain op (apache#66)

* Initial checkin

* Fix bugs

* Add unit test for sparse_retain

* Add example and modify test

add storage cast for outputs that have non-default storage (apache#67)

fix gpu build (apache#69)

Fix test_sparse_retain python3 issue (apache#68)

revert nnvm version
  • Loading branch information
eric-haibin-lin committed Jun 10, 2017
1 parent d75ef8e commit f98912b
Show file tree
Hide file tree
Showing 82 changed files with 5,910 additions and 375 deletions.
6 changes: 3 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ del /Q *.7z
// Python unittest for CPU
def python_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train"
}
}

// GPU test has two parts. 1) run unittest on GPU, 2) compare the results on
// both CPU and GPU
def python_gpu_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu"
}
}
Expand Down
191 changes: 191 additions & 0 deletions benchmark/python/sparse_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import ctypes

from mxnet.test_utils import *
import scipy.sparse as sp
import os
import time
import argparse

from mxnet.base import check_call, _LIB

parser = argparse.ArgumentParser(description="Benchmark sparse operators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
args = parser.parse_args()


def get_avazu(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists('avazu-app.t')):
import urllib
zippath = os.path.join(data_dir, "avazu-app.t.bz2")
url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2"
urllib.urlretrieve(url, zippath)
# decompress
os.system("bzip2 -d avazu-app.t.bz2")
os.chdir("..")


def test_dot_real():
def get_iter(path, data_shape, batch_size):
data_train = mx.io.LibSVMIter(data_libsvm=path,
data_shape=data_shape,
batch_size=batch_size)
data_iter = iter(data_train)
return data_iter
data_dir = os.path.join(os.getcwd(), 'data')
get_avazu(data_dir)
path = os.path.join(data_dir, 'avazu-app.t')
# TODO(haibin) get file size automatically
size = 336490781 >> 20

# model
batch_size = 512
feature_dim = 1000000
data_shape = (feature_dim, )
train_iter = get_iter(path, data_shape, batch_size)

k = 500
weight = mx.nd.random_uniform(low=0, high=1, shape=(feature_dim, k))
weight.wait_to_read()

# start workload
start = time.time()
results = []
num_batch = 0
for batch in train_iter:
data = train_iter.getdata()
results.append(mx.nd.dot(data, weight))
num_batch += 1
for result in results:
result.wait_to_read()

end = time.time()
cost = end - start
print(size / cost, cost, num_batch, num_batch / cost)


def test_dot_synthetic():
"""benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density.
`t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost
of dot(dns, dns), with the same matrix except that it is in default storage type.
"""
def measure_cost_forward_baseline(repeat, dot, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(lhs, rhs)
end = time.time()
diff = end - start
return diff / repeat

def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(transpose(lhs), rhs)
end = time.time()
diff = end -start
return diff / repeat

def measure_cost(repeat, f, *args, **kwargs):
# start bench
start = time.time()
results = []
for i in range(repeat):
results.append(f(*args, **kwargs))
for result in results:
result.wait_to_read()
end = time.time()
diff = end - start
return diff / repeat

def bench_dot_forward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.to_dense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns)
costs.append(cost / repeat)
ratio = costs[1] / costs[0]

costs_baseline = []
cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[1] / costs_baseline[0]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f"
print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio,
costs_baseline[1], costs_baseline[0], ratio_baseline))

def bench_dot_backward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.to_dense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True)
costs.append(cost)
ratio = costs[1] / costs[0]

costs_baseline = []
cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[1] / costs_baseline[0]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f"
print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio,
costs_baseline[1], costs_baseline[0], ratio_baseline))

print("A = sparse NDArray of shape(m, k)")
print("B = dense NDArray of shape(k, n)")
print("dot_forward\tdot(csr, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense'
'\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense')

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
# TODO(haibin) make these runtime options
m = 512
k = [50000, 100000]
n = [50, 100]
density = [0.05, 0.02, 0.01, 0.005, 0.001]
num_repeat = 10
# contexts = [mx.cpu(), mx.gpu(0)]
contexts = [mx.cpu()]
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat)

print("dot_backward\tdot(csr.T, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense'
'\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense')
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat)

if __name__ == "__main__":
test_dot_real()
test_dot_synthetic()
94 changes: 94 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int delay_alloc,
int dtype,
NDArrayHandle *out);


/*!
* \brief create an empty sparse NDArray with specified shape and data type
* \param storage_type the storage type of the ndarray
* \param shape the pointer to the shape
* \param ndim the dimension of the shape
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
* \param delay_alloc whether to delay allocation until
* the narray is first mutated
* \param dtype data type of created array
* \param num_aux the number of aux data to support this ndarray
* \param aux_type data type of the aux data for the created array
* \param aux_ndims the dimension of the shapes of aux data
* \param aux_shape the shapes of aux data
* \param out the returning handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
Expand Down Expand Up @@ -358,6 +390,19 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out);

/*!
* \brief Slice the NDArray with non-default storage along axis 0.
* \param handle the handle to the NDArray
* \param slice_begin The beginning index of slice
* \param slice_end The ending index of slice
* \param out The NDArrayHandle of sliced NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySliceEx(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle out);
/*!
* \brief Index the NDArray along axis 0.
* \param handle the handle to the NDArray
Expand All @@ -368,6 +413,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out);

/*!
* \brief get the storage type of the array
*/
MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
int *out_storage_type);

/*!
* \brief Reshape the NDArray.
* \param handle the handle to the narray
Expand Down Expand Up @@ -406,6 +458,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype);

/*!
* \brief get the type of the ith aux data in NDArray
* \param handle the handle to the narray
* \param i the index of the aux data
* \param out_type pointer holder to get type of aux data
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type);

// Get the ith aux data blob wrapped in an NDArray
MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out);

// Get the data blob wrapped in an NDArray
MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle,
NDArrayHandle *out);
/*!
* \brief get the context of the NDArray
* \param handle the handle to the narray
Expand Down Expand Up @@ -1003,6 +1075,25 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete);




/*!
* \brief infer storage type of unknown input types given the known one.
*/
MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_storage_type_data,
mx_uint *in_storage_type_size,
const int **in_storage_type_data,
mx_uint *out_storage_type_size,
const int **out_storage_type_data,
mx_uint *aux_storage_type_size,
const int **aux_storage_type_data,
int *complete);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down Expand Up @@ -1167,6 +1258,9 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Executor {
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class IIterator : public dmlc::DataIter<DType> {
}
}; // class IIterator

/*!
* \brief iterator type
* \param DType data type
*/
template<typename DType>
class SparseIIterator : public IIterator<DType> {
public:
/*! \brief storage type of the data or label */
virtual const NDArrayStorageType GetStorageType(bool is_data) const = 0;
/*! \brief shape of the data or label */
virtual const TShape GetShape(bool is_data) const = 0;
}; // class SparseIIterator

/*! \brief a single data instance */
struct DataInst {
/*! \brief unique id for instance */
Expand Down
Loading

0 comments on commit f98912b

Please sign in to comment.