diff --git a/Jenkinsfile b/Jenkinsfile index df39672c5ed2..e41cb7217de4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -215,9 +215,9 @@ del /Q *.7z // Python unittest for CPU def python_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest" - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train" } } @@ -225,7 +225,7 @@ def python_ut(docker_type) { // both CPU and GPU def python_gpu_ut(docker_type) { timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu" + sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu" sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu" } } diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse_op.py new file mode 100644 index 000000000000..0aef3bc3ae31 --- /dev/null +++ b/benchmark/python/sparse_op.py @@ -0,0 +1,191 @@ +import ctypes + +from mxnet.test_utils import * +import scipy.sparse as sp +import os +import time +import argparse + +from mxnet.base import check_call, _LIB + +parser = argparse.ArgumentParser(description="Benchmark sparse operators", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet') +args = parser.parse_args() + + +def get_avazu(data_dir): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists('avazu-app.t')): + import urllib + zippath = os.path.join(data_dir, "avazu-app.t.bz2") + url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2" + urllib.urlretrieve(url, zippath) + # decompress + os.system("bzip2 -d avazu-app.t.bz2") + os.chdir("..") + + +def test_dot_real(): + def get_iter(path, data_shape, batch_size): + data_train = mx.io.LibSVMIter(data_libsvm=path, + data_shape=data_shape, + batch_size=batch_size) + data_iter = iter(data_train) + return data_iter + data_dir = os.path.join(os.getcwd(), 'data') + get_avazu(data_dir) + path = os.path.join(data_dir, 'avazu-app.t') + # TODO(haibin) get file size automatically + size = 336490781 >> 20 + + # model + batch_size = 512 + feature_dim = 1000000 + data_shape = (feature_dim, ) + train_iter = get_iter(path, data_shape, batch_size) + + k = 500 + weight = mx.nd.random_uniform(low=0, high=1, shape=(feature_dim, k)) + weight.wait_to_read() + + # start workload + start = time.time() + results = [] + num_batch = 0 + for batch in train_iter: + data = train_iter.getdata() + results.append(mx.nd.dot(data, weight)) + num_batch += 1 + for result in results: + result.wait_to_read() + + end = time.time() + cost = end - start + print(size / cost, cost, num_batch, num_batch / cost) + + +def test_dot_synthetic(): + """benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density. + `t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost + of dot(dns, dns), with the same matrix except that it is in default storage type. + """ + def measure_cost_forward_baseline(repeat, dot, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(lhs, rhs) + end = time.time() + diff = end - start + return diff / repeat + + def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(transpose(lhs), rhs) + end = time.time() + diff = end -start + return diff / repeat + + def measure_cost(repeat, f, *args, **kwargs): + # start bench + start = time.time() + results = [] + for i in range(repeat): + results.append(f(*args, **kwargs)) + for result in results: + result.wait_to_read() + end = time.time() + diff = end - start + return diff / repeat + + def bench_dot_forward(m, k, n, density, ctx, repeat): + set_default_context(ctx) + dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx) + data_shape = (m, k) + csr_data = rand_ndarray(data_shape, 'csr', density) + dns_data = csr_data.to_dense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy + lhs_dns_np = lhs_csr_sp.todense() + + data = [dns_data, csr_data] + costs = [] + for d in data: + dns.wait_to_read() + d.wait_to_read() + cost = measure_cost(repeat, mx.nd.dot, d, dns) + costs.append(cost / repeat) + ratio = costs[1] / costs[0] + + costs_baseline = [] + cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[1] / costs_baseline[0] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f" + print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio, + costs_baseline[1], costs_baseline[0], ratio_baseline)) + + def bench_dot_backward(m, k, n, density, ctx, repeat): + set_default_context(ctx) + dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx) + data_shape = (m, k) + csr_data = rand_ndarray(data_shape, 'csr', density) + dns_data = csr_data.to_dense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) + lhs_dns_np = lhs_csr_sp.todense() + + data = [dns_data, csr_data] + costs = [] + for d in data: + dns.wait_to_read() + d.wait_to_read() + cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True) + costs.append(cost) + ratio = costs[1] / costs[0] + + costs_baseline = [] + cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[1] / costs_baseline[0] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f" + print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio, + costs_baseline[1], costs_baseline[0], ratio_baseline)) + + print("A = sparse NDArray of shape(m, k)") + print("B = dense NDArray of shape(k, n)") + print("dot_forward\tdot(csr, dns)") + print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense' + '\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense') + + check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads))) + # TODO(haibin) make these runtime options + m = 512 + k = [50000, 100000] + n = [50, 100] + density = [0.05, 0.02, 0.01, 0.005, 0.001] + num_repeat = 10 + # contexts = [mx.cpu(), mx.gpu(0)] + contexts = [mx.cpu()] + for i in range(2): + for ctx in contexts: + for den in density: + bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat) + + print("dot_backward\tdot(csr.T, dns)") + print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense' + '\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense') + for i in range(2): + for ctx in contexts: + for den in density: + bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat) + +if __name__ == "__main__": + test_dot_real() + test_dot_synthetic() diff --git a/dmlc-core b/dmlc-core index a6c5701219e6..3f919c0d850c 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a6c5701219e635fea808d264aefc5b03c3aec314 +Subproject commit 3f919c0d850cab959aada246dcf305c9b6ab5a7d diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 90270f776456..da08a8ff76f3 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -246,6 +246,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int delay_alloc, int dtype, NDArrayHandle *out); + + +/*! + * \brief create an empty sparse NDArray with specified shape and data type + * \param storage_type the storage type of the ndarray + * \param shape the pointer to the shape + * \param ndim the dimension of the shape + * \param dev_type device type, specify device we want to take + * \param dev_id the device id of the specific device + * \param delay_alloc whether to delay allocation until + * the narray is first mutated + * \param dtype data type of created array + * \param num_aux the number of aux data to support this ndarray + * \param aux_type data type of the aux data for the created array + * \param aux_ndims the dimension of the shapes of aux data + * \param aux_shape the shapes of aux data + * \param out the returning handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out); + /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes @@ -358,6 +390,19 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out); + +/*! + * \brief Slice the NDArray with non-default storage along axis 0. + * \param handle the handle to the NDArray + * \param slice_begin The beginning index of slice + * \param slice_end The ending index of slice + * \param out The NDArrayHandle of sliced NDArray + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArraySliceEx(NDArrayHandle handle, + mx_uint slice_begin, + mx_uint slice_end, + NDArrayHandle out); /*! * \brief Index the NDArray along axis 0. * \param handle the handle to the NDArray @@ -368,6 +413,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); + +/*! + * \brief get the storage type of the array + */ +MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type); + /*! * \brief Reshape the NDArray. * \param handle the handle to the narray @@ -406,6 +458,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype); + +/*! + * \brief get the type of the ith aux data in NDArray + * \param handle the handle to the narray + * \param i the index of the aux data + * \param out_type pointer holder to get type of aux data + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type); + +// Get the ith aux data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out); + +// Get the data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out); /*! * \brief get the context of the NDArray * \param handle the handle to the narray @@ -1003,6 +1075,25 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, mx_uint *aux_type_size, const int **aux_type_data, int *complete); + + + + +/*! + * \brief infer storage type of unknown input types given the known one. + */ +MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_storage_type_data, + mx_uint *in_storage_type_size, + const int **in_storage_type_data, + mx_uint *out_storage_type_size, + const int **out_storage_type_data, + mx_uint *aux_storage_type_size, + const int **aux_storage_type_data, + int *complete); + //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- @@ -1167,6 +1258,9 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, const mx_uint num_provided_arg_dtypes, const char** provided_arg_dtype_names, const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, int* shared_buffer_len, diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 40bd60f5f405..5856b87cf859 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -115,6 +115,7 @@ class Executor { const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, std::vector* in_args, diff --git a/include/mxnet/io.h b/include/mxnet/io.h index b4429a951920..2d90f500f7c8 100644 --- a/include/mxnet/io.h +++ b/include/mxnet/io.h @@ -44,6 +44,19 @@ class IIterator : public dmlc::DataIter { } }; // class IIterator +/*! + * \brief iterator type + * \param DType data type + */ +template +class SparseIIterator : public IIterator { + public: + /*! \brief storage type of the data or label */ + virtual const NDArrayStorageType GetStorageType(bool is_data) const = 0; + /*! \brief shape of the data or label */ + virtual const TShape GetShape(bool is_data) const = 0; +}; // class SparseIIterator + /*! \brief a single data instance */ struct DataInst { /*! \brief unique id for instance */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 504fd5e7676e..7c080279d5f5 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -28,8 +28,22 @@ #endif namespace mxnet { +// forward declarations +class NDArray; + +namespace op { +template +void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst); + +template +void CastStorageComputeImpl(mshadow::Stream *s, const NDArray& input, const NDArray& output); +}; + +namespace ndarray { +template +void Copy(const TBlob &from, TBlob *to, Context from_ctx, Context to_ctx, RunContext ctx); +}; -// forward declaration namespace autograd { class AGNode; @@ -53,6 +67,27 @@ class AGNodeEntry { class AutogradRuntime; } // namespace autograd +// enum for storage types +#define CSR_IND_PTR_TYPE mshadow::kInt32 +#define CSR_IDX_DTYPE mshadow::kInt32 +#define ROW_SPARSE_IDX_TYPE mshadow::kInt32 +// FIXME int64_t is not available mshadow +namespace csr { +enum CSRAuxType {kIndPtr, kIdx}; +} + +namespace rowsparse { +enum RowSparseAuxType {kIdx}; +} + +enum NDArrayStorageType { + kUndefinedStorage = -1, // undefined storage + kDefaultStorage, // dense + kRowSparseStorage, // row sparse + kCSRStorage, // csr +}; + + /*! * \brief ndarray interface */ @@ -73,10 +108,55 @@ class NDArray { */ NDArray(const TShape &shape, Context ctx, bool delay_alloc = false, int dtype = mshadow::default_type_flag) - : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc, dtype)), + : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)), shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); +#endif + } + /*! \brief constructor for NDArray with storage type + */ + NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx, + bool delay_alloc = true, int dtype = mshadow::default_type_flag, + std::vector aux_types = {}, std::vector aux_shapes = {}, + TShape storage_shape = TShape(mshadow::Shape1(0))) + : shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { + // Assign default aux types if not given + if (aux_types.size() == 0) { + if (stype == kRowSparseStorage) { + aux_types = {ROW_SPARSE_IDX_TYPE}; + } else if (stype == kCSRStorage) { + aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + // Assign default shapes if not given + // unknown shapes are intialized as {0} such that Size() would return 0 + if (aux_shapes.size() == 0) { + if (stype == kRowSparseStorage) { + aux_shapes = {TShape(mshadow::Shape1(0))}; + } else if (stype == kCSRStorage) { + // aux shapes for indptr and indices + aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + if (storage_shape.Size() == 0) { + if (stype == kRowSparseStorage) { + storage_shape = shape; + storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; + } else if (stype == kCSRStorage) { + storage_shape = aux_shapes[csr::kIdx]; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc, + dtype, aux_types, aux_shapes); +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = std::make_shared(); #endif } /*! @@ -85,28 +165,94 @@ class NDArray { * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. */ - NDArray(const TBlob &data, int dev_id) - : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), + NDArray(const TBlob &data, int dev_id, Engine::VarHandle shared_var = nullptr) + : ptr_(std::make_shared(data, dev_id, shared_var)), shape_(data.shape_), dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); #endif } + /*! - * \return the shape of current NDArray + * \return the shape of current NDArray. */ inline const TShape& shape() const { return shape_; } + /*! + * \return the shape of underlying chunk which stores the NDArray values. + * For default storage, it is the same as shape(). For row-sparse storage, it is the shape of + * the tensor which stores the non-zero values. + */ + inline const TShape &storage_shape() const { + CHECK(ptr_ != nullptr); + return ptr_->storage_shape; + } + + /*! + * \brief For sparse operations, the storage shape is an estimated value + * in the beginning for allocating enough capacity for the final result. + * After the operation is done, the exact size of the shape is known + * and need to be reset using this function. For example, adding + * two CSRs with nnz1 and nnz2 as their numbers of non-zero values, respectively, + * would allocate the array of size nnz1+nnz2 first and get the final + * nnz that is smaller than nnz1+nnz2. Therefore, the storage shape's size + * needs to be shrunk from nnz1+nnz2 to nnz. + */ + inline void SetStorageShape(const TShape& sshape) { + CHECK(storage_type() != kDefaultStorage); + ptr_->storage_shape = sshape; + } + + /*! + * \return the shape of aux data at ith index. If it doesn't exist, return an empty one. + */ + inline const TShape aux_shape(size_t i) const { + CHECK(storage_type() != kDefaultStorage); + return ptr_->aux_shapes[i]; + } + + /*! + * \brief For a sparse operation on a csr matrix for example, + * the size of the column index array + * is an estimated value in the beginning for allocating enough capacity + * for the final result. After the operation is done, the exact size of + * the shape is known and need to be reset using this function. + */ + inline void SetAuxShape(size_t i, const TShape& shape) const { + ptr_->aux_shapes[i] = shape; + } + /*! * \return the data TBlob */ inline const TBlob& data() const { - CheckAndAlloc(); + if (storage_type() == kDefaultStorage) CheckAndAlloc(); SetTBlob(); return tblob_; } + /*! + * \return the aux TBlob + */ + inline TBlob aux_data(size_t i) const { + auto stype = storage_type(); + TBlob res; + auto shape = aux_shape(i); + auto type = aux_type(i); + MSHADOW_TYPE_SWITCH(type, DType, { + auto dptr = static_cast(ptr_->aux_handles[i].dptr); + CHECK(stype == kRowSparseStorage || stype == kCSRStorage) + << "Unexpected storage type: " << stype; + res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type); + }); +#if MKL_EXPERIMENTAL == 1 + res.Mkl_mem_ = Mkl_mem_; +#endif + return res; + } /*! * \return the context of NDArray, this function is only valid when the NDArray is not empty */ @@ -119,6 +265,28 @@ class NDArray { inline int dtype() const { return dtype_; } + inline int aux_type(size_t i) const { + CHECK(!is_none()); + return ptr_->aux_types[i]; + } + /*! + * \return the number of aux data used for given storage type + */ + static size_t NumAuxData(NDArrayStorageType stype) { + size_t num = 0; + switch (stype) { + case kDefaultStorage: num = 0; break; + case kCSRStorage: num = 2; break; + case kRowSparseStorage: num = 1; break; + default: LOG(FATAL) << "Unknown storage type" << stype; break; + } + return num; + } + + inline NDArrayStorageType storage_type() const { + if (is_none()) return kUndefinedStorage; + return ptr_->storage_type; + } /*! \return whether this ndarray is not initialized */ inline bool is_none() const { return ptr_.get() == nullptr; @@ -127,6 +295,18 @@ class NDArray { bool fresh_out_grad() const; /*! \return updated grad state in entry_ */ void set_fresh_out_grad(bool state) const; + // returns true if a sparse ndarray's aux_data and storage are initialized + inline bool storage_initialized() const { + if (is_none()) return false; + auto stype = storage_type(); + CHECK_NE(stype, kDefaultStorage); + if (stype == kRowSparseStorage || stype == kCSRStorage) { + return aux_shape(0).Size() != 0; + } else { + LOG(FATAL) << "Unknown storage type"; + } + return true; + } /*! * \brief Block until all the pending write operations with respect * to current NDArray are finished, and read can be performed. @@ -260,17 +440,38 @@ class NDArray { void SyncCopyToCPU(void *data, size_t size) const; /*! * \brief Slice a NDArray - * \param begin begin index in first dim - * \param end end index in first dim + * \param begin begin index in first dim (inclusive) + * \param end end index in first dim (exclusive) * \return sliced NDArray */ NDArray Slice(index_t begin, index_t end) const; + + /*! + * \brief Slice a NDArray with non-default storage + * \param begin begin index in first dim (inclusive) + * \param end end index in first dim (exclusive) + * \return sliced NDArray + */ + void SliceEx(index_t begin, index_t end, NDArray *dst) const; /*! * \brief Index a NDArray * \param idx the index * \return idx-th sub array NDArray */ NDArray At(index_t idx) const; + // Wrap the tblob of aux data into an NDArray which shares the same variable with the + // current one. + inline const NDArray aux_ndarray(size_t i) const { + CHECK_NE(storage_type(), kDefaultStorage); + CHECK(i < ptr_->aux_shapes.size()); + return NDArray(aux_data(i), ctx().dev_id, var()); + } + // Wrap the tblob of data into an NDArray which shares the same variable with the + // current one. + inline const NDArray data_ndarray() const { + CHECK_NE(storage_type(), kDefaultStorage); + return NDArray(data(), ctx().dev_id, var()); + } /*! * \brief Create a NDArray that shares memory with current one * The new array must have smaller memory size than the current array. @@ -279,6 +480,7 @@ class NDArray { * \return NDArray in new shape and type. */ inline NDArray AsArray(const TShape &shape, int dtype) const { + CHECK_EQ(storage_type(), kDefaultStorage) << "Not implemented yet"; CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), shape.Size() * mshadow::mshadow_sizeof(dtype)) << "NDArray.AsArray: target memory size is bigger"; @@ -312,8 +514,25 @@ class NDArray { * This is an internal function used by system that normal user should not use */ inline void CheckAndAlloc() const { + CHECK_EQ(storage_type(), kDefaultStorage); ptr_->CheckAndAlloc(); } + /* ! + * \brief Alloc memory for non-default storage + * aux_shape is only known at run time + */ + inline void CheckAndAlloc(const std::vector &aux_shapes) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_); + } + inline void CheckAndAllocData(const TShape &storage_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocData(storage_shape, dtype_); + } + inline void CheckAndAllocAuxData(size_t i, const TShape &aux_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocAuxData(i, aux_shape); + } /*! * \brief Save list of ndarray into the Stream.x * \param fo The stream of output. @@ -336,43 +555,105 @@ class NDArray { private: friend class autograd::AutogradRuntime; /*! \brief the real data chunk that backs NDArray */ + // shandle is used to store the actual values in the NDArray + // aux_handles store the aux data(such as indices) if it's needed by non-default storage. struct Chunk { - /*! \brief storage handlefrom storage engine */ + /*! \brief storage handle from storage engine. + for non-default storage, shandle stores the data(value) array. + */ Storage::Handle shandle; + /*! \brief storage handles for aux data (e.g index) + for row_sparse, aux_handles[0] = indices + for csr, aux_handles[0] = indptr, aux_handles[1] = indices + */ + std::vector aux_handles; /*! \brief variable from engine */ Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed */ + /*! \brief construct from static data */ bool static_data; - /*! \brief whether allocation is delayed */ + /*! \brief whether data allocation is delayed. This doesn't indicate whether aux data + allocation is delayed. */ bool delay_alloc; + // the type of the storage. The storage_type is never kUndefinedStorage once the chunk + // is constructed. + NDArrayStorageType storage_type = kDefaultStorage; + /*! \brief type of aux */ + std::vector aux_types; + // context of data + Context ctx; + // The shape of the chunk data. + // This might not be the same shape as the NDArray, since the storage may be sparse. + // The default value for storage_shape is {0} when an empty non-default NDArray is created. + TShape storage_shape; + // The shape of aux data. The default value for the shape depends on the type of storage. + // If aux_shapes[i].Size() is zero, aux data i is empty. + std::vector aux_shapes; + // \brief skip the deletion of var handle. Usually set when shared_var is present. + bool skip_delete_var = false; + /*! \brief default cosntructor */ - Chunk() : static_data(true), delay_alloc(false) { - var = Engine::Get()->NewVariable(); - } - /*! \brief construct from static data */ - Chunk(const TBlob &data, int dev_id) - : static_data(true), - delay_alloc(false) { - var = Engine::Get()->NewVariable(); + Chunk() : static_data(true), delay_alloc(false) {} +/* if (data.dev_mask() == cpu::kDevMask) { shandle.ctx = Context::CPU(); } else { CHECK_EQ(data.dev_mask(), gpu::kDevMask); shandle.ctx = Context::GPU(dev_id); +*/ + /*! \brief construct a new chunk */ + Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype) + : static_data(false), delay_alloc(true), ctx(ctx_) { + auto size = shape.Size(); + storage_shape = shape; + var = Engine::Get()->NewVariable(); + shandle.size = size * mshadow::mshadow_sizeof(dtype); + shandle.ctx = ctx_; + if (!delay_alloc_) this->CheckAndAlloc(); + } + + Chunk(const TBlob &data, int dev_id, Engine::VarHandle shared_var) + : static_data(true), delay_alloc(false) { + CHECK(storage_type == kDefaultStorage); + // init var + if (shared_var == nullptr) { + var = Engine::Get()->NewVariable(); + } else { + skip_delete_var = true; + var = shared_var; } + // init ctx + if (data.dev_mask() == cpu::kDevMask) { + ctx = Context::CPU(); + } else { + CHECK_EQ(data.dev_mask(), gpu::kDevMask); + ctx = Context::GPU(dev_id); + } + // init shandle + shandle.ctx = ctx; shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_); + storage_shape = data.shape_; } - /*! \brief construct a new chunk */ - Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype) - : static_data(false), delay_alloc(true) { - var = Engine::Get()->NewVariable(); - shandle.size = size * mshadow::mshadow_sizeof(dtype); + // Constructor for a non-default storage chunk + Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_, + bool delay_alloc_, int dtype, const std::vector &aux_types_, + const std::vector &aux_shapes_) + : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_), + aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_), + aux_shapes(aux_shapes_) { shandle.ctx = ctx; - if (!delay_alloc_) this->CheckAndAlloc(); + var = Engine::Get()->NewVariable(); + // aux_handles always reflect the correct number of aux data + for (size_t i = 0; i < aux_shapes.size(); i++) { + CheckAndAllocAuxData(i, aux_shapes[i]); + } + if (!delay_alloc) { + CheckAndAllocData(storage_shape, dtype); + } } /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { @@ -381,22 +662,98 @@ class NDArray { delay_alloc = false; } } - /*! \brief destructor */ - ~Chunk() { - if (static_data || delay_alloc) { - Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); + inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, + int dtype) { + // calculate size, perform allocation + if (kRowSparseStorage == storage_type) { + // For row sparse, aux_shape indicates the number of rows to allocate + auto aux_shape = aux_shapes[rowsparse::kIdx]; + CHECK_EQ(shape.ndim(), 2) << "High dim RowSparse not yet implemented"; + CheckAndAllocAuxData(rowsparse::kIdx, aux_shape); + TShape storage_shape(shape); + storage_shape[0] = aux_shape[0]; + CheckAndAllocData(storage_shape, dtype); + } else if (kCSRStorage == storage_type) { + CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]); + CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]); + CheckAndAllocData(aux_shapes[csr::kIdx], dtype); } else { - Storage::Handle h = this->shandle; - Engine::Get()->DeleteVariable([h](RunContext s) { - Storage::Get()->Free(h); - }, shandle.ctx, var); + LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc"; } } + // create storage handle for data based on shape and dtype, assuming ctx is set + // storage shape is also updated + // if data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocData(const TShape &shape, int dtype) { + CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; + auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + if (shandle.size < dbytes) { + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); + // init storage + shandle = Storage::Get()->Alloc(dbytes, ctx); + } + // init shape + storage_shape = shape; + // delay_alloc is only set when data storage handle is present + delay_alloc = false; + } + // create storage handle for aux data based on shape + // this function assumes ctx, aux shapes and aux types are set + // aux shape is also updated + // if aux data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocAuxData(size_t i, const TShape &shape) { + CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kUndefinedStorage) + << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kDefaultStorage) + << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; + if (aux_handles.size() <= i) { + aux_handles.resize(i + 1); + } + size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); + if (aux_handles[i].size < aux_bytes) { + // free storage if necessary and alloc again + if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]); + // init aux storage + aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx); + } + // init shape + aux_shapes[i] = shape; + } + /*! \brief destructor */ + ~Chunk() { + if (skip_delete_var) return; + bool skip_free = static_data || delay_alloc; + Storage::Handle h = this->shandle; + std::vector aux_h = this->aux_handles; + Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) { + if (skip_free == false) { + Storage::Get()->Free(h); + for (size_t i = 0; i < aux_h.size(); i++) { + if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]); + } + } + }, shandle.ctx, var); + } }; void SetTBlob() const { - tblob_.dptr_ = static_cast(ptr_->shandle.dptr) + byte_offset_; - tblob_.shape_ = shape_; + CHECK(ptr_ != nullptr); + TShape shape = shape_; + char *dptr = static_cast(ptr_->shandle.dptr); + auto stype = storage_type(); + if (stype == kDefaultStorage) { + dptr += byte_offset_; + } else if (stype == kCSRStorage || stype == kRowSparseStorage) { + shape = storage_shape(); + } else { + LOG(FATAL) << "unknown storage type " << stype; + } + tblob_.dptr_ = dptr; + tblob_.shape_ = shape; tblob_.type_flag_ = dtype_; tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); #if MKL_EXPERIMENTAL == 1 @@ -404,11 +761,12 @@ class NDArray { #endif } + #if MKL_EXPERIMENTAL == 1 std::shared_ptr Mkl_mem_; #endif /*! \brief internal data of NDArray */ - std::shared_ptr ptr_; + std::shared_ptr ptr_{nullptr}; /*! \brief shape of current NDArray */ TShape shape_; /*! \brief byte offset in chunk */ @@ -435,11 +793,112 @@ class NDArray { * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. + * \param alloc_output whether to allocate memory for the output ndarray * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); +// Make a copy of a CSR NDArray +template +inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source storage is not initialized, fill destination with zeros + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + // TODO(haibin) implement FillZerosCsrImpl + // op::FillZerosCsrImpl(s, to); + return; + } + // Allocate storage + to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); + to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); + to->CheckAndAllocData(from.aux_shape(csr::kIdx)); + // FIXME This is a naive implementation for CSR copy. It, however, is + // not efficient when the source CSR is sliced. In that case, we're copying + // a superset of values and indices of the slice. + // Ideally, we should truncate the values and indices array, and adjust indptr + // accordingly. + TBlob val = to->data(); + TBlob indptr = to->aux_data(csr::kIndPtr); + TBlob idx = to->aux_data(csr::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIndPtr), &indptr, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a row-sparse NDArray +template +inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source is zeros, fill destination with zeros, too + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + op::FillZerosRspImpl(s, to); + return; + } + auto aux_shape = from.aux_shape(rowsparse::kIdx); + to->CheckAndAlloc({aux_shape}); + TBlob val = to->data(); + TBlob idx = to->aux_data(rowsparse::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(rowsparse::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a dense NDArray +template +inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + TBlob tmp = to->data(); + ndarray::Copy(from.data(), &tmp, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of an NDArray based on storage type +template +void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace std; + using namespace mshadow; + // if storage type doesn't match, cast the storage first + auto from_stype = from.storage_type(); + auto to_stype = to->storage_type(); + NDArray casted_nd; + if (from_stype != to_stype) { + TShape shape = from.shape(); + auto from_ctx = from.ctx(); + auto s = ctx.get_stream(); + // TODO(haibin) inplace conversion + if (to_stype == kDefaultStorage) { + casted_nd = NDArray(shape, from_ctx); + } else { + casted_nd = NDArray(to_stype, shape, from_ctx); + } + op::CastStorageComputeImpl(s, from, casted_nd); + } else { + casted_nd = from; + } + if (to_stype == kDefaultStorage) { + CopyFromToDnsImpl(casted_nd, to, ctx); + } else if (to_stype == kRowSparseStorage) { + CopyFromToRspImpl(casted_nd, to, ctx); + } else if (to_stype == kCSRStorage) { + CopyFromToCsrImpl(casted_nd, to, ctx); + } else { + LOG(FATAL) << "unknown storage type" << to_stype; + } + if (is_same::value || is_same::value) { + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + } +} /*! * \brief Perform elementwise sum over each data from source, store result into out. diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 316a90fe0841..bf9961c8234e 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -7,7 +7,6 @@ #ifndef MXNET_OP_ATTR_TYPES_H_ #define MXNET_OP_ATTR_TYPES_H_ - #include #include @@ -18,6 +17,9 @@ #include "./operator.h" #include "./ndarray.h" +#define FCOMP_EX_CPU "FComputeEx" +#define FCOMP_EX_GPU "FComputeEx" + namespace mxnet { using nnvm::NodeAttrs; @@ -61,6 +63,17 @@ using FCompute = std::function& inputs, const std::vector& req, const std::vector& outputs)>; +/*! + * \brief Resiger an NDArray compute function for simple stateless forward only operator + * + * \note Register under "FComputeEx" and "FComputeEx" + * Dispatched only when operators process non-default storage inputs or outputs + */ +using FComputeEx = std::function& inputs, + const std::vector& req, + const std::vector& outputs)>; } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 1b765233947d..e236a9cf313b 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -23,11 +23,11 @@ class Storage { /*! * \brief Pointer to the data. */ - void* dptr; + void* dptr{nullptr}; /*! * \brief Size of the storage. */ - size_t size; + size_t size{0}; /*! * \brief Context information about device and ID. */ diff --git a/mshadow b/mshadow index c037b06ddd81..bbde96541478 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit c037b06ddd810d39322cd056650f8b1f4763dd9d +Subproject commit bbde96541478cd93fe9d617e8d1d955c264bac1d diff --git a/nnvm b/nnvm index 7796ac76ccea..2e3561500de9 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 7796ac76ccea1fba31afc32056c83f6da38b6c57 +Subproject commit 2e3561500de99a0c173f3bc7b1a6c2b31435d6d9 diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index ff5f6cd6be7e..768d9ede2643 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -8,6 +8,7 @@ from . import base from . import contrib from . import ndarray +from . import sparse_ndarray from . import name # use mx.sym as short for symbol from . import symbol as sym @@ -18,6 +19,7 @@ from . import operator # use mx.nd as short for mx.ndarray from . import ndarray as nd +from . import sparse_ndarray as sparse_nd # use mx.rnd as short for mx.random from . import random as rnd from . import random diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py index e56361efdb1f..b20e1eb0f086 100644 --- a/python/mxnet/contrib/autograd.py +++ b/python/mxnet/contrib/autograd.py @@ -7,6 +7,8 @@ import functools from ..base import _LIB, check_call, string_types from ..base import mx_uint, NDArrayHandle, c_array +# pylint: disable= unused-import +from ..sparse_ndarray import SparseNDArray from ..ndarray import NDArray, zeros_like from ..symbol import _GRAD_REQ_MAP diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 6b9aab2de6f1..3991319ff13a 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -11,6 +11,7 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from .sparse_ndarray import _ndarray_cls from . import ndarray as nd # those functions are not used here, we just import them to keep backward compatibility @@ -90,7 +91,9 @@ def _get_outputs(self): handles = ctypes.POINTER(NDArrayHandle)() check_call(_LIB.MXExecutorOutputs(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] + num_output = out_size.value + outputs = [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(num_output)] + return outputs def forward(self, is_train=False, **kwargs): """Calculate the outputs specified by the bound symbol. diff --git a/python/mxnet/io.py b/python/mxnet/io.py index ec3c25f54d30..b728f50838a8 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -13,6 +13,7 @@ from .base import mx_real_t from .base import check_call, build_param_doc as _build_param_doc from .ndarray import NDArray +from .sparse_ndarray import _ndarray_cls from .ndarray import array from .ndarray import concatenate @@ -752,12 +753,12 @@ def iter_next(self): def getdata(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + return _ndarray_cls(hdl, False) def getlabel(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + return _ndarray_cls(hdl, False) def getindex(self): index_size = ctypes.c_uint64(0) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ab07421caffd..3384be7947ac 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -48,7 +48,7 @@ def updater_handle(key, lhs_handle, rhs_handle, _): class KVStore(object): """A key-value store for synchronization of values, over multiple devices.""" - def __init__(self, handle): + def __init__(self, handle, name2idx=None): """Initializes a new KVStore. Parameters @@ -58,6 +58,7 @@ def __init__(self, handle): """ assert isinstance(handle, KVStoreHandle) self.handle = handle + self.name2idx = name2idx if name2idx is not None else {} self._updater = None self._updater_func = None @@ -395,7 +396,7 @@ def _send_command_to_servers(self, head, body): check_call(_LIB.MXKVStoreSendCommmandToServers( self.handle, mx_uint(head), c_str(body))) -def create(name='local'): +def create(name='local', name2idx=None): """Creates a new KVStore. For single machine training, there are two commonly used types: @@ -435,4 +436,4 @@ def create(name='local'): handle = KVStoreHandle() check_call(_LIB.MXKVStoreCreate(c_str(name), ctypes.byref(handle))) - return KVStore(handle) + return KVStore(handle, name2idx=name2idx) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 189f301e91f7..b90500d4a9c5 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -37,7 +37,7 @@ 'eval_metric', 'locals']) -def _create_kvstore(kvstore, num_device, arg_params): +def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): """Create kvstore This function select and create a proper kvstore if given the kvstore type. @@ -61,8 +61,8 @@ def _create_kvstore(kvstore, num_device, arg_params): # no need to use kv for single device and single machine kv = None else: - kv = kvs.create(kvstore) - if kvstore == 'local': + kv = kvs.create(kvstore, name2idx=name2idx) + if kvstore is 'local': # automatically select a proper local max_size = max(np.prod(param.shape) for param in arg_params.values()) @@ -85,25 +85,50 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, if update_on_kvstore: kvstore.pull(idx, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): - """Perform update of param_arrays from grad_arrays on kvstore.""" - for index, pair in enumerate(zip(param_arrays, grad_arrays)): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, + stype_dict=None, param_names=None): + """Perform update of param_arrays from grad_arrays on kvstore. + If `param_names` is None or kvstore doesn't have a `name2idx` dictionary, + the index of a param is determined by the order it appears in `param_arrays`. """ + stype_dict = {} if stype_dict is None else stype_dict + for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + index = i + if param_names is not None: + name = param_names[i] + index = index if name not in kvstore.name2idx else kvstore.name2idx[name] + # cast storage type if stype doesn't match + if name in stype_dict: + for i, grad in enumerate(grad_list): + stype = stype_dict[name] + if grad_list[i].storage_type != stype: + grad_list[i] = nd.cast_storage(grad, stype) # push gradient, priority is negative index kvstore.push(index, grad_list, priority=-index) # pull back the weights kvstore.pull(index, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None): + kvstore=None, stype_dict=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" - for index, pair in enumerate(zip(param_arrays, grad_arrays)): + stype_dict = {} if stype_dict is None else stype_dict + for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + # cast storage type if stype doesn't match + if param_names is not None and param_names[i] in stype_dict: + for i, grad in enumerate(grad_list): + stype = stype_dict[param_names[i]] + if grad_list[i].storage_type != stype: + grad_list[i] = nd.cast_storage(grad, stype) + index = i if kvstore: + if param_names is not None: + name = param_names + index = index if name not in kvstore.name2idx else kvstore.name2idx[name] # push gradient, priority is negative index kvstore.push(index, grad_list, priority=-index) # pull back the sum gradients, to the same locations. diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index 820841087a9c..c78daa1137c8 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -849,9 +849,17 @@ def get_input_grads(self, merge_multi_context=True): """ raise NotImplementedError() - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to the installed optimizer and the gradients computed - in the previous forward-backward batch. + in the previous forward-backward batch. The storage type of parameters is casted according + to `storage_type_dict`, if provided. + + Parameters + ---------- + storage_type_dict: dict of str to str + Defaults to ``None``. Desired storage types of parameters for parameter update. If the + parameter gradient is not of desired storage type, its storage type will be casted + before the update. Examples -------- diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 11922ddafb56..ae10e8e401d0 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -399,13 +399,13 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._curr_module.backward(out_grads=out_grads) - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle. """ assert self.binded and self.params_initialized and self.optimizer_initialized self._params_dirty = True - self._curr_module.update() + self._curr_module.update(storage_type_dict=storage_type_dict) def get_outputs(self, merge_multi_context=True): """Gets outputs from a previous forward computation. diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index fef5c507d7e8..a0eb19dafccc 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -454,8 +454,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', if self._params_dirty: self._sync_params_from_devices() + name2idx = {} + for idx, name in enumerate(self._exec_group.param_names): + name2idx[name] = idx + (kvstore, update_on_kvstore) = \ - _create_kvstore(kvstore, len(self._context), self._arg_params) + _create_kvstore(kvstore, len(self._context), self._arg_params, name2idx=name2idx) batch_size = self._exec_group.batch_size if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: @@ -558,7 +562,7 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._exec_group.backward(out_grads=out_grads) - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. @@ -572,7 +576,9 @@ def update(self): if self._update_on_kvstore: _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, - self._kvstore) + self._kvstore, + stype_dict=storage_type_dict, + param_names=self._param_names) else: _update_params(self._exec_group.param_arrays, self._exec_group.grad_arrays, diff --git a/python/mxnet/module/python_module.py b/python/mxnet/module/python_module.py index f46ea280aaff..82dcb06aa020 100644 --- a/python/mxnet/module/python_module.py +++ b/python/mxnet/module/python_module.py @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non """ pass - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. Currently we do nothing here. Subclass should override this method if contains parameters. diff --git a/python/mxnet/module/sequential_module.py b/python/mxnet/module/sequential_module.py index 21e30fb3b0ce..383286642e0c 100644 --- a/python/mxnet/module/sequential_module.py +++ b/python/mxnet/module/sequential_module.py @@ -344,14 +344,14 @@ def backward(self, out_grads=None): out_grads = module.get_input_grads() - def update(self): + def update(self, storage_type_dict=None): """Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle. """ assert self.binded and self.params_initialized and self.optimizer_initialized for module in self._modules: - module.update() + module.update(storage_type_dict=storage_type_dict) def get_outputs(self, merge_multi_context=True): """Gets outputs from a previous forward computation. diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8900843f5937..8e8d3ffebbd4 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -19,7 +19,7 @@ import numpy as np from .base import _LIB, string_types, numeric_types from .base import c_array, py_str, c_str, mx_real_t, _Null # pylint: disable=unused-import -from .base import mx_uint, NDArrayHandle, check_call, OpHandle +from .base import mx_uint, NDArrayHandle, check_call from .base import ctypes2buffer from .context import Context from . import _ndarray_internal as _internal @@ -54,7 +54,6 @@ np.uint8 : 3, np.int32 : 4 } - _DTYPE_MX_TO_NP = { 0 : np.float32, 1 : np.float64, @@ -62,7 +61,18 @@ 3 : np.uint8, 4 : np.int32 } -# pylint: enable= no-member +_STORAGE_TYPE_ID_TO_STR = { + -1 : 'undefined', + 0 : 'default', + 1 : 'row_sparse', + 2 : 'csr', +} +_STORAGE_TYPE_STR_TO_ID = { + 'undefined' : -1, + 'default' : 0, + 'row_sparse' : 1, + 'csr' : 2, +} def _new_empty_handle(): """Returns a new empty handle. @@ -106,6 +116,11 @@ def waitall(): """ check_call(_LIB.MXNDArrayWaitAll()) +def _storage_type(handle): + storage_type = ctypes.c_int(0) + check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type))) + return _STORAGE_TYPE_ID_TO_STR[storage_type.value] + class NDArray(NDArrayBase): """An array object representing a multidimensional, homogeneous array of fixed-size items. @@ -119,6 +134,9 @@ def __repr__(self): return '<%s %s @%s>' % (self.__class__.__name__, shape_info, self.context) + def __reduce__(self): + return (NDArray, (None,), self.__getstate__()) + def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ return add(self, other) @@ -629,7 +647,6 @@ def wait_to_read(self): """ check_call(_LIB.MXNDArrayWaitToRead(self.handle)) - @property def ndim(self): """Returns the number of dimensions of this array @@ -664,6 +681,7 @@ def shape(self): self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[:ndim.value]) + @property def size(self): """Number of elements in the array. @@ -725,6 +743,10 @@ def dtype(self): self.handle, ctypes.byref(mx_dtype))) return _DTYPE_MX_TO_NP[mx_dtype.value] + @property + def storage_type(self): + return _storage_type(self.handle) + @property # pylint: disable= invalid-name, undefined-variable def T(self): @@ -949,6 +971,13 @@ def backward(self, out_grad=None, retain_graph=False): c_array(NDArrayHandle, ograd_handles), ctypes.c_int(retain_graph))) + def to_csr(self): + # pylint: disable=undefined-variable + return cast_storage(self, storage_type='csr') + + def to_rsp(self): + # pylint: disable=undefined-variable + return cast_storage(self, storage_type='row_sparse') def onehot_encode(indices, out): """One-hot encoding indices into matrix out. @@ -2406,37 +2435,5 @@ def %s(%s): ndarray_function.__module__ = 'mxnet.ndarray' return ndarray_function - -# pylint: enable=too-many-locals, invalid-name -def _init_ndarray_module(ndarray_class, root_namespace): - """List and add all the ndarray functions to current module.""" - _set_ndarray_class(ndarray_class) - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = _sys.modules["%s.ndarray" % root_namespace] - module_internal = _sys.modules["%s._ndarray_internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_ndarray_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.ndarray' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) - -_init_ndarray_module(NDArray, "mxnet") - # from .base import add_fileline_to_docstring # add_fileline_to_docstring(__name__) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py new file mode 100644 index 000000000000..79351b1eb371 --- /dev/null +++ b/python/mxnet/sparse_ndarray.py @@ -0,0 +1,654 @@ +# coding: utf-8 +"""SparseNDArray API of mxnet.""" +from __future__ import absolute_import +from __future__ import division +try: + from __builtin__ import slice as py_slice +except ImportError: + from builtins import slice as py_slice + +import ctypes +import warnings + +import os as _os +import sys as _sys + +# import operator +import numpy as np +from .base import _LIB, numeric_types +from .base import c_array, py_str, mx_real_t, c_str +from .base import mx_uint, NDArrayHandle, check_call, OpHandle +from .context import Context +from . import _ndarray_internal as _internal +from . import ndarray +from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray import NDArray, _storage_type, _make_ndarray_function + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. +# pylint: disable=unused-import +try: + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + elif _sys.version_info >= (3, 0): + from ._cy3.ndarray import NDArrayBase, _set_ndarray_class + else: + from ._cy2.ndarray import NDArrayBase, _set_ndarray_class +except ImportError: + if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + +# pylint: enable=unused-import +_STORAGE_AUX_TYPES = { + 'row_sparse': [np.int32], + 'csr': [np.int32, np.int32] +} + +def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None): + """Return a new handle with specified storage type, shape, dtype and context. + + Empty handle is only used to hold results + + Returns + ------- + handle + A new empty ndarray handle + """ + hdl = NDArrayHandle() + aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types] + aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes + aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes] + aux_shapes = sum(aux_shapes, ()) + num_aux = mx_uint(len(aux_types)) + check_call(_LIB.MXNDArrayCreateSparseEx( + ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[storage_type])), + c_array(mx_uint, shape), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + num_aux, + c_array(ctypes.c_int, aux_type_ids), + c_array(mx_uint, aux_shape_lens), + c_array(mx_uint, aux_shapes), + ctypes.byref(hdl))) + return hdl + +class SparseNDArray(NDArray): + """An array object representing a multidimensional, homogeneous array of +fixed-size items, stored in sparse format. See CSRNDArray and RowSparseNDArray +for more details. + + """ + + def __reduce__(self): + raise Exception('Not implemented for SparseND yet!') + # return SparseNDArray, (None,), self.__getstate__() + + def __add__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __iadd__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __radd__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __isub__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rsub__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __imul__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rmul__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rdiv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __idiv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rtruediv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __itruediv__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __pow__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __rpow__(self, other): + raise Exception('Not implemented for SparseND yet!') + + def __getstate__(self): + raise Exception('Not implemented for SparseND yet!') + + def __setstate__(self, state): + raise Exception('Not implemented for SparseND yet!') + + def __setitem__(self, key, value): + """x.__setitem__(i, y) <=> x[i]=y + + Set self[key] to value. Only slice [:] is supported. + + Parameters + ---------- + key : slice + The indexing key. + value : NDArray or numpy.ndarray + The value to set. + + Examples + -------- + >>> src = mx.sparse_nd.row_sparse(data, indices, (3,3)) + >>> src.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign SparseNDArray with same storage type + >>> x = mx.sparse_nd.zeros('row_sparse', (3,3)) + >>> x[:] = src + >>> x.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign NDArray to SparseNDArray + >>> x[:] = mx.nd.ones((3,3)) + >>> x.asnumpy() + array([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]], dtype=float32) + """ + if not self.writable: + raise ValueError('Failed to assign to a readonly NDArray') + if isinstance(key, py_slice): + if key.step is not None or key.start is not None or key.stop is not None: + raise ValueError('Assignment with slicing not supported in SparseNDArray.') + if isinstance(value, NDArray): + # avoid copying to itself + if value.handle is not self.handle: + value.copyto(self) + elif isinstance(value, numeric_types): + raise Exception("Assigning numeric types to SparseNDArray not supported yet.") + elif isinstance(value, (np.ndarray, np.generic)): + # TODO(haibin) Implement _sync_copyfrom for sparse ndarray to avoid an extra copy + warnings.warn('Assigning non-NDArray object to SparseNDArray is not efficient', + RuntimeWarning) + tmp = ndarray.array(value) + tmp.copyto(self) + else: + raise TypeError('type %s not supported' % str(type(value))) + else: + assert(isinstance(key, (int, tuple))) + raise Exception('SparseNDArray only supports [:] for assignment') + + def __getitem__(self, key): + """x.__getitem__(i) <=> x[i] + + Returns a sliced view of this array. + + Parameters + ---------- + key : int or slice + Indexing key. + + Examples + -------- + >>> x[:] = mx.nd.arange(0,6).reshape((2,3)) + >>> x.asnumpy() + array([[ 0., 1., 2.], + [ 3., 4., 5.]], dtype=float32) + >>> x[1:2].asnumpy() + array([[ 3., 4., 5.]], dtype=float32) + """ + stype = self.storage_type + if stype != 'csr': + raise Exception("__getitem__ for " + str(stype) + " not implemented yet") + if isinstance(key, int): + raise Exception("Not implemented yet") + if isinstance(key, py_slice): + if key.step is not None: + raise ValueError('NDArray only supports continuous slicing on axis 0') + if key.start is not None or key.stop is not None: + return self._slice(key.start, key.stop) + else: + return self + if isinstance(key, tuple): + raise ValueError('Multi-dimension indexing is not supported') + + def _sync_copyfrom(self, source_array): + raise Exception('Not implemented for SparseND yet!') + + def _slice(self, start, stop): + """Returns a read-only SparseNDArray slice that shares memory with current one. + To create a writable slice, please use ``mx.nd.slice`` instead. Currently only + `csr` storage type is supported. + + Parameters + ---------- + start : int + Starting index of slice. + stop : int + Finishing index of slice. + + Example + ---------- + >>> indptr = np.array([0, 2, 3, 6]) + >>> indices = np.array([0, 2, 2, 0, 1, 2]) + >>> data = np.array([1, 2, 3, 4, 5, 6]) + >>> a = mx.sparse_nd.csr(data, indptr, indices, (3, 3)) + >>> a.asnumpy() + array([[1, 0, 2], + [0, 0, 3], + [4, 5, 6]]) + + >>> a[1:2].asnumpy() + array([[0, 0, 3]]) + + """ + stype = self.storage_type + assert(stype == 'csr'), "_slice for " + str(stype) + " not implemented yet" + warnings.warn('slicing SparseNDArray is not efficient', RuntimeWarning) + shape = list(self.shape) + shape[0] = stop - start + handle = _new_alloc_handle(self.storage_type, tuple(shape), self.context, + True, self.dtype, self.aux_types) + start = mx_uint(start) if start else mx_uint(0) + stop = mx_uint(stop) if stop else mx_uint(self.shape[0]) + + check_call(_LIB.MXNDArraySliceEx(self.handle, start, stop, handle)) + ret = _ndarray_cls(handle=handle, writable=False) + return ret + + def _at(self, idx): + raise Exception('at operator for SparseND is not supported.') + + def reshape(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def broadcast_to(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def _aux_type(self, i): + """Data-type of the array’s ith aux data. + + Returns + ------- + numpy.dtype + This SparseNDArray's aux data type. + """ + aux_type = ctypes.c_int() + check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type))) + return _DTYPE_MX_TO_NP[aux_type.value] + + @property + def values(self): + """The values array of the SparseNDArray. This is a read-only view of the values array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's values array. + """ + return self._data() + + + @property + def _num_aux(self): + ''' The number of aux data used to help store the sparse ndarray. + ''' + return len(_STORAGE_AUX_TYPES[self.storage_type]) + + @property + # pylint: disable= invalid-name, undefined-variable + def T(self): + raise Exception('Transpose is not supported for SparseNDArray.') + + @property + def aux_types(self): + """The data types of the aux data for the SparseNDArray. + """ + aux_types = [] + num_aux = self._num_aux + for i in range(num_aux): + aux_types.append(self._aux_type(i)) + return aux_types + + def asnumpy(self): + """Return a dense ``numpy.ndarray`` object with value copied from this array + + """ + return self.to_dense().asnumpy() + + def astype(self, dtype): + raise Exception('Not implemented for SparseND yet!') + + def copyto(self, other): + """Copies the value of this array to another array. + + If ``other`` is a ``NDArray`` object, then ``other.shape`` and + ``self.shape`` should be the same. This function copies the value from + ``self`` to ``other``. + + If ``other`` is a context, a new ``NDArray`` will be first created on + the target context, and the value of ``self`` is copied. + + Parameters + ---------- + other : NDArray or Context + The destination array or context. + + Returns + ------- + NDArray + The copied array. If ``other`` is an ``NDArray``, then the return value + and ``other`` will point to the same ``NDArray``. + """ + if isinstance(other, NDArray): + if other.handle is self.handle: + warnings.warn('You are attempting to copy an array to itself', RuntimeWarning) + return + return _internal._copyto(self, out=other) + elif isinstance(other, Context): + hret = _ndarray_cls(_new_alloc_handle(self.storage_type, self.shape, other, + True, self.dtype, self.aux_types)) + return _internal._copyto(self, out=hret) + else: + raise TypeError('copyto does not support type ' + str(type(other))) + + def to_dense(self): + return to_dense(self) + + def _aux_data(self, i, writable=False): + """ Get an NDArray referencing the ith aux data array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl))) + return NDArray(hdl, writable) + + def _data(self, writable=False): + """ Get an NDArray referencing the value array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl))) + return NDArray(hdl, writable) + +class CSRNDArray(SparseNDArray): + """A CSRNDArray represents a NDArray as three separate arrays: `values`, + `indptr` and `indices`. It uses the standard CSR representation where the column indices for + row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored + in values[indptr[i]:indptr[i+1]]. + + """ + + @property + def indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + return self._aux_data(1) + + @property + def indptr(self): + """The indptr array of the SparseNDArray with `csr` storage type. + This is a read-only view of the indptr array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indptr array. + """ + return self._aux_data(0) + +class RowSparseNDArray(SparseNDArray): + """A RowSparseNDArray is typically used to represent a subset of a larger + NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values + in indices are the indices in the first dimension of the slices that have been extracted from + the larger NDArray. The indices are expected to be sorted in ascending order. + + The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp`` + RowSparseNDArray + + ``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]`` + + RowSparseNDArray is used principally in the definition of gradients for operations + that have sparse gradients (e.g. SparseEmbedding). + """ + + @property + def indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + return self._aux_data(0) + +def _prepare_src_array(src, dtype, default_dtype): + if isinstance(src, NDArray): + dtype = src.dtype if dtype is None else dtype + else: + dtype = default_dtype if dtype is None else dtype + if not isinstance(src, np.ndarray): + try: + src = np.array(src, dtype=dtype) + except: + raise TypeError('values must be array like object') + return src, dtype + +def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): + """Creates a 2D array with compressed sparse row format. + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [nnz], where D0 is the number of + non-zero entries. + indptr: array_like + An object exposing the array interface, with shape [D0 + 1]. The first element in indptr + should always be zero. + indices: array_like + An object exposing the array interface, with shape [nnz]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indptr_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indptr.dtype`` + if `indptr` is an `NDArray`, `int32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + CSRNDArray + A `CSRNDArray` with the `csr` storage representation. + """ + storage_type = 'csr' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indptr, indptr_type = _prepare_src_array(indptr, indptr_type, + _STORAGE_AUX_TYPES[storage_type][0]) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][1]) + # verify types + assert('int' in str(indptr_type) or 'long' in str(indptr_type)) + assert('int' in str(indices_type) or 'long' in str(indices_type)) + # verify shapes + aux_shapes = [indptr.shape, indices.shape] + assert(values.ndim == 1) + assert(indptr.ndim == 1) + assert(indices.ndim == 1) + assert(len(shape) == 2) + result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indptr_type, indices_type], aux_shapes)) + # assign indptr, indices and values + values_ref = result._data(True) + indptr_ref = result._aux_data(0, True) + indices_ref = result._aux_data(1, True) + values_ref[:] = values + indptr_ref[:] = indptr + indices_ref[:] = indices + return result + +def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None): + """Creates a row sparse array with a set of tensor slices at given indices. + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is + the number of rows with non-zeros entries. + indices: array_like + An object exposing the array interface, with shape [D0]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + RowSparseNDArray + An `RowSparseNDArray` with the `row_sparse` storage representation. + """ + storage_type = 'row_sparse' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][0]) + # verify types + assert('int' in str(indices_type) or 'long' in str(indices_type)) + # verify shapes + assert(values.ndim == len(shape)) + assert(indices.ndim == 1) + result = RowSparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indices_type], [indices.shape])) + # assign indices and values + values_ref = result._data(True) + indices_ref = result._aux_data(0, True) + values_ref[:] = values + indices_ref[:] = indices + return result + +def to_dense(source): + """ Return a dense array representation of this SparseNDArray. + + Returns + ------- + NDArray + The dense array with default storage + """ + return ndarray.cast_storage(source, storage_type='default') + +def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): + """Return a new array of given shape and type, filled with zeros. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array + storage_type: string + The storage type of the empty array, such as 'row_sparse', 'csr', etc + ctx : Context, optional + An optional device context (default is the current default context) + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`) + aux_types: list of numpy.dtype, optional + An optional type for the aux data for SparseNDArray (default values depends + on the storage type) + + Returns + ------- + SparseNDArray + A created array + Examples + -------- + >>> mx.sparse_nd.zeros('csr', (1,2), mx.gpu(0)) + + >>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + if aux_types is None: + if storage_type == 'row_sparse' or storage_type == 'csr': + aux_types = _STORAGE_AUX_TYPES[storage_type] + else: + raise Exception("unknown storage type") + assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type])) + out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types)) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out) + +def _ndarray_cls(handle, writable=True): + stype = _storage_type(handle) + if stype == 'default': + return NDArray(handle, writable=writable) + elif stype == 'csr': + return CSRNDArray(handle, writable=writable) + elif stype == 'row_sparse': + return RowSparseNDArray(handle, writable=writable) + else: + raise Exception("unknown storage type") + +# pylint: enable=too-many-locals, invalid-name +def _init_ndarray_module(ndarray_class, root_namespace): + """List and add all the ndarray functions to current module.""" + _set_ndarray_class(ndarray_class) + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] + for i in range(size.value): + op_names.append(py_str(plist[i])) + + module_obj = _sys.modules["%s.ndarray" % root_namespace] + module_internal = _sys.modules["%s._ndarray_internal" % root_namespace] + module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = _make_ndarray_function(hdl, name) + if function.__name__.startswith('_contrib_'): + function.__name__ = function.__name__[9:] + function.__module__ = 'mxnet.contrib.ndarray' + setattr(module_contrib, function.__name__, function) + elif function.__name__.startswith('_'): + setattr(module_internal, function.__name__, function) + else: + setattr(module_obj, function.__name__, function) + +_init_ndarray_module(_ndarray_cls, "mxnet") diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 14203e59862d..e752eb541648 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -19,6 +19,8 @@ from .context import Context, cpu from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from .name import NameManager # pylint: disable=unused-import +from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID +from .sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal from .attribute import AttrScope @@ -721,6 +723,89 @@ def list_auxiliary_states(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] + def infer_storage_type(self, *args, **kwargs): + """Infer the storage type of outputs and arguments of given known types of arguments. + + User can either pass in the known types in positional way or keyword argument way. + Tuple of Nones is returned if there is not enough information passed in. + An error will be raised if there is inconsistency found in the known types passed in. + + Parameters + ---------- + *args : + Provide type of arguments in a positional way. + Unknown type can be marked as None + + **kwargs : + Provide keyword arguments of known types. + + Returns + ------- + arg_storage_types : list of numpy.dtype or None + List of types of arguments. + The order is in the same order as list_arguments() + out_storage_types : list of numpy.dtype or None + List of types of outputs. + The order is in the same order as list_outputs() + aux_storage_types : list of numpy.dtype or None + List of types of outputs. + The order is in the same order as list_auxiliary_states() + """ + # pylint: disable=too-many-locals + if len(args) != 0 and len(kwargs) != 0: + raise ValueError('Can only specify known argument \ + types either by positional or kwargs way.') + sdata = [] + if len(args) != 0: + keys = None + for s in args: + if s is not None: + if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring): + raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID)) + sdata.append(_STORAGE_TYPE_STR_TO_ID[s]) + else: + sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined']) + else: + keys = [] + for k, v in kwargs.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + keys.append(c_str(k)) + sdata.append(_STORAGE_TYPE_STR_TO_ID[v]) + arg_storage_type_size = mx_uint() + arg_storage_type_data = ctypes.POINTER(ctypes.c_int)() + out_storage_type_size = mx_uint() + out_storage_type_data = ctypes.POINTER(ctypes.c_int)() + aux_storage_type_size = mx_uint() + aux_storage_type_data = ctypes.POINTER(ctypes.c_int)() + complete = ctypes.c_int() + check_call(_LIB.MXSymbolInferStorageType( + self.handle, + mx_uint(len(sdata)), + c_array(ctypes.c_char_p, keys), + c_array(ctypes.c_int, sdata), + ctypes.byref(arg_storage_type_size), + ctypes.byref(arg_storage_type_data), + ctypes.byref(out_storage_type_size), + ctypes.byref(out_storage_type_data), + ctypes.byref(aux_storage_type_size), + ctypes.byref(aux_storage_type_data), + ctypes.byref(complete))) + if complete.value != 0: + arg_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \ + for i in range(arg_storage_type_size.value)] + out_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \ + for i in range(out_storage_type_size.value)] + aux_storage_types = [ + _STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \ + for i in range(aux_storage_type_size.value)] + return (arg_storage_types, out_storage_types, aux_storage_types) + else: + return (None, None, None) + # pylint: enable=too-many-locals + + def infer_type(self, *args, **kwargs): """Infers the type of all arguments and all outputs, given the known types for some arguments. @@ -1160,8 +1245,9 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): raise TypeError('Only accept list of NDArrays or dict of str to NDArray') return c_array(NDArrayHandle, arg_handles), arg_arrays - def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, - shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): + def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=None, + group2ctx=None, shared_arg_names=None, shared_exec=None, + shared_buffer=None, **kwargs): """Bind current symbol to get an executor, allocate all the arguments needed. Allows specifying data types. @@ -1203,6 +1289,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype + storage_type_dict : Dict of str->str + Input storage type dictionary, name->storage_type + group2ctx : Dict of string to mx.Context The dict mapping the `ctx_group` attribute to the context assignment. @@ -1217,7 +1306,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer : Dict of string to `NDArray` The dict mapping argument names to the `NDArray` that can be reused for initializing the current executor. This buffer will be checked for reuse if one argument name - of the current executor is not found in `shared_arg_names`. + of the current executor is not found in `shared_arg_names`. The `NDArray`s are + expected have default storage type. kwargs : Dict of str->shape Input shape dictionary, name->shape @@ -1227,6 +1317,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, executor : mxnet.Executor The generated executor """ + # data types num_provided_arg_types = 0 provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types @@ -1242,6 +1333,22 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) + # storage types + num_provided_arg_stypes = 0 + # provided storage type argument names + provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() + provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types + if storage_type_dict is not None: + provided_arg_stype_names = [] + provided_arg_stype_data = [] + for k, v in storage_type_dict.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + provided_arg_stype_names.append(c_str(k)) + provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v])) + num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names)) + provided_arg_stype_names = c_array(ctypes.c_char_p, provided_arg_stype_names) + provided_arg_stype_data = c_array(ctypes.c_int, provided_arg_stype_data) + provided_arg_shape_data = [] # shape data # argument shape index in sdata, # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg @@ -1315,6 +1422,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer_names = [] shared_buffer_handles = [] for k, v in shared_buffer.items(): + assert(v.storage_type == 'default'), \ + "shared_buffer is expected to only contain NDArrays with default storage" shared_buffer_names.append(c_str(k)) shared_buffer_handles.append(v.handle) shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) @@ -1354,6 +1463,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, mx_uint(len(shared_arg_name_list)), c_array(ctypes.c_char_p, shared_arg_name_list), ctypes.byref(shared_buffer_len), @@ -1383,11 +1495,12 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer[k] = v # create in_args, arg_grads, and aux_states for the current executor - arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] - grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i])) + arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) \ + for i in range(num_in_args.value)] + grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) if arg_grad_handles[i] is not None else None for i in range(num_in_args.value)] - aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) + aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i])) for i in range(num_aux_states.value)] executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) @@ -1638,7 +1751,8 @@ def reshape(self, shape): """ return reshape(self, shape=shape) -def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs): +def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, + init=None, storage_type=None, **kwargs): """Creates a symbolic variable with specified name. Example usage: @@ -1692,6 +1806,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini if not isinstance(init, string_types): init = init.dumps() attr['__init__'] = init + if storage_type is not None: + attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[storage_type]) for k, v in kwargs.items(): if k.startswith('__') and k.endswith('__'): attr[k] = str(v) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 3ab44d0917a1..f9f596694182 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -10,11 +10,13 @@ import os import errno import logging +import scipy.sparse as sp import numpy as np import numpy.testing as npt +import numpy.random as rnd import mxnet as mx from .context import Context -from .ndarray import array +from .ndarray import array, _STORAGE_TYPE_STR_TO_ID from .symbol import Symbol try: import requests @@ -66,6 +68,51 @@ def random_arrays(*shapes): return arrays +def random_sample(population, k): + """Return a k length list of the elements chosen from the population sequence.""" + assert 0 <= k <= len(population) + population_copy = population[:] + np.random.shuffle(population_copy) + return population_copy[0:k] + + +# TODO(haibin) also include types in arguments +def rand_sparse_ndarray(shape, storage_type, density=None): + """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ + density = rnd.rand() if density is None else density + if storage_type == 'row_sparse': + # TODO(haibin) support high dim sparse ndarray + assert(len(shape) < 3) + prod = np.prod(shape) + num_cols = int(prod / shape[0]) + # sample index + idx_sample = rnd.rand(shape[0]) + indices = np.argwhere(idx_sample < density).flatten() + if indices.shape[0] == 0: + result = mx.sparse_nd.zeros('row_sparse', shape) + return result, (np.array([]), np.array([], dtype='int32')) + # generate random values + val = rnd.rand(indices.shape[0], num_cols) + arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int32) + return arr, (val, indices) + elif storage_type == 'csr': + assert(len(shape) == 2) + csr = sp.rand(shape[0], shape[1], density=density, format='csr') + result = mx.sparse_nd.csr(csr.data, csr.indptr, csr.indices, shape) + return result, (csr.indptr, csr.indices, csr.data) + else: + assert(False), "unknown storage type" + +def rand_ndarray(shape, storage_type, density=None): + if storage_type == 'default': + arr = mx.nd.array(random_arrays(shape)) + else: + arr, _ = rand_sparse_ndarray(shape, storage_type, density=density) + return arr + +def rand_shape_2d(): + return (rnd.randint(1, 10), rnd.randint(1, 10)) + def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version of NumPy. @@ -297,7 +344,8 @@ def _parse_location(sym, location, ctx): % (str(set(sym.list_arguments())), str(set(location.keys())))) else: location = {k: v for k, v in zip(sym.list_arguments(), location)} - location = {k: mx.nd.array(v, ctx=ctx) for k, v in location.items()} + location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \ + else v for k, v in location.items()} return location @@ -418,7 +466,8 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, - atol=None, grad_nodes=None, use_forward_train=True, ctx=None): + atol=None, grad_nodes=None, use_forward_train=True, ctx=None, + grad_stype_dict=None): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -435,7 +484,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto - if type is dict of str -> numpy.ndarray maps the name of arguments to the corresponding numpy.ndarray. *In either case, value of all the arguments must be provided.* - aux_states : ist or tuple or dict, optional + aux_states : list or tuple or dict, optional The auxiliary states required when generating the executor for the symbol. numeric_eps : float, optional Delta for the finite difference method that approximates the gradient. @@ -447,6 +496,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto Whether to use is_train=True when computing the finite-difference. ctx : Context, optional Check the gradient computation on the specified device. + grad_stype_dict : dict of str->str, optional + Storage type dictionary for gradient ndarrays. References --------- ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py @@ -470,7 +521,7 @@ def random_projection(shape): location_npy = {k:v.asnumpy() for k, v in location.items()} aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) if aux_states is not None: - aux_states_npy = {k:v.asnumpy() for k, v in aux_states.items()} + aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()} else: aux_states_npy = None if grad_nodes is None: @@ -497,6 +548,11 @@ def random_projection(shape): + [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))]) args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + if grad_stype_dict is not None: + assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" + for k, v in grad_stype_dict.items(): + if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default': + args_grad[k] = mx.nd.cast_storage(args_grad[k], storage_type=v) executor = out.bind(ctx, grad_req=grad_req, args=location, args_grad=args_grad, aux_states=aux_states) @@ -588,8 +644,8 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, g[:] = 0 executor.forward(is_train=False) - outputs = [x.asnumpy() for x in executor.outputs] + outputs = [x.asnumpy() for x in executor.outputs] for output_name, expect, output in zip(sym.list_outputs(), expected, outputs): assert_almost_equal(expect, output, rtol, atol, ("EXPECTED_%s"%output_name, "FORWARD_%s"%output_name)) @@ -657,14 +713,29 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} - args_grad_data = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + # args_grad_data should be casted to storage type if hinted + # TODO(haibin) this is a temporary solution for testing. remove later + attrs = sym.attr_dict() + args_grad_data = {} + for k, v in args_grad_npy.items(): + attr = attrs.get(k, {}) + grad_stype = attr.get('grad_stype_hint', None) + nd = mx.nd.array(v, ctx=ctx) + if grad_stype is not None: + out = mx.nd.cast_storage(nd, storage_type=grad_stype) + args_grad_data[k] = out + else: + args_grad_data[k] = nd + if isinstance(grad_req, str): grad_req = {k:grad_req for k in sym.list_arguments()} elif isinstance(grad_req, (list, tuple)): grad_req = {k:v for k, v in zip(sym.list_arguments(), grad_req)} - executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) + executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, + aux_states=aux_states, grad_req=grad_req) executor.forward(is_train=True) + if isinstance(out_grads, (tuple, list)): out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads] elif isinstance(out_grads, (dict)): diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9d60c8615027..91ac04021268 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -154,6 +154,39 @@ int MXNDArrayCreateEx(const mx_uint *shape, API_END(); } +int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out) { + API_BEGIN(); + std::vector aux_types; + std::vector aux_shapes; + auto shape_start = aux_shape; + for (size_t i = 0; i < num_aux; i++) { + // types + aux_types.push_back(aux_type[i]); + // shapes + aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]); + shape_start += aux_ndims[i]; + } + *out = new NDArray( + NDArrayStorageType(storage_type), + TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype, aux_types, aux_shapes); + API_END(); +} + + int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out) { @@ -287,6 +320,16 @@ int MXNDArraySlice(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArraySliceEx(NDArrayHandle handle, + mx_uint slice_begin, + mx_uint slice_end, + NDArrayHandle out) { + NDArray *ptr = static_cast(out); + API_BEGIN(); + static_cast(handle)->SliceEx(slice_begin, slice_end, ptr); + API_END(); +} + int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out) { @@ -333,6 +376,18 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + if (!arr->is_none()) { + *out_storage_type = arr->storage_type(); + } else { + *out_storage_type = kUndefinedStorage; + } + API_END(); +} + int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { @@ -382,6 +437,32 @@ int MXNDArrayGetDType(NDArrayHandle handle, API_END(); } +int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out_type = arr->aux_type(i); + API_END(); +} + +int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->aux_ndarray(i)); + API_END(); +} + +int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->data_ndarray()); + API_END(); +} + int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id) { diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index d8857f80635d..f2cad238a71b 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -58,6 +58,8 @@ struct MXAPIThreadLocalEntry { std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning type flags */ std::vector arg_types, out_types, aux_types; + /*! \brief result holder for returning storage types */ + std::vector arg_storage_types, out_storage_types, aux_storage_types; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ca49402ecf7e..a335209cd9fa 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -173,6 +173,9 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes * \param provided_arg_dtype_names argument name list of provided dtypes * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec * \param shared_arg_name_list parameter name list passed from _bind_ith_exec * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec @@ -205,6 +208,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, const mx_uint num_provided_arg_dtypes, const char** provided_arg_dtype_names, const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, int* shared_buffer_len, @@ -255,6 +261,23 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, } } + // setup arg_stype_map + std::unordered_map arg_stype_map; + if (nullptr == provided_arg_stypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__storage_type__")) { + arg_stype_map[arg_name] = kDefaultStorage; + } + } + } else { // use user input type_dict + // create stype map for in_args and aux_states + arg_stype_map.reserve(num_provided_arg_stypes); + for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { + arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; + } + } + // create default ctx Context ctx = Context::Create(static_cast(dev_type), dev_id); // create ctx map @@ -395,9 +418,10 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector aux_state_vec; *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, - aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec, - shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, - use_shared_buffer? &shared_buffer_map : nullptr, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); // copy ndarray ptrs to ret->handles so that front end diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 0be1d3574dd9..8d190597ab0b 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file c_api_symbolic.cc + * \file c_api_ndarray.cc * \brief C API of mxnet */ @@ -16,6 +16,8 @@ #include "../common/utils.h" #include "../ndarray/autograd.h" +#define IMPERATIVE_EXEC_DEBUG 0 + using namespace mxnet; using mxnet::autograd::AutogradRuntime; @@ -122,16 +124,18 @@ void SetContext(Context* p_ctx, ctx = Context::CPU(); } } - +// Set the shape, dtype and storage type void SetShapeType(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, const std::vector& ndinputs, const int& infered_num_outputs, - std::vector* p_ndoutputs) { + std::vector* p_ndoutputs, + int* dispatch_stype) { std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -167,9 +171,41 @@ void SetShapeType(const nnvm::Op* op, CHECK(infertype[op](attrs, &in_types, &out_types)); CHECK_EQ(out_types.size(), static_cast(infered_num_outputs)); + // infer storage type + auto& in_storage_types = ret->arg_storage_types; + auto& out_storage_types = ret->out_storage_types; + in_storage_types.clear(); + out_storage_types.clear(); + + for (auto& i : ndinputs) { + in_storage_types.push_back(i.storage_type()); + } + for (auto& i : ndoutputs) { + out_storage_types.push_back(i.storage_type()); + } + if (inferstorage.count(op)) { + CHECK(inferstorage[op](attrs, &in_storage_types, &out_storage_types)); + CHECK_EQ(out_storage_types.size(), static_cast(infered_num_outputs)); + } else { +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FInferStorageType not present."; +#endif + } + + bool contains_non_default = common::ContainsNonDefaultStorage(in_storage_types); + contains_non_default |= common::ContainsNonDefaultStorage(out_storage_types); + int kNonDefaultStorage = -2; + *dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + for (int i = 0; i < infered_num_outputs; ++i) { + NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (ndoutputs[i].is_none()) { - ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + // If failed to infer the storage type, assume the output storage is dense + if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) { + ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + } else { + ndoutputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); + } } else { CHECK_EQ(ndoutputs[i].shape(), out_shapes[i]) << i << "th output has invalid shape. " @@ -216,23 +252,20 @@ void SetDependency(std::vector *p_read_vars, } CHECK_LE(ntmp, 1) << "Only support 1 temp space request"; } - - for (auto& i : ndinputs) { - read_vars.push_back(i.var()); - } - for (auto& i : ndoutputs) { - write_vars.push_back(i.var()); - } + for (auto& i : ndinputs) read_vars.emplace_back(i.var()); + for (auto& i : ndoutputs) write_vars.emplace_back(i.var()); if (mutate.count(op)) { auxidx = mutate[op](attrs); std::sort(auxidx.begin(), auxidx.end()); - for (auto & i : auxidx) { - write_vars.push_back(ndinputs[i].var()); + for (auto& i : auxidx) { + auto var = ndinputs[i].var(); + write_vars.push_back(var); } } Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars); } + void PushFCompute(const FCompute& fn, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, @@ -242,23 +275,61 @@ void PushFCompute(const FCompute& fn, const std::vector& requested, const std::vector& ndinputs, const std::vector& ndoutputs) { + using namespace common; bool is_train = AutogradRuntime::Get()->IsTraining(); Engine::Get()->PushAsync( [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train]( RunContext rctx, engine::CallbackOnComplete on_complete) { std::vector input_blobs, output_blobs; - for (auto& i : ndinputs) { - input_blobs.push_back(i.data()); - } - for (auto& i : ndoutputs) { - output_blobs.push_back(i.data()); - } + std::vector temp_in; + std::vector temp_out; OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; - std::vector req(output_blobs.size(), kWriteTo); - fn(attrs, opctx, input_blobs, req, output_blobs); + if (ctx.dev_mask() == gpu::kDevMask) { +#if MXNET_USE_CUDA + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fn(attrs, opctx, input_blobs, req, output_blobs); + // cast to original storage type, if necessary + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + rctx.get_stream()->Wait(); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fn(attrs, opctx, input_blobs, req, output_blobs); + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + } + on_complete(); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); +} + +void PushFComputeEx(const FComputeEx& fn, + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& ndinputs, + const std::vector& ndoutputs) { + Engine::Get()->PushAsync( + [ctx, attrs, fn, ndinputs, ndoutputs, requested]( + RunContext rctx, + engine::CallbackOnComplete on_complete) { + std::vector input_blobs, output_blobs; + OpContext opctx{false, rctx, + engine::CallbackOnComplete(), + requested}; + std::vector req(ndoutputs.size(), kWriteTo); + fn(attrs, opctx, ndinputs, req, ndoutputs); if (ctx.dev_mask() == gpu::kDevMask) { rctx.get_stream()->Wait(); } @@ -327,8 +398,6 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, NDArrayHandle *inputs, int *num_outputs, NDArrayHandle **outputs) { - static auto& fcpu = nnvm::Op::GetAttr("FCompute"); - static auto& fgpu = nnvm::Op::GetAttr("FCompute"); static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); static auto& createop = nnvm::Op::GetAttr("FCreateLayerOp"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); @@ -337,20 +406,23 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, int infered_num_outputs; int num_visible_outputs; - SetNumOutputs(op, attrs, num_inputs, - &infered_num_outputs, &num_visible_outputs); + SetNumOutputs(op, attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); std::vector ndinputs, ndoutputs; SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outarray); + num_outputs, infered_num_outputs, num_visible_outputs, outarray); if (ndfunc.count(op)) { ndfunc[op](attrs, ndinputs, &ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "NDArray function executed."; +#endif } else { // TODO(piiswrong): infer ctx Context ctx; + int storage_type; SetContext(&ctx, attrs, num_inputs, ndinputs, infered_num_outputs, ndoutputs); - SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs); + SetShapeType(op, attrs, ctx, ndinputs, infered_num_outputs, &ndoutputs, &storage_type); std::vector read_vars, write_vars; std::vector requested; @@ -358,20 +430,24 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, SetDependency(&read_vars, &write_vars, &requested, &auxidx, op, attrs, ctx, ndinputs, ndoutputs); - FCompute fn; - if (ctx.dev_mask() == cpu::kDevMask && fcpu.count(op)) { - fn = fcpu[op]; - } else if (ctx.dev_mask() == gpu::kDevMask && fgpu.count(op)) { - fn = fgpu[op]; - } - - if (fn) { + FCompute fn = common::GetFCompute(op, ctx); + FComputeEx fcomp_ex = common::GetFComputeEx(op, ctx, storage_type); + if (fcomp_ex) { + PushFComputeEx(fcomp_ex, op, attrs, ctx, read_vars, write_vars, requested, + ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FComputeEx executed."; +#endif + } else if (fn) { if (AutogradRuntime::Get()->IsTraining()) { AutogradRuntime::Get()->RecordImperativeFCompute(op, attrs, &ndinputs, &ndoutputs); } PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, requested, ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "FCompute executed."; +#endif } else if (createop.count(op)) { std::shared_ptr opr( createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types)); @@ -381,11 +457,14 @@ void ImperativeInvokeImpl(const nnvm::NodeAttrs& attrs, } PushOperator(opr, op, attrs, ctx, read_vars, write_vars, requested, auxidx, ndinputs, ndoutputs); +#if IMPERATIVE_EXEC_DEBUG + LOG(INFO) << "CreateOp executed."; +#endif } else { LOG(FATAL) << "Operator " << op->name << " cannot be run; requires at least one of" - << " FCompute, NDArrayFunction, FCreateOperator be registered"; + << " FCompute, FComputeEx NDArrayFunction, FCreateOperator be registered"; } } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index cad9e604df60..f4737fa8b3e2 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -512,6 +512,58 @@ int MXSymbolInferShapePartial(SymbolHandle sym, &succ); } +// TODO(haibin) refactor with infer_type +int MXSymbolInferStorageType(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_storage_type_data, + mx_uint *in_storage_type_size, + const int **in_storage_type_data, + mx_uint *out_storage_type_size, + const int **out_storage_type_data, + mx_uint *aux_storage_type_size, + const int **aux_storage_type_data, + int *complete) { + nnvm::Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Graph g = Symbol2Graph(*s); + nnvm::StorageTypeVector arg_storage_types(g.indexed_graph().input_nodes().size(), + kUndefinedStorage); + if (keys == nullptr && num_args != 0) { + std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + CHECK_LE(num_args, read_only_args.size()); + for (mx_uint i = 0; i < num_args; ++i) { + arg_storage_types[read_only_args[i]] = arg_storage_type_data[i]; + } + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = arg_storage_type_data[i]; + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_storage_types, "InferStorageType"); + } + + g = nnvm::pass::InferStorageType(std::move(g), arg_storage_types, "__storage_type__"); + // copy back + CopyAttr(g.indexed_graph(), g.GetAttr("storage_type"), + &(ret->arg_storage_types), &(ret->out_storage_types), &(ret->aux_storage_types)); + + *in_storage_type_size = static_cast(ret->arg_storage_types.size()); + *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); + *out_storage_type_size = static_cast(ret->out_storage_types.size()); + *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); + *in_storage_type_size = static_cast(ret->arg_storage_types.size()); + *in_storage_type_data = dmlc::BeginPtr(ret->arg_storage_types); + *out_storage_type_size = static_cast(ret->out_storage_types.size()); + *out_storage_type_data = dmlc::BeginPtr(ret->out_storage_types); + *aux_storage_type_size = static_cast(ret->aux_storage_types.size()); + *aux_storage_type_data = dmlc::BeginPtr(ret->aux_storage_types); + *complete = (g.GetAttr("storage_type_num_unknown_nodes") == 0); + API_END(); +} + + int MXSymbolInferType(SymbolHandle sym, mx_uint num_args, const char** keys, diff --git a/src/common/utils.h b/src/common/utils.h index 789b4d14b9f2..5b80c4dcaa29 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -18,11 +18,142 @@ #include #include +#include +#include +#include namespace mxnet { +// forward declaration +namespace op { +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +} + namespace common { #if DMLC_USE_CXX11 +/* + * \brief Get the corresponding tensor blobs from default storage NDArrays. + * If any NDArray is of non-default storage, it is casted to default storage and + * the temporary NDArrays are stored in `temps`. When storage_fallback is false, + * and `MXNET_EXEC_STORAGE_FALLBACK` == 0, storage fallback is disallowed. + * \return true if any input is casted + */ +template +inline bool GetDefaultBlobs(const std::vector& nds, + std::vector *blobs, + std::vector *temps, + const OpContext& ctx, + bool storage_fallback = false) { + bool casted = false; + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + for (auto& nd : nds) { + if (nd.storage_type() != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + NDArray temp(nd.shape(), nd.ctx(), false); + op::CastStorageComputeImpl(ctx.get_stream(), nd, temp); + temps->push_back(temp); + blobs->push_back(temp.data()); + casted = true; + } else { + blobs->push_back(nd.data()); + } + } + return casted; +} + +template +inline void GetOutputBlobs(const std::vector& nds, + std::vector *blobs) { + for (auto& nd : nds) { + blobs->push_back(nd.data()); + } +} + +/* + * \brief Cast the NDArrays in `src` according to the storage types of the NDArrays + * in `dst`. The ones with default storage in `dst` are ignored. + * When storage_fallback is false, and `MXNET_EXEC_STORAGE_FALLBACK` == 0, + * storage fallback is disallowed. + */ +template +inline void CastNonDefaultStorage(const std::vector& dst, + const std::vector& src, + const OpContext& ctx, + bool storage_fallback = false) { + CHECK_GE(dst.size(), src.size()); + if (src.size() == 0) return; + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + size_t src_idx = 0; + for (size_t i = 0; i < dst.size(); i++) { + auto stype = dst[i].storage_type(); + if (stype != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + op::CastStorageComputeImpl(ctx.get_stream(), src[src_idx++], dst[i]); + } + } + CHECK_EQ(src_idx, src.size()) << "Not all src NDArrays are casted"; +} + +// Check if any storage type is not default storage +inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { + for (auto& i : vstorage) { + if (i != kUndefinedStorage && i != kDefaultStorage) return true; + } + return false; +} + +inline bool ContainsDefaultStorage(const std::vector& ndarrays) { + for (auto &nd : ndarrays) { + if (nd.storage_type() == kDefaultStorage) { + return true; + } + } + return false; +} + +inline FCompute GetFCompute(const Op* op, Context ctx) { + static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); + static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); + if (ctx.dev_mask() == cpu::kDevMask) { + return fcompute_cpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fcompute_gpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + +inline FComputeEx GetFComputeEx(const Op* op, Context ctx, int stype) { + static auto& fcpu = nnvm::Op::GetAttr(FCOMP_EX_CPU); + static auto& fgpu = nnvm::Op::GetAttr(FCOMP_EX_GPU); + if (stype == kDefaultStorage) return nullptr; + if (ctx.dev_mask() == cpu::kDevMask) { + return fcpu.get(op, nullptr); + } else if (ctx.dev_mask() == gpu::kDevMask) { + return fgpu.get(op, nullptr); + } + LOG(FATAL) << "Unknown device mask"; + return nullptr; +} + + // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { // This is resource efficient option. diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 16b55adc15e8..0d718df41c9e 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -8,11 +8,15 @@ #include #include #include "./exec_pass.h" +#include "../common/utils.h" #if MXNET_USE_MKL2017 == 1 #include #include "../operator/mkl/mkl_memory-inl.h" #include "../operator/mkl/mkl_util-inl.h" #endif + +#define EXEC_ATTACH_OP_DEBUG 0 + namespace mxnet { namespace op { @@ -24,9 +28,33 @@ namespace exec { // forward executor class ForwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + using namespace common; op_ctx.run_ctx = rctx; - op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + + // If any input ndarray contains non-default storage, + // we need to cast it to default storage and setup the tblobs again. For example, + // if any of the input ndarray changes, the updated value won't be reflected in the temporary + // ndarray with default storage. + in_data_.clear(); out_data_.clear(); aux_data_.clear(); + temp_in_.clear(); temp_out_.clear(); temp_aux_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + GetDefaultBlobs(in_array_, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(aux_array_, &aux_data_, &temp_aux_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(in_array_, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(aux_array_, &aux_data_, &temp_aux_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + op_->Forward(op_ctx, in_data_, req, out_data_, aux_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); + } #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); @@ -35,18 +63,14 @@ class ForwardOpExecutor : public OpExecutor { } void Setup() override { - in_data_.clear(); aux_data_.clear(); + // We need to tell whether in NDArray is input or aux for (size_t i = 0; i < in_array.size(); ++i) { if (!std::binary_search(aux_index_.begin(), aux_index_.end(), i)) { - in_data_.push_back(in_array[i].data()); + in_array_.emplace_back(in_array[i]); } else { - aux_data_.push_back(in_array[i].data()); + aux_array_.emplace_back(in_array[i]); } } - out_data_.resize(out_array.size()); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), [](const NDArray& nd) { - return nd.data(); - }); } Operator::ExecType exec_type() const override { return op_->exec_type(); @@ -62,12 +86,14 @@ class ForwardOpExecutor : public OpExecutor { std::shared_ptr op_; std::vector aux_index_; std::vector in_data_, out_data_, aux_data_; + std::vector in_array_, aux_array_, temp_in_, temp_aux_, temp_out_; }; // backward executor class BackwardOpExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + // TODO(haibin) support storage fallback for BackwardOpExecutor op_ctx.run_ctx = rctx; op_->Backward(op_ctx, out_grad_, in_data_, out_data_, req, in_grad_, aux_data_); @@ -135,23 +161,36 @@ class BackwardOpExecutor : public OpExecutor { // fcompute executor executor class FComputeExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + using namespace common; op_ctx.run_ctx = rctx; - fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + // setup blobs + // TODO(haibin) avoid repeating this if all inputs are already in default-storage. + { + in_data_.clear(); out_data_.clear(); + temp_in_.clear(); temp_out_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); + } + } #if MKL_EXPERIMENTAL == 1 mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - in_data_.resize(in_array.size()); - out_data_.resize(out_array.size()); - auto get_blob = [](const NDArray& nd) { - return nd.data(); - }; - std::transform(in_array.begin(), in_array.end(), in_data_.begin(), get_blob); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), get_blob); - } + void Setup() override {} Operator::ExecType exec_type() const override { return Operator::kSync; } @@ -159,28 +198,41 @@ class FComputeExecutor : public OpExecutor { : fcompute_(fcompute), attrs_(attrs) { } - static FCompute GetFCompute(const Op* op, Context ctx) { - static auto& fcompute_cpu = nnvm::Op::GetAttr("FCompute"); - static auto& fcompute_gpu = nnvm::Op::GetAttr("FCompute"); - if (ctx.dev_mask() == cpu::kDevMask) { - return fcompute_cpu.get(op, nullptr); - } else if (ctx.dev_mask() == gpu::kDevMask) { - return fcompute_gpu.get(op, nullptr); - } else { - LOG(FATAL) << "Unknown device mask"; - return nullptr; - } - } - private: FCompute fcompute_; NodeAttrs attrs_; std::vector in_data_, out_data_; + std::vector temp_in_, temp_out_; +}; + +// fcomputend executor +class FComputeExExecutor : public OpExecutor { + public: + void Run(RunContext rctx, bool is_gpu) override { + op_ctx.run_ctx = rctx; + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + } + void Setup() override { + in_data_ = in_array; + out_data_ = out_array; + } + Operator::ExecType exec_type() const override { + return Operator::kSync; + } + explicit FComputeExExecutor(FComputeEx fcompute, const NodeAttrs& attrs) + : fcompute_(fcompute), attrs_(attrs) { + } + + private: + FComputeEx fcompute_; + NodeAttrs attrs_; + std::vector in_data_, out_data_; }; // pass to attach operator executors Graph AttachOpExecs(Graph g) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::FMutateInputs; @@ -193,6 +245,7 @@ Graph AttachOpExecs(Graph g) { const auto& vctx = g.GetAttr("context"); const auto& saved_opr = g.GetAttr< std::unordered_map>>("saved_opr"); + const auto& dispatch_stypes = g.GetAttr("dispatch_stypes"); // get the graph const auto& idx = g.indexed_graph(); @@ -206,7 +259,12 @@ Graph AttachOpExecs(Graph g) { if (fmutate_inputs.count(inode.source->op())) { mutate_index = fmutate_inputs[inode.source->op()](inode.source->attrs); } - FCompute fcompute = FComputeExecutor::GetFCompute(inode.source->op(), vctx[i]); + FCompute fcompute = common::GetFCompute(inode.source->op(), vctx[i]); + FComputeEx fcompute_ex = + common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stypes[i]); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "dispatch storage type = " << dispatch_stypes[i]; +#endif if (fcreate_layer_op.count(inode.source->op())) { std::vector ishape; std::vector itype; @@ -222,19 +280,33 @@ Graph AttachOpExecs(Graph g) { inode.source->attrs, vctx[i], ishape, itype)); } ret[i] = std::make_shared(opr, mutate_index); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "ForwardOp for op " << inode.source->op()->name; +#endif } else if (is_layer_backward.get(inode.source->op(), false)) { CHECK_GE(inode.control_deps.size(), 1); uint32_t fwd_id = inode.control_deps[0]; CHECK(vctx[fwd_id] == vctx[i]); CHECK(ret[fwd_id] != nullptr); + CHECK_EQ(dispatch_stypes[i], kDefaultStorage) + << "BackwardOp doesn't handle non-default storage yet"; ret[i] = std::make_shared( dynamic_cast(ret[fwd_id].get())->op_, mxnet::op::OpPropGetOpProperty(inode.source->attrs), mutate_index); +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "BackwardOp for op " << inode.source->op()->name; +#endif + } else if (fcompute_ex != nullptr) { +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "FComputeEx for op " << inode.source->op()->name; +#endif + ret[i] = std::make_shared(fcompute_ex, inode.source->attrs); } else if (fcompute != nullptr) { +#if EXEC_ATTACH_OP_DEBUG + LOG(INFO) << "FCompute for op " << inode.source->op()->name; +#endif ret[i] = std::make_shared(fcompute, inode.source->attrs); - } else { - LOG(INFO) << "FCompute not registered " << inode.source->op()->name; } } g.attrs["op_execs"] = std::make_shared(ret); diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 8df6a3c5d3bb..20535be320d9 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -19,6 +19,12 @@ namespace exec { /*! \brief reuse graph definition */ using nnvm::Graph; +const int kBadStorageID = -1; +const int kExternalStorageID = -2; +const int kDynamicStorageID = -3; + +const int kNonDefaultStorage = -2; + /*! * \brief executor to execute an operator * This is a graph executor dependent interface @@ -26,7 +32,7 @@ using nnvm::Graph; */ class OpExecutor { public: - /*! \brief input arrays */ + /*! \brief input data arrays, which may be either input or aux */ std::vector in_array; /*! \brief output data arrays */ std::vector out_array; @@ -47,7 +53,7 @@ class OpExecutor { * This function call do not synchronize the stream. * \param rctx The runtime context passed in by environment. */ - virtual void Run(RunContext rctx) = 0; + virtual void Run(RunContext rctx, bool is_gpu) = 0; /*! \return the execution type */ virtual Operator::ExecType exec_type() const = 0; }; diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index d60c5e46e52c..de8411a7be95 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -12,6 +12,7 @@ #include "./exec_pass.h" #include "./graph_executor.h" #include "../engine/profiler.h" +#include "../common/utils.h" namespace mxnet { namespace exec { @@ -29,6 +30,30 @@ GraphExecutor::~GraphExecutor() { } } +inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } @@ -442,21 +467,25 @@ void GraphExecutor::Init(nnvm::Symbol symbol, data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; nnvm::DTypeVector arg_dtypes; + nnvm::StorageTypeVector arg_stypes; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& arg_name = idx[nid].source->attrs.name; + size_t eid = idx.entry_id(nid, 0); if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_states.size()); - data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; + data_entry_[eid] = aux_states[aux_top]; arg_shapes.push_back(aux_states[aux_top].shape()); arg_dtypes.push_back(aux_states[aux_top].dtype()); + arg_stypes.push_back(aux_states[aux_top].storage_type()); aux_state_map_.emplace(arg_name, aux_states[aux_top]); ++aux_top; } else { CHECK_LT(arg_top, in_args.size()); - data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; + data_entry_[eid] = in_args[arg_top]; arg_shapes.push_back(in_args[arg_top].shape()); arg_dtypes.push_back(in_args[arg_top].dtype()); + arg_stypes.push_back(in_args[arg_top].storage_type()); in_arg_map_.emplace(arg_name, in_args[arg_top]); if (kNullOp != grad_req_types[arg_top]) { grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]); @@ -464,6 +493,10 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } ++arg_top; } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << " as stype " + << data_entry_[eid].storage_type() << " (input)"; +#endif } // expand arg_shapes and arg_dtypes to contain backward inputs @@ -480,6 +513,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol, HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } + // TODO(haibin) better error message for infer_storage + g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); // Initialize the rest attributes of the graph. // This function can be called by regular bind @@ -496,6 +531,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -513,22 +549,37 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { // aux_states - aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], false, inferred_dtype); - aux_state_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign aux entry\t" << eid << "\t as stype " << inferred_stype; +#endif } else { // in_args - in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); - in_arg_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); data_entry_[eid] = in_arg_vec->back(); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << "\tas stype " << inferred_stype; +#endif + // Get the storage type for grad if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], false, inferred_dtype); - arg_grad_vec->back() = 0; + // Init based on storage type + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas stype " << grad_stype; +#endif grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); } @@ -540,33 +591,40 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, /*! * \brief If the requested ndarray's shape size is less than - * the corresponding shared_data_array's shape size, reuse - * the memory allocation; otherwise, create a zero ndarray. + * the corresponding shared_data_array's shape size and the + * storage type is default storage, reuse the memory allocation + * in shared_buffer; otherwise, create a zero ndarray. */ NDArray ReshapeOrCreate(const std::string& name, const TShape& dest_arg_shape, const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, const Context& ctx, std::unordered_map* shared_buffer) { + if (dest_arg_dtype != kDefaultStorage) { + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } auto it = shared_buffer->find(name); if (it != shared_buffer->end()) { if (it->second.shape().Size() >= dest_arg_shape.Size()) { // memory can be reused CHECK_EQ(it->second.dtype(), dest_arg_dtype) << "Requested arg array's dtype does not match the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), kDefaultStorage) + << "shared_buffer should only contain NDArrays with default storage type."; return it->second.Reshape(dest_arg_shape); } else { LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape << ", which is larger than already allocated shape " << it->second.shape() << ". Need to re-allocate. Consider putting default bucket key to be " << "the bucket taking the largest input for better memory sharing."; - it->second = NDArray(dest_arg_shape, ctx, false, dest_arg_dtype); - it->second = 0; + // the NDArrays in shared_buffer are guaranteed to be of default storage + it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); return it->second; } // arg_array.shape().Size() >= arg_shape.Size() } else { - auto p = shared_buffer->emplace(name, NDArray(dest_arg_shape, ctx, false, dest_arg_dtype)); - p.first->second = 0; - return p.first->second; + auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + shared_buffer->emplace(name, ret); + return ret; } // if (it != shared_buffer->end()) } @@ -579,6 +637,7 @@ NDArray ReshapeOrCreate(const std::string& name, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -598,9 +657,12 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; const std::string& arg_name = idx[nid].source->attrs.name; - if (mutable_nodes.count(nid)) { // aux_states - if (nullptr != shared_exec) { + // aux_states + if (mutable_nodes.count(nid)) { + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec->aux_state_map().at(arg_name).storage_type() == kDefaultStorage) { const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape." @@ -614,16 +676,18 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, << arg_name << " for the current executor"; aux_state_vec->emplace_back(aux_nd); } else { - aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], - false, inferred_dtype); - aux_state_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); } // if (has_shared_exec) data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; - } else { // in_args + } else { // in_args and grad for in_args if (shared_arg_names.count(arg_name)) { // model parameter - if (nullptr != shared_exec) { + // model parameter + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec->in_arg_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); CHECK_EQ(inferred_shape, in_arg_nd.shape()) << "Inferred shape does not match shared_exec.arg_array's shape" @@ -636,33 +700,43 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, " be resued for creating NDArray of the argument" << arg_name << " for the current executor"; in_arg_vec->emplace_back(in_arg_nd); - if (kNullOp == grad_req_types[arg_top]) { - arg_grad_vec->emplace_back(); - } else { + } else { + // doesn't have shared_exec, or non-default storage + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + } + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + if (nullptr != shared_exec && grad_stype == kDefaultStorage && + shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); - grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); - } // if (kNullOp == grad_req_types[arg_top]) - } else { // !has shared_exec - in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); - in_arg_vec->back() = 0; - if (kNullOp == grad_req_types[arg_top]) { - arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], - false, inferred_dtype); - arg_grad_vec->back() = 0; - grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); - } // if (kNullOp == grad_req_types[arg_top]) - } // if (has_shared_exec) + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); + } + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } } else { // !shared_arg_names.count(arg_name) + // model parameter in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, - in_arg_ctxes[arg_top], shared_buffer)); + inferred_stype, in_arg_ctxes[arg_top], + shared_buffer)); + // gradient for model parameter if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, - inferred_dtype, arg_grad_ctxes[arg_top], - shared_buffer)); + inferred_dtype, grad_stype, + arg_grad_ctxes[arg_top], shared_buffer)); grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } // if (shared_arg_names.count(arg_name)) @@ -685,14 +759,35 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { const auto& idx = g.indexed_graph(); + // dispatch based on stype per operator + const auto& vstorage_type = g.GetAttr("storage_type"); + nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); + for (size_t nid = 0; nid < idx.num_nodes(); nid++) { + const auto& inode = idx[nid]; + auto num_outputs = inode.source->num_outputs(); + auto num_inputs = inode.inputs.size(); + nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); + for (size_t i = 0; i < num_inputs; i++) { + auto e = inode.inputs[i]; + vs[i] = vstorage_type[idx.entry_id(e)]; + CHECK_NE(vs[i], kUndefinedStorage); + } + for (uint32_t i = 0; i < num_outputs; ++i) { + uint32_t eid = idx.entry_id(nid, i); + vs[i + num_inputs] = vstorage_type[eid]; + } + bool contains_non_default = common::ContainsNonDefaultStorage(vs); + dispatch_stypes[nid] = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + } + g.attrs["dispatch_stypes"] = std::make_shared(std::move(dispatch_stypes)); + + // data entries for output gradients for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; } { // memory allocator - const int kBadStorageID = -1; - const int kExternalStorageID = -2; nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; @@ -702,6 +797,9 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, data_entry_[eid] = kv.second; arg_storage_id[eid] = kExternalStorageID; } + for (size_t i = 0; i < idx.num_node_entries(); i++) { + if (vstorage_type[i] != kDefaultStorage) arg_storage_id[i] = kDynamicStorageID; + } g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); g = nnvm::ApplyPass(g, "PlanMemory"); } @@ -759,6 +857,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -778,6 +877,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + nnvm::DTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -789,6 +889,10 @@ void GraphExecutor::Init(nnvm::Symbol symbol, if (arg_dtype_map.end() != it2) { arg_dtypes[i] = it2->second; } + auto it3 = arg_stype_map.find(name); + if (arg_stype_map.end() != it3) { + arg_stypes[i] = it3->second; + } } g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { @@ -801,17 +905,21 @@ void GraphExecutor::Init(nnvm::Symbol symbol, HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } + // TODO(jun/haibin) check if InferShape is successful, and give warnings instead of segfault later + g = nnvm::pass::InferStorageType(g, arg_stypes, "__storage_type__"); // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. if (nullptr == shared_buffer) { // regular simple bind InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); @@ -864,6 +972,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // initialize the memory of each entries void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; // get the graph @@ -872,20 +981,29 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { const auto& vdtype = graph_.GetAttr("dtype"); const auto& vshape = graph_.GetAttr("shape"); const auto& vstorage = graph_.GetAttr("storage_id"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); const auto& vctx = graph_.GetAttr("context"); CHECK_EQ(idx.num_node_entries(), vshape.size()); CHECK_EQ(idx.num_node_entries(), vdtype.size()); CHECK_EQ(idx.num_node_entries(), vstorage.size()); CHECK_EQ(data_entry_.size(), vshape.size()); std::vector data_context(idx.num_node_entries()); + std::vector data_storage_type(idx.num_node_entries(), kUndefinedStorage); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) { - data_context[idx.entry_id(nid, i)] = vctx[nid]; + auto eid = idx.entry_id(nid, i); + data_context[eid] = vctx[nid]; + CHECK_NE(vstorage_type[nid], kUndefinedStorage); + data_storage_type[eid] = (NDArrayStorageType) vstorage_type[nid]; } } // information about the pool - using PoolEntry = std::pair; + struct PoolEntry { + Context ctx; + size_t bytes; + NDArrayStorageType stype; + }; std::vector pool_info; // assign array to head gradient @@ -893,26 +1011,36 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { uint32_t nid = idx.input_nodes().at(i); uint32_t oid = head_grad_map_.at(idx[nid].source); uint32_t eid = idx.entry_id(idx.outputs()[oid]); + NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid]; CHECK_NE(vshape[eid].ndim(), 0U); CHECK_NE(vdtype[eid], -1); - data_entry_[idx.entry_id(nid, 0)] = - NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + auto data_eid = idx.entry_id(nid, 0); + // initialize based on storage_type + if (stype != kDefaultStorage) { + data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]); + } else { + data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit head_g entry\t" << data_eid << "\tas stype " << stype; +#endif } // get maximum bytes in each pool for (size_t i = 0; i < vshape.size(); ++i) { if (!data_entry_[i].is_none()) continue; size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]); int storage_id = vstorage[i]; + // skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID if (storage_id < 0) continue; size_t sid = static_cast(storage_id); if (sid >= pool_info.size()) { - pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0)}); + pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage}); } PoolEntry& info = pool_info[sid]; - if (info.second == 0) { - info = PoolEntry{data_context[i], bytes}; + if (info.bytes == 0) { + info = PoolEntry{data_context[i], bytes, data_storage_type[i]}; } else { - info.second = std::max(info.second, bytes); + info.bytes = std::max(info.bytes, bytes); } } // construct the re-use pool, if needed @@ -933,13 +1061,14 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { sorted_pool_index.push_back(i); } auto pool_comparator = [&pool_info](int lhs, int rhs){ - return pool_info[lhs].second > pool_info[rhs].second; + return pool_info[lhs].bytes > pool_info[rhs].bytes; }; std::sort(sorted_pool_index.begin(), sorted_pool_index.end(), pool_comparator); for (size_t i : sorted_pool_index) { - const Context& ctx = pool_info[i].first; - size_t bytes = pool_info[i].second; + const Context& ctx = pool_info[i].ctx; + size_t bytes = pool_info[i].bytes; + NDArrayStorageType storage_type = pool_info[i].stype; bool allocated = false; for (auto it = free_pool.lower_bound(bytes); it != free_pool.end(); ++it) { if (it->second.ctx() == ctx && it->first >= bytes) { @@ -964,15 +1093,22 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } CHECK_EQ(data_pool_.size(), pool_info.size()); // assign the data entries - for (size_t i = 0; i < data_entry_.size(); ++i) { // avoid pre-allocated arrays if (!data_entry_[i].is_none()) continue; // assign allocated array by storage id int storage_id = vstorage[i]; - CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; - const NDArray& src = data_pool_.at(storage_id); - data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + auto storage_type = (NDArrayStorageType) vstorage_type[i]; + if (storage_type == kDefaultStorage) { + CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; + const NDArray& src = data_pool_.at(storage_id); + data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + } else { + data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit data entry\t" << i << "\tas stype " << storage_type; +#endif } } @@ -987,11 +1123,28 @@ void GraphExecutor::InitCachedOps() { const auto& vctx = graph_.GetAttr("context"); const auto& addto_entry = graph_.GetAttr >("addto_entry"); const auto& skip_plus_node = graph_.GetAttr >("skip_plus_node"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); op_nodes_.resize(idx.num_nodes()); // setup the array and requirements. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; +#if EXECUTOR_DEBUG + if (inode.source->is_variable()) { + LOG(INFO) << "node " << nid << " var"; + } else { + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; + auto exec = op_execs[nid]; + for (const auto& e : inode.inputs) { + auto eid = idx.entry_id(e); + LOG(INFO) << "\t\tinput " << eid << " stype: " << vstorage_type[eid]; + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + LOG(INFO) << "\t\toutput " << eid << " stype: " << vstorage_type[eid]; + } + } +#endif if (inode.source->is_variable()) continue; #if MXNET_USE_PROFILER op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); @@ -1068,7 +1221,7 @@ void GraphExecutor::InitCachedOps() { if (is_async) { exec->op_ctx.async_on_complete = on_complete; } - exec->Run(ctx); + exec->Run(ctx, is_gpu); // call on complete only if it is async op if (!is_async) { if (is_gpu) { @@ -1213,6 +1366,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1; #endif Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling); nid = seg_op.topo_end - 1; @@ -1225,6 +1381,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { if (op_nodes_[nid].skip_exec_node) continue; opnode.exec->op_ctx.is_train = is_train; if (opnode.exec->exec_type() == Operator::kCrossDeviceCopy) { +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " for CrossDeviceCopy"; +#endif CHECK_EQ(inode.inputs.size(), 1U); CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); @@ -1234,6 +1393,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid; #endif Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { @@ -1298,7 +1460,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, RunContext ctx, Engine::CallbackOnComplete on_complete) { // Run all opr in the sub-graph for (auto &exec : exec_list) { - exec->Run(ctx); + exec->Run(ctx, is_gpu); } if (is_gpu) { #if MXNET_USE_CUDA @@ -1333,6 +1495,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_args, @@ -1343,7 +1506,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, auto exec = new exec::GraphExecutor(); exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - arg_shape_map, arg_dtype_map, + arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names, in_args, arg_grads, aux_states, shared_buffer, shared_exec); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index d5a4e8c3aa6c..308eddba8b80 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -19,6 +19,8 @@ #include #include "./exec_pass.h" +#define EXECUTOR_DEBUG 0 + namespace mxnet { using NodeOperatorMap = std::unordered_map& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -126,6 +129,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -138,6 +142,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -186,7 +191,8 @@ class GraphExecutor : public Executor { std::vector op_nodes_; // internal data entry of each node std::vector data_entry_; - // internal data pool of allocated entries + // internal data pool of allocated entries. + // these allocated entries can be used for static memory sharing between executors. std::vector data_pool_; // output arrays std::vector output_arrays_; diff --git a/src/executor/inplace_addto_detect_pass.cc b/src/executor/inplace_addto_detect_pass.cc index 75a2608313aa..1a0bc9cb40a6 100644 --- a/src/executor/inplace_addto_detect_pass.cc +++ b/src/executor/inplace_addto_detect_pass.cc @@ -44,6 +44,8 @@ Graph DetectInplaceAddTo(Graph g) { uint32_t eid_rhs = idx.entry_id(inode.inputs[1]); if (ref_count[eid_rhs] != 1) continue; if (inode.inputs[0].node_id >= inode.inputs[1].node_id) continue; + // TODO(haibin) support inplace addto for Dynamic Storage + if (storage_id[eid_rhs] == kDynamicStorageID) continue; CHECK_NE(storage_id[eid_rhs], sid); storage_id[eid_rhs] = sid; addto_entry[eid_rhs] = 1; diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index a51e24503785..91488c065033 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -23,7 +23,7 @@ namespace io { class BatchLoader : public IIterator { public: explicit BatchLoader(IIterator *base): - base_(base), head_(1), num_overflow_(0) { + head_(1), num_overflow_(0), base_(base) { } virtual ~BatchLoader(void) { @@ -34,7 +34,7 @@ class BatchLoader : public IIterator { std::vector > kwargs_left; // init batch param, it could have similar param with kwargs_left = param_.InitAllowUnknown(kwargs); - // Init space for out_ + // Init space for out out_.inst_index = new unsigned[param_.batch_size]; out_.batch_size = param_.batch_size; out_.data.clear(); @@ -51,6 +51,7 @@ class BatchLoader : public IIterator { } head_ = 1; } + virtual bool Next(void) { out_.num_batch_padd = 0; out_.batch_size = param_.batch_size; @@ -110,23 +111,25 @@ class BatchLoader : public IIterator { return out_; } - private: + protected: /*! \brief batch parameters */ BatchParam param_; /*! \brief output data */ TBlobBatch out_; - /*! \brief base iterator */ - IIterator *base_; /*! \brief on first */ int head_; /*! \brief number of overflow instances that readed in round_batch mode */ int num_overflow_; + /*! \brief tensor to hold data */ + std::vector data_; + + private: + /*! \brief base iterator */ + IIterator *base_; /*! \brief data shape */ std::vector shape_; /*! \brief unit size */ std::vector unit_size_; - /*! \brief tensor to hold data */ - std::vector data_; // initialize the data holder by using from the first batch. inline void InitData(const DataInst& first_batch) { shape_.resize(first_batch.data.size()); diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc new file mode 100644 index 000000000000..aad54160ec13 --- /dev/null +++ b/src/io/iter_libsvm.cc @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file iter_libsvm.cc + * \brief define a LibSVM Reader to read in arrays + */ +#include +#include +#include +#include +#include +#include "./iter_sparse_prefetcher.h" +#include "./iter_sparse_batchloader.h" + +namespace mxnet { +namespace io { +// LibSVM parameters +struct LibSVMIterParam : public dmlc::Parameter { + /*! \brief path to data libsvm file */ + std::string data_libsvm; + /*! \brief data shape */ + TShape data_shape; + /*! \brief path to label libsvm file */ + std::string label_libsvm; + /*! \brief label shape */ + TShape label_shape; + // declare parameters + DMLC_DECLARE_PARAMETER(LibSVMIterParam) { + DMLC_DECLARE_FIELD(data_libsvm) + .describe("The input LibSVM file or a directory path."); + DMLC_DECLARE_FIELD(data_shape) + .describe("The shape of one example."); + DMLC_DECLARE_FIELD(label_libsvm).set_default("NULL") + .describe("The input LibSVM file or a directory path. " + "If NULL, all labels will be read from ``data_libsvm``."); + index_t shape1[] = {1}; + DMLC_DECLARE_FIELD(label_shape).set_default(TShape(shape1, shape1 + 1)) + .describe("The shape of one label."); + } +}; + +class LibSVMIter: public SparseIIterator { + public: + LibSVMIter() {} + virtual ~LibSVMIter() {} + + // intialize iterator loads data in + virtual void Init(const std::vector >& kwargs) { + param_.InitAllowUnknown(kwargs); + data_parser_.reset(dmlc::Parser::Create(param_.data_libsvm.c_str(), + 0, 1, "libsvm")); + CHECK_EQ(param_.data_shape.ndim(), 1) << "dimension of data_shape is expected to be 1"; + if (param_.label_libsvm != "NULL") { + label_parser_.reset(dmlc::Parser::Create(param_.label_libsvm.c_str(), + 0, 1, "libsvm")); + CHECK_GT(param_.label_shape.Size(), 1) + << "label_shape is not expected to be (1,) when param_.label_libsvm is set."; + } else { + CHECK_EQ(param_.label_shape.Size(), 1) + << "label_shape is expected to be (1,) when param_.label_libsvm is NULL"; + } + // both data and label are of CSRStorage in libsvm format + if (param_.label_shape.Size() > 1) { + out_.data.resize(6); + } else { + // only data is of CSRStorage in libsvm format. + out_.data.resize(4); + } + } + + virtual void BeforeFirst() { + data_parser_->BeforeFirst(); + if (label_parser_.get() != nullptr) { + label_parser_->BeforeFirst(); + } + data_ptr_ = label_ptr_ = 0; + data_size_ = label_size_ = 0; + inst_counter_ = 0; + end_ = false; + } + + virtual bool Next() { + if (end_) return false; + while (data_ptr_ >= data_size_) { + if (!data_parser_->Next()) { + end_ = true; return false; + } + data_ptr_ = 0; + data_size_ = data_parser_->Value().size; + } + out_.index = inst_counter_++; + CHECK_LT(data_ptr_, data_size_); + const auto data_row = data_parser_->Value()[data_ptr_++]; + // data, indices and indptr + out_.data[0] = AsDataBlob(data_row); + out_.data[1] = AsIdxBlob(data_row); + out_.data[2] = AsIndPtrPlaceholder(data_row); + + if (label_parser_.get() != nullptr) { + while (label_ptr_ >= label_size_) { + CHECK(label_parser_->Next()) + << "Data LibSVM's row is smaller than the number of rows in label_libsvm"; + label_ptr_ = 0; + label_size_ = label_parser_->Value().size; + } + CHECK_LT(label_ptr_, label_size_); + const auto label_row = label_parser_->Value()[label_ptr_++]; + // data, indices and indptr + out_.data[3] = AsDataBlob(label_row); + out_.data[4] = AsIdxBlob(label_row); + out_.data[5] = AsIndPtrPlaceholder(label_row); + } else { + out_.data[3] = AsScalarLabelBlob(data_row); + } + return true; + } + + virtual const DataInst &Value(void) const { + return out_; + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + if (is_data) return kCSRStorage; + return param_.label_shape.Size() > 1 ? kCSRStorage : kDefaultStorage; + } + + virtual const TShape GetShape(bool is_data) const { + if (is_data) return param_.data_shape; + return param_.label_shape; + } + + private: + inline TBlob AsDataBlob(const dmlc::Row& row) { + const real_t* ptr = row.value; + TShape shape(mshadow::Shape1(row.length)); + return TBlob((real_t*) ptr, shape, cpu::kDevMask); // NOLINT(*) + } + + inline TBlob AsIdxBlob(const dmlc::Row& row) { + const uint32_t* ptr = row.index; + TShape shape(mshadow::Shape1(row.length)); + return TBlob((int32_t*) ptr, shape, cpu::kDevMask, CSR_IDX_DTYPE); // NOLINT(*) + } + + inline TBlob AsIndPtrPlaceholder(const dmlc::Row& row) { + return TBlob(nullptr, mshadow::Shape1(0), cpu::kDevMask, CSR_IND_PTR_TYPE); + } + + inline TBlob AsScalarLabelBlob(const dmlc::Row& row) { + const real_t* ptr = row.label; + return TBlob((real_t*) ptr, mshadow::Shape1(1), cpu::kDevMask); // NOLINT(*) + } + + LibSVMIterParam param_; + // output instance + DataInst out_; + // internal instance counter + unsigned inst_counter_{0}; + // at end + bool end_{false}; + // label parser + size_t label_ptr_{0}, label_size_{0}; + size_t data_ptr_{0}, data_size_{0}; + std::unique_ptr > label_parser_; + std::unique_ptr > data_parser_; +}; + + +DMLC_REGISTER_PARAMETER(LibSVMIterParam); + +MXNET_REGISTER_IO_ITER(LibSVMIter) +.describe(R"code(Returns the LibSVM file iterator. This iterator is experimental and +should be used with care. + +The input data is similar to libsvm file format, except that the indices are expected to be +zero-based instead of one-based. Details of the libsvm format are available at +`https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/` + +In this function, the `data_shape` parameter is used to set the shape of each line of the data. +The dimension of both `data_shape` and `label_shape` are expected to be 1. + +When `label_libsvm` is set to ``NULL``, both data and label are read from the same file specified +by `data_libsvm`. Otherwise, data is read from `data_libsvm` and label from `label_libsvm`, +in this case, if `data_libsvm` contains label, it will ignored. + +The `LibSVMIter` only support `round_batch` parameter set to ``True`` for now. So, if `batch_size` +is 3 and there are 4 total rows in libsvm file, 2 more examples +are consumed at the first round. If `reset` function is called after first round, +the call is ignored and remaining examples are returned in the second round. + +If ``data_libsvm = 'data/'`` is set, then all the files in this directory will be read. + +Examples:: + + // Contents of libsvm file ``data.t``. + 1.0 0:0.5 2:1.2 + -2.0 + -3.0 0:0.6 1:2.4 2:1.2 + 4 2:-1.2 + + // Creates a `LibSVMIter` with `batch_size`=3. + LibSVMIter = mx.io.LibSVMIter(data_libsvm = 'data.t', data_shape = (3,), + batch_size = 3) + + // The first batch (data and label) + [[ 0.5 0. 1.2 ] + [ 0. 0. 0. ] + [ 0.6 2.4 1.2 ]] + + [ 1. -2. -3.] + + // The second batch (data and label) + [[ 0. 0. -1.2 ] + [ 0.5 0. 1.2 ] + [ 0. 0. 0. ]] + + [ 4. 1. -2.] + + // Contents of libsvm file ``label.t`` + 1.0 + -2.0 0:0.125 + -3.0 2:1.2 + 4 1:1.0 2:-1.2 + + // Creates a `LibSVMIter` with specified label file + LibSVMIter = mx.io.LibSVMIter(data_libsvm = 'data.t', data_shape = (3,), + label_libsvm = 'label.t', label_shape = (3,), batch_size = 3) + + // Two batches of data read from the above iterator are as follows(data and label): + // The first batch + [[ 0.5 0. 1.2 ] + [ 0. 0. 0. ] + [ 0.6 2.4 1.2 ]] + + [[ 0. 0. 0. ] + [ 0.125 0. 0. ] + [ 0. 0. 1.2 ]] + + // The second batch + [[ 0. 0. -1.2 ] + [ 0.5 0. 1.2 ] + [ 0. 0. 0. ]] + + [[ 0. 1. -1.2 ] + [ 0. 0. 0. ] + [ 0.125 0. 0. ]] + +)code" ADD_FILELINE) +.add_arguments(LibSVMIterParam::__FIELDS__()) +.add_arguments(BatchParam::__FIELDS__()) +.add_arguments(PrefetcherParam::__FIELDS__()) +.set_body([]() { + return new SparsePrefetcherIter( + new SparseBatchLoader( + new LibSVMIter())); + }); + +} // namespace io +} // namespace mxnet diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 9050ef2d1b38..3eb85b12c077 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -28,8 +28,7 @@ namespace io { class PrefetcherIter : public IIterator { public: explicit PrefetcherIter(IIterator* base) - : loader_(base), out_(nullptr) { - } + : loader_(base), out_(nullptr) {} ~PrefetcherIter() { while (recycle_queue_.size() != 0) { @@ -38,21 +37,24 @@ class PrefetcherIter : public IIterator { delete batch; } delete out_; - iter_.Destroy(); + iter.Destroy(); } - virtual void Init(const std::vector >& kwargs) { + void InitParams(const std::vector >& kwargs) { std::vector > kwargs_left; // init image rec param kwargs_left = param_.InitAllowUnknown(kwargs); - // use the kwarg to init batch loader - loader_->Init(kwargs); // maximum prefetch threaded iter internal size const int kMaxPrefetchBuffer = 16; // init thread iter - iter_.set_max_capacity(kMaxPrefetchBuffer); + iter.set_max_capacity(kMaxPrefetchBuffer); + } - iter_.Init([this](DataBatch **dptr) { + virtual void Init(const std::vector >& kwargs) { + InitParams(kwargs); + // use the kwarg to init batch loader + loader_->Init(kwargs); + iter.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); if (*dptr == nullptr) { @@ -91,7 +93,7 @@ class PrefetcherIter : public IIterator { } virtual void BeforeFirst(void) { - iter_.BeforeFirst(); + iter.BeforeFirst(); } virtual bool Next(void) { @@ -106,9 +108,9 @@ class PrefetcherIter : public IIterator { arr.WaitToWrite(); } recycle_queue_.pop(); - iter_.Recycle(&old_batch); + iter.Recycle(&old_batch); } - return iter_.Next(&out_); + return iter.Next(&out_); } virtual const DataBatch &Value(void) const { return *out_; @@ -117,16 +119,16 @@ class PrefetcherIter : public IIterator { protected: /*! \brief prefetcher parameters */ PrefetcherParam param_; - /*! \brief internal batch loader */ - std::unique_ptr > loader_; + /*! \brief backend thread */ + dmlc::ThreadedIter iter; private: + /*! \brief internal batch loader */ + std::unique_ptr > loader_; /*! \brief output data */ DataBatch *out_; /*! \brief queue to be recycled */ std::queue recycle_queue_; - /*! \brief backend thread */ - dmlc::ThreadedIter iter_; }; } // namespace io } // namespace mxnet diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h new file mode 100644 index 000000000000..81c2359d547f --- /dev/null +++ b/src/io/iter_sparse_batchloader.h @@ -0,0 +1,184 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse_batchloader.h + * \brief define a batch adapter to create sparse tblob batch + */ +#ifndef MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ +#define MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./inst_vector.h" +#include "./image_iter_common.h" +#include "./iter_batchloader.h" + +namespace mxnet { +namespace io { + +/*! \brief create a batch iterator from single instance iterator */ +class SparseBatchLoader : public BatchLoader, public SparseIIterator { + public: + explicit SparseBatchLoader(SparseIIterator *base): + BatchLoader(base), sparse_base_(base) { + } + + virtual ~SparseBatchLoader(void) {} + + inline void Init(const std::vector >& kwargs) { + BatchLoader::Init(kwargs); + data_stype_ = sparse_base_->GetStorageType(true); + label_stype_ = sparse_base_->GetStorageType(false); + if (param_.round_batch == 0) { + LOG(FATAL) << "sparse batch loader doesn't support round_batch == false yet"; + } + } + + virtual void BeforeFirst(void) { + BatchLoader::BeforeFirst(); + } + + virtual bool Next(void) { + out_.num_batch_padd = 0; + out_.batch_size = param_.batch_size; + this->head_ = 0; + // if overflown from previous round, directly return false, until before first is called + if (num_overflow_ != 0) return false; + index_t top = 0; + inst_cache_.clear(); + while (sparse_base_->Next()) { + inst_cache_.emplace_back(sparse_base_->Value()); + if (inst_cache_.size() >= param_.batch_size) break; + } + // no more data instance + if (inst_cache_.size() == 0) { + return false; + } + if (inst_cache_.size() < param_.batch_size) { + CHECK_GT(param_.round_batch, 0); + num_overflow_ = 0; + sparse_base_->BeforeFirst(); + for (; inst_cache_.size() < param_.batch_size; ++num_overflow_) { + CHECK(sparse_base_->Next()) << "number of input must be bigger than batch size"; + inst_cache_.emplace_back(sparse_base_->Value()); + } + } + out_.num_batch_padd = num_overflow_; + CHECK_EQ(inst_cache_.size(), param_.batch_size); + this->InitDataFromBatch(); + MSHADOW_INT_TYPE_SWITCH(CSR_IND_PTR_TYPE, IType, { + for (size_t j = 0; j < inst_cache_.size(); j++) { + const auto& d = inst_cache_[j]; + out_.inst_index[top] = d.index; + size_t unit_size = 0; + for (size_t i = 0; i < d.data.size(); ++i) { + // indptr tensor + if (IsIndPtr(i)) { + auto indptr = data_[i].get(); + if (j == 0) indptr[0] = 0; + indptr[j + 1] = indptr[j] + (IType) unit_size; + offsets_[i] = j; + } else { + // indices and values tensor + unit_size = d.data[i].shape_.Size(); + MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { + const auto begin = offsets_[i]; + const auto end = offsets_[i] + unit_size; + mshadow::Copy(data_[i].get().Slice(begin, end), + d.data[i].get_with_shape(mshadow::Shape1(unit_size))); + }); + offsets_[i] += unit_size; + } + } + } + }); + return true; + } + + virtual const TBlobBatch &Value(void) const { + return BatchLoader::Value(); + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + return sparse_base_->GetStorageType(is_data); + } + + virtual const TShape GetShape(bool is_data) const { + TShape inst_shape = sparse_base_->GetShape(is_data); + std::vector shape_vec; + shape_vec.push_back(param_.batch_size); + for (index_t dim = 0; dim < inst_shape.ndim(); ++dim) { + shape_vec.push_back(inst_shape[dim]); + } + return TShape(shape_vec.begin(), shape_vec.end()); + } + + private: + /*! \brief base sparse iterator */ + SparseIIterator *sparse_base_; + /*! \brief data instances */ + std::vector inst_cache_; + /*! \brief data storage type */ + NDArrayStorageType data_stype_; + /*! \brief data label type */ + NDArrayStorageType label_stype_; + /*! \brief tensor offset for slicing */ + std::vector offsets_; + + // check whether ith position is the indptr tensor for a CSR tensor + inline bool IsIndPtr(size_t i) { + auto data_num_aux = NDArray::NumAuxData(data_stype_); + auto label_num_aux = NDArray::NumAuxData(label_stype_); + auto label_indptr_offset = data_num_aux + 1 + label_num_aux; + // data indptr + if (i == data_num_aux && data_stype_ == kCSRStorage) { + return true; + } + // label indptr + if (i == label_indptr_offset && label_stype_ == kCSRStorage && data_stype_ == kCSRStorage) { + return true; + } + return false; + } + + // initialize the data holder by using from the batch + inline void InitDataFromBatch() { + CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage); + CHECK_GT(inst_cache_.size(), 0); + out_.data.clear(); + offsets_.clear(); + + size_t total_size = inst_cache_[0].data.size(); + data_.resize(total_size); + offsets_.resize(total_size, 0); + std::vector vec_sizes(total_size, 0); + // accumulate the memory required for a batch + for (size_t i = 0; i < total_size; ++i) { + size_t size = 0; + // vec_size for indptr + if (IsIndPtr(i)) { + size = param_.batch_size + 1; + } else { + for (const auto &d : inst_cache_) size += d.data[i].shape_.Size(); + } + vec_sizes[i] = size; + } + + CHECK_EQ(vec_sizes[0], vec_sizes[1]); + for (size_t i = 0; i < total_size; ++i) { + int src_type_flag = inst_cache_[0].data[i].type_flag_; + // init object attributes + TShape dst_shape(mshadow::Shape1(vec_sizes[i])); + data_[i].resize(mshadow::Shape1(vec_sizes[i]), src_type_flag); + CHECK(data_[i].dptr_ != nullptr); + out_.data.push_back(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag)); + } + } +}; // class BatchLoader +} // namespace io +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ diff --git a/src/io/iter_sparse_prefetcher.h b/src/io/iter_sparse_prefetcher.h new file mode 100644 index 000000000000..6b2d22573e98 --- /dev/null +++ b/src/io/iter_sparse_prefetcher.h @@ -0,0 +1,134 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse_prefetcher.h + * \brief define a prefetcher using threaditer to keep k batch fetched + */ +#ifndef MXNET_IO_ITER_SPARSE_PREFETCHER_H_ +#define MXNET_IO_ITER_SPARSE_PREFETCHER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./inst_vector.h" +#include "./image_iter_common.h" +#include "./iter_prefetcher.h" + +namespace mxnet { +namespace io { +// iterator on sparse data +class SparsePrefetcherIter : public PrefetcherIter { + public: + explicit SparsePrefetcherIter(SparseIIterator* base) + : PrefetcherIter(base), sparse_loader_(base) {} + + ~SparsePrefetcherIter() {} + + virtual void Init(const std::vector >& kwargs) { + PrefetcherIter::InitParams(kwargs); + // use the kwarg to init batch loader + sparse_loader_->Init(kwargs); + iter.Init([this](DataBatch **dptr) { + if (!sparse_loader_->Next()) return false; + const TBlobBatch& batch = sparse_loader_->Value(); + if (*dptr == nullptr) { + // allocate databatch + *dptr = new DataBatch(); + (*dptr)->num_batch_padd = batch.num_batch_padd; + // (*dptr)->data.at(0) => data + // (*dptr)->data.at(1) => label + (*dptr)->data.resize(2); + (*dptr)->index.resize(batch.batch_size); + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + bool is_data = i == 0; + auto stype = this->GetStorageType(is_data); + auto dtype = param_.dtype ? param_.dtype.value() : batch.data[data_iter].type_flag_; + if (stype == kDefaultStorage) { + (*dptr)->data.at(i) = NDArray(batch.data[data_iter].shape_, + Context::CPU(), false, dtype); + } else { + (*dptr)->data.at(i) = NDArray(stype, this->GetShape(is_data), + Context::CPU(), false, dtype); + } + data_iter += NDArray::NumAuxData(stype) + 1; + } + } + // copy data over + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + auto& nd = ((*dptr)->data)[i]; + auto stype = nd.storage_type(); + auto& data_i = ((*dptr)->data)[i]; + if (stype == kDefaultStorage) { + CopyFromTo(data_i.data(), batch.data[data_iter]); + } else if (stype == kCSRStorage) { + auto& values = batch.data[data_iter]; + auto& indices = batch.data[data_iter + 1]; + auto& indptr = batch.data[data_iter + 2]; + // allocate memory + CHECK_EQ(indices.shape_.Size(), values.shape_.Size()); + nd.CheckAndAllocAuxData(csr::kIdx, indices.shape_); + nd.CheckAndAllocData(values.shape_); + nd.CheckAndAllocAuxData(csr::kIndPtr, indptr.shape_); + // copy values, indices and indptr + CopyFromTo(data_i.data(), values); + CopyFromTo(data_i.aux_data(csr::kIdx), indices); + CopyFromTo(data_i.aux_data(csr::kIndPtr), indptr); + } else { + LOG(FATAL) << "Storage type not implemented: " << stype; + } + data_iter += NDArray::NumAuxData(stype) + 1; + (*dptr)->num_batch_padd = batch.num_batch_padd; + } + if (batch.inst_index) { + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); + } + return true; + }, + [this]() { sparse_loader_->BeforeFirst(); }); + } + + virtual void BeforeFirst(void) { + PrefetcherIter::BeforeFirst(); + } + + virtual bool Next(void) { + return PrefetcherIter::Next(); + } + virtual const DataBatch &Value(void) const { + return PrefetcherIter::Value(); + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + return sparse_loader_->GetStorageType(is_data); + } + + virtual const TShape GetShape(bool is_data) const { + return sparse_loader_->GetShape(is_data); + } + + private: + /*! \brief internal sparse batch loader */ + SparseIIterator* sparse_loader_; + + inline void CopyFromTo(TBlob dst, const TBlob src) { + MSHADOW_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Copy(dst.FlatTo1D(), src.FlatTo1D()); + }); + } +}; +} // namespace io +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_PREFETCHER_H_ diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6f1795d6f368..44178d305c4a 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -12,6 +12,7 @@ #include #include #include "./ndarray_function.h" +#include "../operator/tensor/matrix_op-inl.h" #include "./autograd.h" #if MXNET_USE_OPENCV @@ -26,6 +27,8 @@ namespace mxnet { NDArray NDArray::Reshape(const TShape &shape) const { using namespace autograd; + CHECK(storage_type() == kDefaultStorage) << "Reshape for storage type " << + storage_type() << " is not implemented yet"; if (AutogradRuntime::Get()->IsTraining()) { CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape must have must have the same size as " @@ -56,12 +59,14 @@ NDArray NDArray::Reshape(const TShape &shape) const { } } - NDArray NDArray::Slice(index_t begin, index_t end) const { using namespace autograd; + using namespace mshadow; NDArray ret = *this; CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_[0], end) << "Slice end index out of range"; + auto stype = storage_type(); + CHECK_EQ(stype, kDefaultStorage); size_t length = shape_.ProdShape(1, shape_.ndim()); MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); @@ -88,8 +93,69 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { } } +void NDArray::SliceEx(index_t begin, index_t end, NDArray *ret) const { + using namespace autograd; + using namespace mshadow; + CHECK(!is_none()) << "NDArray is not initialized"; + CHECK_GE(shape_[0], end) << "Slice end index out of range"; + auto stype = storage_type(); + CHECK_NE(stype, kDefaultStorage); + if (stype == kCSRStorage) { + using namespace csr; + ret->shape_[0] = end - begin; + NDArray src = *this; + // destination NDArray shares the same variable + ret->ptr_->var = var(); + Engine::Get()->PushSync([src, ret, begin, end](RunContext ctx) { + NDArray dst = *ret; + // create a new chunk for dst NDArray + NDArray::Chunk chunk = *src.ptr_; + // void indptr storage handle + chunk.aux_handles[kIndPtr] = Storage::Handle(); + // shape for indptr is end - begin + 1 + chunk.CheckAndAllocAuxData(kIndPtr, Shape1(end - begin + 1)); + if (src.ctx().dev_mask() == cpu::kDevMask) { + MSHADOW_INT_TYPE_SWITCH(src.aux_type(kIndPtr), IType, { + MSHADOW_TYPE_SWITCH(src.dtype(), DType, { + // create new indptr + const IType* src_indptr = src.aux_data(kIndPtr).dptr(); + IType* dst_indptr = static_cast (chunk.aux_handles[kIndPtr].dptr); + op::SliceCsrIndPtrImpl(begin, end, ctx, src_indptr, dst_indptr); + // advance idx and values pointers (CPU implementation) + // TODO(haibin) refactor for GPU implementation later + IType offset = src_indptr[begin]; + IType* idx = static_cast(chunk.aux_handles[kIdx].dptr); + DType* values = static_cast(chunk.shandle.dptr); + chunk.aux_handles[kIdx].dptr = idx + offset; + chunk.shandle.dptr = values + offset; + // update storage shape and aux shape (CPU implementation) + auto nnz = dst_indptr[end - begin]; + chunk.aux_shapes[kIdx] = Shape1(nnz); + chunk.storage_shape = Shape1(nnz); + chunk.static_data = true; + chunk.skip_delete_var = true; + // update dst chunk + *dst.ptr_ = chunk; + }); + }); + } else { +#if MXNET_USE_CUDA + LOG(FATAL) << "SliceEx CSR not implemented yet"; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } + }, ctx(), {}, {var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + } else { + LOG(FATAL) << "Slice not yet implemented for storage " << stype; + } + // TODO(haibin) support auto_grad for SliceEx +} NDArray NDArray::At(index_t idx) const { + CHECK(storage_type() == kDefaultStorage) << "Storage type " + << storage_type() << " doesn't support At()"; NDArray ret = this->Slice(idx, idx+1); if (shape_.ndim() > 1) { return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim())); @@ -212,11 +278,11 @@ void BinaryOp(const NDArray &lhs, // redirect everything to mshadow operations switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { - Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { + TBlob tmp = ret.data(); + ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -242,6 +308,7 @@ void SetValueOp(const real_t &rhs, NDArray *out) { switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { + CHECK(ret.storage_type() == kDefaultStorage); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); }, ret.ctx(), {}, {ret.var()}, @@ -313,6 +380,7 @@ void ScalarOp(const NDArray &lhs, } } + void CopyFromTo(const NDArray &from, NDArray *to, int priority) { if (from.var() == to->var()) { // skip to copy to itself @@ -327,44 +395,33 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { NDArray ret = *to; int a = from.ctx().dev_mask(); int b = to->ctx().dev_mask(); - std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU")); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU")); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU")); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, from.dtype() != ret.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2GPU")); diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index 28524b73d0dd..aad80fd4360a 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -12,27 +12,28 @@ // macro to help specialize evaluation function #ifndef DECL_TERNARY -#define DECL_TERNARY(XPU, OP, FUN) \ - template<> \ - void Eval(const TBlob &lhs, const TBlob &mhs, \ - const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, mhs, rhs, ret, ctx); \ +#define DECL_TERNARY(XPU, OP, FUN) \ + template<> \ + void Eval(const TBlob &lhs, const TBlob &mhs, \ + const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, mhs, rhs, ret, ctx); \ } #endif #ifndef DECL_BINARY -#define DECL_BINARY(XPU, OP, FUN) \ - template<> \ +#define DECL_BINARY(XPU, OP, FUN) \ + template<> \ void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ - template<> \ - void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ + template<> \ + void Eval(const TBlob &lhs, const real_t &rhs, \ + TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -44,10 +45,11 @@ namespace mxnet { namespace ndarray { + // true implementation template -inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalBinary_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -61,10 +63,9 @@ inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, }); } - template -inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalOneHot_(const TBlob &index, const TBlob &rhs, + TBlob *ret, RunContext ctx) { LOG(INFO) << "The operator onehot_encode is deprecated; use one_hot instead."; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -81,8 +82,8 @@ inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, } template -inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); // TODO(eric): support mixed type choose, i.e. int index and float rhs. @@ -98,8 +99,8 @@ inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, } template -inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->get(s) @@ -109,8 +110,8 @@ inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob } template -inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, - TBlob *ret, RunContext ctx) { +void EvalScalar_(const TBlob &lhs, const real_t &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -130,7 +131,7 @@ inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, template<> void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, - TBlob *ret, RunContext ctx) { + TBlob *ret, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -145,12 +146,11 @@ void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max } template<> -void EvalRandom( - const real_t &a, - const real_t &b, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +void EvalRandom(const real_t &a, + const real_t &b, + const Resource &resource, + TBlob *ret, + RunContext ctx) { typedef DEVICE xpu; mshadow::Stream *s = ctx.get_stream(); switch (ret->type_flag_) { @@ -426,6 +426,7 @@ DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) DECL_SCALAR(DEVICE, Div, EvalScalar_, true) + // for reverse seq DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index def38126d08c..f4315b62a6a8 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -17,6 +17,7 @@ #include #include #include "./operator_common.h" +#include "../common/utils.h" namespace mxnet { namespace op { @@ -53,6 +54,42 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } +// Only inferring output storage types from input for now +template +inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + auto deduce = [&](std::vector *vec, const char *name, AttrType& result, + bool fallback) { + auto &v = *vec; + for (size_t i = 0; i < vec->size(); ++i) { + if (v[i] == kUndefinedStorage) { + // if input type is unknown, assume it's default storage + CHECK(assign(&v[i], kDefaultStorage)); + } else if (assign(&result, v[i]) == false && fallback) { + result = kDefaultStorage; + } + } + }; + AttrType dattr = kUndefinedStorage; + deduce(in_attrs, "input", dattr, enable_fallback); + if (reverse_infer) { + LOG(FATAL) << "not implemented yet"; + } + auto write = [&](std::vector *vec, const char *name) { + for (size_t i = 0; i < vec->size(); ++i) { + CHECK(assign(&(*vec)[i], dattr)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " << "expected " << dattr << ", got " << (*vec)[i]; + } + }; + if (is_none(dattr)) dattr = kDefaultStorage; + write(out_attrs, "output"); + return true; +} + template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -73,6 +110,29 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +template +inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); +} + +inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; + auto &in = *in_attrs; + auto &out = *out_attrs; + CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; + if (in[0] == kUndefinedStorage) in[0] = in[1]; + if (out[0] == kUndefinedStorage) out[0] = in[1]; + return true; +} + // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; @@ -105,6 +165,22 @@ struct ElemwiseGradUseNone { } }; +// TODO(haibin) this is a temporary function for debugging purpose. Remove later. +template +void print_info(const mshadow::Tensor& tensor, const std::string& name) { + std::cout << "Tensor " << name << " with shape ("; + int len = 1; + for (int i = 0; i < dim; i++) { + len *= tensor.shape_[i]; + std::cout << tensor.shape_[i] << ","; + if (i == dim - 1) std::cout << ")"; + } + std::cout << std::endl; + for (int j = 0; j < len; j ++) std::cout << tensor.dptr_[j] << " "; + std::cout << std::endl; +} + + } // namespace op } // namespace mxnet diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 9b5dcfe3d3b1..6a9ee30f1b04 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -7,6 +7,7 @@ #ifndef MXNET_OPERATOR_MXNET_OP_H_ #define MXNET_OPERATOR_MXNET_OP_H_ +#include #include #include @@ -22,6 +23,8 @@ const float PI = 3.14159265358979323846; using std::isnan; #endif +template +int get_num_threads(const int N); #ifdef __CUDACC__ #define CUDA_KERNEL_LOOP(i, n) \ @@ -37,8 +40,18 @@ inline int cuda_get_num_blocks(const int N) { using namespace mshadow::cuda; return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); } + +template<> +inline int get_num_threads(const int N) { + using namespace mshadow::cuda; + return kBaseThreadNum * cuda_get_num_blocks(N); +} #endif // __CUDACC__ +template<> +inline int get_num_threads(const int N) { + return omp_get_max_threads(); +} /*! \brief operator request type switch */ #define MXNET_ASSIGN_REQ_SWITCH(req, ReqType, ...) \ diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index a43d092bceb6..ecfb9c76acb3 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -11,12 +11,15 @@ #include #include #include +#include +#include #include #include #include #include #include #include "../common/cuda_utils.h" +#include "../common/utils.h" namespace mxnet { namespace op { @@ -315,6 +318,24 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } +template +void FCompExFallback(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + FCompute fcompute, + const std::string& fname) { + using namespace common; + std::vector in_blobs, out_blobs; + std::vector temp_in, temp_out; + GetDefaultBlobs(inputs, &in_blobs, &temp_in, ctx, true); + GetDefaultBlobs(outputs, &out_blobs, &temp_out, ctx, true); + fcompute(attrs, ctx, in_blobs, req, out_blobs); + CastNonDefaultStorage(outputs, temp_out, ctx, true); +} + + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 85091c008ab4..83a4a9cfccbb 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -84,6 +84,87 @@ inline void SGDUpdate(const nnvm::NodeAttrs& attrs, }); } +/*! \brief kernel for sparse sgd + */ +template +struct SGDDnsRspKernel { + // DType is the output data type + // IType is row sparse idx type + // i is the ith row in row sparse gradient + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out, const DType* weight, + const IType* grad_idx, const DType *grad_val, + const DType clip_gradient, const DType lr, + const DType wd, const DType rescale_grad) { + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient)); + } else { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr * rescale_grad) * grad_val[grad_i]); + } + } + } +}; + +template +inline void SGDUpdateDnsRspImpl(const SGDParam& param, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + Stream* s = ctx.get_stream(); + auto &weight = inputs[0]; + auto &grad = inputs[1]; + auto &out = outputs[0]; + CHECK_EQ(weight.storage_type(), kDefaultStorage); + CHECK_EQ(grad.storage_type(), kRowSparseStorage); + if (!grad.storage_initialized()) return; + + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + auto weight_data = weight.data().FlatTo2D(s); + auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto grad_val = grad.data().FlatTo2D(s); + auto out_data = out.data().FlatTo2D(s); + auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; + auto width = weight.shape().ProdShape(1, weight.shape().ndim()); + mxnet_op::Kernel, xpu>::Launch(s, num_rows, width, + out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const SGDParam& param = nnvm::get(attrs.parsed); + auto weight_stype = inputs[0].storage_type(); + auto grad_stype = inputs[1].storage_type(); + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { + SGDUpdateDnsRspImpl(param, ctx, inputs, req, outputs); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); + } +} + struct SGDMomParam : public dmlc::Parameter { float lr; float momentum; @@ -153,6 +234,88 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, }); } +template +struct SGDMomDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data, + DType* mom_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, const DType param_rescale_grad) { + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (param_clip_gradient >= 0.0f) { + mom_data[data_i] = param_momentum * mom_data[data_i] + - param_lr * param_wd * weight_data[data_i] + - param_lr * + mshadow_op::clip::Map(param_rescale_grad * grad_data[grad_i], + param_clip_gradient); + } else { + mom_data[data_i] = param_momentum * mom_data[data_i] + - param_lr * param_wd * weight_data[data_i] + - param_lr * param_rescale_grad * grad_data[grad_i]; + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); + } + } +}; + +template +inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + auto &weight = inputs[0]; + auto &grad = inputs[1]; + auto &mom = inputs[2]; + auto &out = outputs[0]; + if (!grad.storage_initialized()) return; + + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + auto weight_data = weight.data().FlatTo2D(s); + auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto grad_val = grad.data().FlatTo2D(s); + auto mom_data = mom.data().FlatTo2D(s); + auto out_data = out.data().FlatTo2D(s); + auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; + auto width = weight.shape().ProdShape(1, weight.shape().ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const SGDMomParam& param = nnvm::get(attrs.parsed); + auto weight_stype = inputs[0].storage_type(); + auto grad_stype = inputs[1].storage_type(); + auto mom_stype = inputs[2].storage_type(); + + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && + mom_stype == kDefaultStorage) { + SGDMomUpdateDnsRspDnsImpl(param, ctx, inputs, req, outputs); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && + mom_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, + SGDMomUpdate, "SGDMomUpdate"); + } +} + struct AdamParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 9ec6aacaafac..5464d03b215f 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -22,6 +22,9 @@ It updates the weights using:: weight = weight - learning_rate * gradient +If gradients are stored with `row_sparse` storage, +where update is applied only to rows whose gradient has non-zero entries. + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -29,6 +32,7 @@ It updates the weights using:: .set_attr("FInferShape", ElemwiseShape<2, 1>) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FCompute", SGDUpdate) +.set_attr(FCOMP_EX_CPU, SGDUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_arguments(SGDParam::__FIELDS__()); @@ -52,6 +56,9 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. +If gradients are stored with `row_sparse` storage, +only rows whose gradients contain non-zero entries are updated (for both weight and momentum). + )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -63,12 +70,12 @@ Where the parameter ``momentum`` is the decay rate of momentum estimates at each return std::vector{2}; }) .set_attr("FCompute", SGDMomUpdate) +.set_attr(FCOMP_EX_CPU, SGDMomUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mom", "NDArray-or-Symbol", "Momentum") .add_arguments(SGDMomParam::__FIELDS__()); - NNVM_REGISTER_OP(adam_update) .describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization of AdaGrad. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 2b2667ec317b..bf0cc570e1f4 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -10,10 +10,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(sgd_update) -.set_attr("FCompute", SGDUpdate); +.set_attr("FCompute", SGDUpdate) +.set_attr(FCOMP_EX_GPU, SGDUpdateEx); NNVM_REGISTER_OP(sgd_mom_update) -.set_attr("FCompute", SGDMomUpdate); +.set_attr("FCompute", SGDMomUpdate) +.set_attr(FCOMP_EX_GPU, SGDMomUpdateEx); NNVM_REGISTER_OP(adam_update) .set_attr("FCompute", AdamUpdate); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 0d0a1d8b5df0..f6f8f429d99e 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -105,6 +105,7 @@ Example:: .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); + NNVM_REGISTER_OP(_backward_broadcast_mul) .set_num_inputs(3) .set_num_outputs(2) diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 6062febe2d9e..9317720f127a 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -10,10 +10,10 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" -#include "../mxnet_op.h" namespace mxnet { namespace op { @@ -123,6 +123,115 @@ void BinaryBackwardUseNone_(const nnvm::NodeAttrs& attrs, } } +// TODO(haibin) This is a single-thread inefficient implementation +// Binary Compute between two row-sparse ndarray +// This implementation only works on CPU +template +void BinaryComputeRspRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto &lhs = inputs[0]; + auto &rhs = inputs[1]; + auto &output = outputs[0]; + + bool init_l = lhs.storage_initialized(); + bool init_r = rhs.storage_initialized(); + // both inputs are zeros + if (!init_l && !init_r) return; + // one of the input is zeros + if (!init_l || !init_r) { + NDArray out(output); + CopyFromToRspImpl(!init_l ? rhs : lhs, &out, ctx.run_ctx); + return; + } + // Memory Estimation: This is (roughly) the number of result rows. We still + // need to subtract the number of common rows + unsigned int num_rows_l = lhs.aux_shape(rowsparse::kIdx).Size(); + unsigned int num_rows_r = rhs.aux_shape(rowsparse::kIdx).Size(); + output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)}); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(lhs.aux_type(rowsparse::kIdx), IType, { + // Indices + auto indices_l = lhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_r = rhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_out = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + // Data + auto data_l = lhs.data().FlatTo2D(s); + auto data_r = rhs.data().FlatTo2D(s); + auto out = output.data().FlatTo2D(s); + + // TODO(haibin) A more appropriate way: Copy to output, then apply ops + size_t iter_l = 0; + size_t iter_r = 0; + size_t iter_out = 0; + int32_t num_common_rows = 0; + while (iter_l < num_rows_l && iter_r < num_rows_r) { + auto idx_l = indices_l[iter_l]; + auto idx_r = indices_r[iter_r]; + if (idx_l == idx_r) { + // Same row + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + out[iter_out] += data_r[iter_r++]; + num_common_rows++; + } else if (idx_l < idx_r) { + // Left only + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + } else { + // Right only + indices_out[iter_out] = idx_r; + mshadow::Copy(out[iter_out], data_r[iter_r++], s); + } + iter_out++; + } + // Copying over the rest of the rows + while (iter_l < num_rows_l) { + indices_out[iter_out] = indices_l[iter_l]; + mshadow::Copy(out[iter_out++], data_l[iter_l++], s); + } + while (iter_r < num_rows_r) { + indices_out[iter_out] = indices_r[iter_r]; + mshadow::Copy(out[iter_out++], data_r[iter_r++], s); + } + auto new_shape = output.aux_shape(rowsparse::kIdx); + new_shape[0] -= num_common_rows; + output.SetAuxShape(rowsparse::kIdx, new_shape); + }); + }); +} + +template +void BinaryComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + if (typeid(OP) == typeid(mshadow::op::plus)) { + // If any input is dense, fallback to FCompute + // TODO(haibin) implement dns + rsp in a separate kernel + if (common::ContainsDefaultStorage(inputs)) { + FCompExFallback(attrs, ctx, inputs, req, outputs, + BinaryCompute, "BinaryCompute"); + return; + } + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + CHECK_EQ(inputs[1].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + BinaryComputeRspRsp(attrs, ctx, inputs, req, outputs); + return; + } else { + LOG(FATAL) << "Not implemented"; + } +} + template void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -134,6 +243,55 @@ void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, }); } +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[1].storage_type(), kRowSparseStorage); + CHECK(typeid(LOP) == typeid(mshadow_op::identity)); + CHECK(typeid(ROP) == typeid(mshadow_op::identity)); + TShape shape = inputs[0].aux_shape(rowsparse::kIdx); + outputs[0].CheckAndAlloc({shape}); + outputs[1].CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + MSHADOW_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), IType, { + auto lgrad_idx = outputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto rgrad_idx = outputs[1].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto ograd_idx = inputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto lgrad = outputs[0].data().FlatTo1D(s); + Tensor rgrad = outputs[1].data().FlatTo1D(s); + Tensor ograd = inputs[0].data().FlatTo1D(s); + ASSIGN_DISPATCH(lgrad, req[0], F(ograd)); + ASSIGN_DISPATCH(rgrad, req[1], F(ograd)); + ASSIGN_DISPATCH(lgrad_idx, req[0], F(ograd_idx)); + ASSIGN_DISPATCH(rgrad_idx, req[1], F(ograd_idx)); + }); + }); +} +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto stype = inputs[0].storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Not implemented yet"; + BinaryBackwardUseNoneRsp(attrs, ctx, inputs, req, outputs); + // TODO(haibin) fallback for kDefaultStorage +} + template void BinaryBackwardUseNoneWithHalf2(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -214,7 +372,7 @@ void BinaryBackwardUseInWithHalf2(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ - .add_argument("lhs", "NDArray-or-Symbol", "first input") \ + .add_argument("lhs", "NDArray-or-Symbol", "first input") \ .add_argument("rhs", "NDArray-or-Symbol", "second input") } // namespace op diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index be4c1d88e983..8bf0d2e10c01 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -12,7 +12,9 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .add_alias("_add").add_alias("_plus").add_alias("_Plus") .describe("Adds arguments element-wise.") .set_attr("FCompute", BinaryCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}); +.set_attr(FCOMP_EX_CPU, BinaryComputeEx) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_add"}) +.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); // specialized gradient add function to do add to optimization // this must differ from elemwise_add to prevent add to optimization in forward pass. @@ -28,7 +30,10 @@ NNVM_REGISTER_OP(_backward_add) return std::vector >{{0, 0}, {0, 1}}; }) .set_attr("FCompute", BinaryBackwardUseNone); + mshadow_op::identity>) +.set_attr(FCOMP_EX_CPU, + BinaryBackwardUseNoneEx) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index ff432380d6d1..cb30d78e2d8e 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", BinaryComputeWithHalf2); +.set_attr("FCompute", BinaryComputeWithHalf2) +.set_attr(FCOMP_EX_GPU, BinaryComputeEx); NNVM_REGISTER_OP(_grad_add) .set_attr("FCompute", BinaryComputeWithHalf2); @@ -17,7 +18,9 @@ NNVM_REGISTER_OP(_grad_add) NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", BinaryBackwardUseNoneWithHalf2); + mshadow_op::identity, mshadow_op::identity>) +.set_attr(FCOMP_EX_GPU, + BinaryBackwardUseNoneEx); NNVM_REGISTER_OP(_sub) .set_attr("FCompute", BinaryComputeWithHalf2); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index 073bbe16d491..9cdd56e66646 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -124,7 +124,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_CPU, IdentityLikeRhsComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -169,6 +171,27 @@ NNVM_REGISTER_OP(_backward_cast) .set_attr("TIsBackward", true) .set_attr("FCompute", CastCompute); +// TODO(haibin) declare backward op for cast storage +// Only support cast to default storage now +// Other types require add infer_storage type pass +DMLC_REGISTER_PARAMETER(CastStorageParam); +NNVM_REGISTER_OP(cast_storage) +.describe(R"code(Casts tensor storage type to the new type. +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FCompute", IdentityCompute) +// _backward pass +// .set_attr("FGradient", ElemwiseGradUseNone{"negative"}) +.set_attr(FCOMP_EX_CPU, CastStorageComputeEx) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(CastStorageParam::__FIELDS__()); + + // negative MXNET_OPERATOR_REGISTER_UNARY(negative) .MXNET_DESCRIBE("Negate src") diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index 746b39fe4c8c..2084f5d3f5c4 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -35,7 +35,9 @@ NNVM_REGISTER_OP(make_loss) // identity output as first input, but attributes are constrainted to be like rhs NNVM_REGISTER_OP(_identity_with_attr_like_rhs) -.set_attr("FCompute", IdentityCompute); +.set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_GPU, IdentityLikeRhsComputeEx); + NNVM_REGISTER_OP(Cast) .set_attr("FCompute", CastCompute); @@ -43,6 +45,10 @@ NNVM_REGISTER_OP(Cast) NNVM_REGISTER_OP(_backward_cast) .set_attr("FCompute", CastCompute); +NNVM_REGISTER_OP(cast_storage) +.set_attr("FCompute", IdentityCompute) +.set_attr(FCOMP_EX_GPU, CastStorageComputeEx); + // negative NNVM_REGISTER_OP(negative) .set_attr("FCompute", UnaryCompute); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 97a7e36535f0..996a25d5a647 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -9,19 +9,22 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../special_functions-inl.h" +#include "../mxnet_op.h" +#include "./broadcast_reduce-inl.h" namespace mxnet { namespace op { template void UnaryLaunch(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; Stream *s = ctx.get_stream(); @@ -77,6 +80,54 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, }); } +template +void IdentityComputeRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto &input = inputs[0]; + auto &output = outputs[0]; + CHECK_NE(req[0], kNullOp) << "kNullOp in IdentityComputeEx not supported yet"; + CHECK_NE(req[0], kWriteInplace) << "kWriteInplace in IdentityComputeEx not supported yet"; + if (!input.storage_initialized()) return; + TShape shape = input.aux_shape(rowsparse::kIdx); + output.CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), AuxType, { + auto out_d = output.data().FlatTo1D(s); + auto out_aux = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto in_aux = input.aux_data(rowsparse::kIdx).FlatTo1D(s); + ASSIGN_DISPATCH(out_d, req[0], + F(input.data().FlatTo1D(s))); + ASSIGN_DISPATCH(out_aux, req[0], F(in_aux)); + }); + }); +} + +template +void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + size_t rhs_idx = 1; + NDArrayStorageType stype = inputs[rhs_idx].storage_type(); + if (stype == kRowSparseStorage) { + IdentityComputeRsp(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented yet"; + } +} + struct CastParam : public dmlc::Parameter { // use int for enumeration int dtype; @@ -154,6 +205,376 @@ struct relu_grad { }; } // namespace kernel_launch_op +struct CastStorageParam : public dmlc::Parameter { + int storage_type; + DMLC_DECLARE_PARAMETER(CastStorageParam) { + DMLC_DECLARE_FIELD(storage_type) + .add_enum("default", kDefaultStorage) + .add_enum("row_sparse", kRowSparseStorage) + .add_enum("csr", kCSRStorage) + .describe("Output storage type."); + } +}; + +/*! + * \brief This is the kernel for initializing row_idx array + * of a RSP matrix. Each thread checks a row of the matrix, + * if non-zero elements are found, mark this row as non-zero + * by row_idx[cur_row_id] = cur_row_id. Otherwise, + * row_idx[cur_row_id] = num_rows. + */ +struct FillRspRowIdx { + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* arr, + const int num_rows, const int num_cols) { + row_idx[i] = num_rows; + const int offset = i * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (arr[offset+j] != 0) { + row_idx[i] = i; + break; + } + } + } +}; + +/*! + * \brief Kernel for marking row_idx of a RSP matrix per row + */ +struct MarkRspRowIdx { + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, + const index_t num_cols) { + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = 0; // mark as zero for zero row + } else { + row_idx[i] = 1; // mark as one for non-zero row + } + } +}; + +struct CopyDnsToRsp{ + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, DType* rsp_data, + const DType* dns_data, const int num_rows, const int num_cols) { + int j = 0; + int offset = i * num_cols; + for (; j < num_cols; ++j) { + if (dns_data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = num_rows; + } else { + row_idx[i] = i; + for (j = 0; j < num_cols; ++j) { + rsp_data[offset+j] = dns_data[offset+j]; + } + } + } +}; + +/*! + * \brief + * CPU implementation of casting a dns tensor to rsp type. + */ +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + CHECK(rsp != nullptr); + CHECK_EQ(rsp->storage_type(), kRowSparseStorage); + CHECK_EQ(dns.shape_, rsp->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(num_rows)); + TBlob row_idx_blob = rsp->aux_data(rowsparse::kIdx); + RType* row_idx = row_idx_blob.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, row_idx, + dns.dptr(), num_cols); + index_t nnr = 0; + nnr = std::accumulate(row_idx, row_idx+num_rows, nnr); + rsp->SetAuxShape(rowsparse::kIdx, mshadow::Shape1(nnr)); + if (0 == nnr) return; + rsp->CheckAndAllocData(mshadow::Shape2(nnr, num_cols)); + mshadow::Tensor dns_data = dns.FlatTo2D(s); + mshadow::Tensor rsp_data = rsp->data().FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < num_rows; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], dns_data[i], s); + ++idx; + } + } + }); + }); +} + +// TODO(haibin) Use memcopy instead will be much faster than assigning each individual element +struct CastStorageRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, + DType* dns, const index_t invalid_rid) { + auto rid = idx[i]; + // skip invalid rows + if (rid == invalid_rid) return; + auto dns_offset = rid * width; + auto rsp_offset = i * width; + for (size_t col = 0; col < width; col++) { + dns[dns_offset + col] = data[rsp_offset + col]; + } + } +}; + +/*! + * \brief This function assumes that the meomry for dns has been allocated already + * since the shape is known at binding stage. + */ +template +void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* dns) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(rsp.storage_type(), kRowSparseStorage); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { + // assign zeros + mxnet_op::Kernel::Launch(s, dns->Size(), dns->dptr()); + if (rsp.storage_initialized()) { + // copy over row by row + auto in_idx = rsp.aux_data(rowsparse::kIdx).FlatTo1D(s).dptr_; + auto in_data = rsp.data().FlatTo2D(s).dptr_; + auto out_data = dns->FlatTo2D(s).dptr_; + auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); + auto rsp_shape = rsp.shape(); + auto invalid_rid = rsp_shape[0]; + auto width = rsp_shape.ProdShape(1, rsp_shape.ndim()); + mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, in_data, + out_data, invalid_rid); + } + }); + }); +} + +/*! + * \brief This is the kernel for initializing the indptr in a csr tensor. + */ +struct FillCsrIndPtr { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param indptr indptr of the csr tensor + * \param dns the dns tensor + * \param num_rows + * \param num_cols + */ + template + MSHADOW_XINLINE static void Map(int i, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + indptr[i+1] = 0; + const int offset = i * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + ++indptr[i+1]; + } + } + } +}; + +/*! + * \brief This is the kernel for initializing the col_idx and value array + * of the csr tensor + */ +struct FillCsrColIdxAndVals { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param val value array of the csr + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param dns the dns tensor + * \param num_rows number of rows of the dns + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + const int offset = i * num_cols; + int k = indptr[i]; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + val[k] = dns[offset+j]; + col_idx[k] = j; + ++k; + } + } + } +}; + +/*! + * \brief + * CPU implementation of casting a dns tensor to csr type. + */ +inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + CHECK(csr != nullptr); + CHECK_EQ(csr->storage_type(), kCSRStorage); + CHECK_EQ(dns.shape_.ndim(), 2); + CHECK_EQ(dns.shape_, csr->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); + IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); + DType* dns_data = dns.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, indptr, + dns_data, num_rows, num_cols); + // single thread to accumulate indptr + // indptr[num_rows] indicates the number of non-zero elements + indptr[0] = 0; + for (index_t i = 0; i < num_rows; ++i) { + indptr[i+1] += indptr[i]; + } + // allocate column idx array and value array + csr->CheckAndAllocAuxData(csr::kIdx, + mshadow::Shape1(static_cast(indptr[num_rows]))); + csr->CheckAndAllocData(mshadow::Shape1(static_cast(indptr[num_rows]))); + // fill col_idx and value arrays of the csr + mxnet_op::Kernel::Launch(s, num_rows, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + }); + }); + }); +} + +/*! + * \brief This is the kernel for copying csr.data to its corresponding dns tensor. + */ +struct CopyCsrDataToDns { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param dns_data data blob of the dns tensor + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param csr_data data blob of the csr tensor + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* dns_data, const CType* col_idx, + const IType* indptr, const DType* csr_data, + const int num_cols) { + const int offset = i * num_cols; + for (auto j = indptr[i]; j < indptr[i+1]; ++j) { + dns_data[offset+col_idx[j]] = csr_data[j]; + } + } +}; + +/*! + * \brief Casts a csr tensor to dns format. + */ +template +void CastStorageCsrDnsImpl(mshadow::Stream* s, const NDArray& csr, TBlob* dns) { + CHECK(dns != nullptr); + CHECK_EQ(csr.storage_type(), kCSRStorage); + CHECK_EQ(dns->shape_.ndim(), 2); + CHECK_EQ(dns->shape_, csr.shape()); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns->shape_[0]; + const index_t num_cols = dns->shape_[1]; + DType* dns_data = dns->dptr(); + mxnet_op::Kernel::Launch(s, dns->shape_.Size(), dns_data); + if (!csr.storage_initialized()) return; + const IType* indptr = csr.aux_data(csr::kIndPtr).dptr(); + const CType* col_idx = csr.aux_data(csr::kIdx).dptr(); + const DType* csr_data = csr.data().dptr(); + mxnet_op::Kernel::Launch(s, num_rows, dns_data, + col_idx, indptr, csr_data, num_cols); + }); + }); + }); +} + +inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE(in_attrs->at(0), kUndefinedStorage) + << "src ndarray's storage type must be specified"; + const CastStorageParam& param = nnvm::get(attrs.parsed); + CHECK_NE(param.storage_type, kUndefinedStorage) + << "dst ndarray's storage type must be specified"; + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.storage_type); + return true; +} + +// TODO(junwu) Implement GPU version for these functions +// and move them to a .cuh file +#ifdef __CUDACC__ +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + LOG(FATAL) << "CastStorageDnsRspImpl gpu version is not implemented."; +} + +inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + LOG(FATAL) << "CastStorageDnsCsrImpl gpu version is not implemented."; +} +#endif + +template +void CastStorageComputeImpl(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + using namespace mshadow; + using namespace mshadow::expr; + const auto src_stype = input.storage_type(); + const auto dst_stype = output.storage_type(); + if (src_stype == kRowSparseStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageRspDnsImpl(s, input, &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsRspImpl(s, input.data(), &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kCSRStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsCsrImpl(s, input.data(), &ret); + } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageCsrDnsImpl(s, input, &ret); + } else { + LOG(FATAL) << "Not implemented"; + } +} + +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + CastStorageComputeImpl(s, inputs[0], outputs[0]); +} + #define MXNET_OPERATOR_REGISTER_UNARY(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ @@ -168,4 +589,5 @@ struct relu_grad { } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_H_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 5f010fdfc62c..8cf00c0eb7b4 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -86,6 +86,40 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("TIsBackward", true) .set_attr("FCompute", EmbeddingOpBackward); +NNVM_REGISTER_OP(SparseEmbedding) +.describe(R"code(Maps integer indices to vector representations (embeddings) with sparse weight update +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "weight"}; + }) +.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferType", EmbeddingOpType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", EmbeddingOpForward) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, + {n->inputs[0]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.") +.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") +.add_arguments(EmbeddingParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_SparseEmbedding) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", SparseEmbeddingBackwardStorageType) +.set_attr("FComputeEx", SparseEmbeddingOpBackwardEx); +// TODO(haibin) handle dense case +// .set_attr("FCompute", EmbeddingOpBackward); NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. @@ -230,5 +264,46 @@ Examples:: .add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value") .add_arguments(OneHotParam::__FIELDS__()); +NNVM_REGISTER_OP(sparse_retain) +.describe(R"code(pick rows specified by user input index array from a row sparse matrix +and save them in the output sparse matrix. + +Example:: + + data = [[1, 2], [3, 4], [5, 6]] + indices = [0, 1, 3] + shape = (4, 2) + rsp_in = row_sparse(data, indices) + to_retain = [0, 3] + rsp_out = sparse_retain(rsp_in, to_retain) + rsp_out.values = [[1, 2], [5, 6]] + rsp_out.indices = [0, 3] + +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "indices"}; + }) +.set_attr("FInferShape", SparseRetainOpShape) +.set_attr("FInferType", SparseRetainOpType) +.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) +.set_attr("FComputeEx", SparseRetainOpForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_sparse_retain", n, ograds, + {n->inputs[sr::kIdx]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array for sparse_retain operator.") +.add_argument("indices", "NDArray-or-Symbol", "The index array of rows ids that will be retained."); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 287ec25d70be..4378bd574932 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -26,6 +26,12 @@ NNVM_REGISTER_OP(batch_take) NNVM_REGISTER_OP(one_hot) .set_attr("FCompute", OneHotOpForward); +NNVM_REGISTER_OP(sparse_retain) +.set_attr("FComputeEx", SparseRetainOpForwardEx); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5fd6e81d0b2f..81b219f7c2c9 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -315,6 +316,133 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } +template +struct EmbeddingBackwardRsp { + template + // each thread i is responsible for target gradient row ids in [segment_start, segment_end) + MSHADOW_XINLINE static void Map(int i, const size_t width, IType* dst_idx, DType* dst_val, + const IType* idx, const size_t num_idx, const DType* src, + const size_t segment_len, const size_t num_rows) { + auto req_type = req; + size_t segment_start = i * segment_len; + size_t segment_end = (i + 1) * segment_len; + for (size_t y = 0; y < num_idx; y++) { + size_t j = idx[y]; + if (j >= num_rows) j = num_rows - 1; + if (j < segment_start || j >= segment_end) continue; + dst_idx[j] = j; + for (size_t k = 0; k < width; k++) { + if (req_type == kWriteTo) req_type = kAddTo; + KERNEL_ASSIGN(dst_val[j * width + k], req_type, src[y * width + k]); + } + } + } +}; + +/* + * for sparse embedding, the storage type for weight gradient is row_sparse. + * we don't care about the storage type for data gradient, since it is not + * differentiable. + */ +inline bool SparseEmbeddingBackwardStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ((*in_attrs)[0], kDefaultStorage); + CHECK_EQ((*in_attrs)[1], kDefaultStorage); + (*out_attrs)[0] = kRowSparseStorage; + (*out_attrs)[1] = kRowSparseStorage; + return true; +} + +template +void SparseEmbeddingOpBackwardDnsDnsRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + if (req[1] == kNullOp) return; + // check storage types + auto idx = inputs[1]; // idx shape (d1, d2 .. dk) + auto grad = inputs[0]; // grad shape (d1, d2, .. dk, out_dim) + auto output = outputs[1]; // weight shape (in_dim, out_dim) + CHECK_EQ(idx.storage_type(), kDefaultStorage); + CHECK_EQ(grad.storage_type(), kDefaultStorage); + CHECK_EQ(output.dtype(), grad.dtype()); + CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) << "Index type doesn't match"; + // CHECK_EQ(req[embedding::kData], kNullOp) + // << "Embedding layer doesn't support calculate data gradient" << req[embedding::kData]; + + const TShape& ishape = idx.shape(); + const TShape& oshape = grad.shape(); + + Stream *s = ctx.get_stream(); + CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) + << "embedding input index and gradient row sparse type doesn't match!"; + // Alloc dense output + unsigned int num_rows = output.shape()[0]; + output.CheckAndAlloc({mshadow::Shape1(num_rows)}); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(idx.dtype(), IType, { + MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { + // input embedding indice, each idx in [0, input_dim) + auto idx_data = idx.data().FlatTo1D(s); + auto grad_data = grad.data().get_with_shape( + Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); + auto output_idx = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto output_val = output.data().FlatTo2D(s); + int num_threads = omp_get_num_threads(); + size_t width = output.shape()[1]; + size_t segment_len = (num_rows + num_threads - 1) / num_threads; + // fill indices with invalid row ids + Kernel::Launch(s, num_rows, output_idx.dptr_, + static_cast(num_rows)); + // fill zeros if needed + if (req_type == kWriteTo) { + Kernel::Launch(s, output_val.shape_.Size(), output_val.dptr_); + } + Kernel, xpu>::Launch(s, num_threads, width, + output_idx.dptr_, + output_val.dptr_, idx_data.dptr_, + ishape.Size(), grad_data.dptr_, + segment_len, num_rows); + }); + }); + }); +} + +// todo replace xpu with cpu +template +void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + // CHECK_EQ(req[embedding::kData], kNullOp) + // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; + // idx shape (d1, d2 .. dk) + auto idx_stype = inputs[1].storage_type(); + // grad shape (d1, d2, .. dk, out_dim) + auto grad_stype = inputs[0].storage_type(); + // weight shape (in_dim, out_dim) + auto output_stype = outputs[1].storage_type(); + if (idx_stype == kDefaultStorage && grad_stype == kDefaultStorage && + output_stype == kRowSparseStorage) { + SparseEmbeddingOpBackwardDnsDnsRsp(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented"; + } +} + namespace take_ { // to avoid name conflict enum TakeOpInputs {kArr, kIdx}; enum TakeOpOutputs {kOut}; @@ -667,6 +795,199 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs, }); } +/*! + * \brief sparse retain namespace + */ +namespace sr { +enum SparseRetainOpInputs {kArr, kIdx}; +enum SparseRetainOpOutputs {kOut}; +} // namespace sr + +inline bool SparseRetainOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) + << "sparse_retain operator takes 2 arguments (" << in_attrs->size() << " given)"; + CHECK_EQ(out_attrs->size(), 1U); + + TShape tshape((*in_attrs)[sr::kArr]); + shape_assign(&tshape, (*out_attrs)[sr::kOut]); + SHAPE_ASSIGN_CHECK(*in_attrs, sr::kArr, tshape); + SHAPE_ASSIGN_CHECK(*out_attrs, sr::kOut, tshape); + return true; +} + +inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE((*in_attrs)[sr::kIdx], -1) << "Index type must be set for sparse_retain operator"; + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[sr::kArr]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[sr::kOut]); + return (*in_attrs)[0] != -1; +} + +inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + if (kRowSparseStorage == in_attrs->at(sr::kArr)) { + out_attrs->at(sr::kOut) = kRowSparseStorage; + } + return true; +} + +inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + out_attrs->at(sr::kArr) = kRowSparseStorage; + out_attrs->at(sr::kIdx) = kDefaultStorage; + return true; +} + +struct SparseRetainRspForward { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, RType* out_idx, + const DType* in_data, const RType* in_idx, + const IType* idx, const size_t nnr, + const size_t num_cols) { + const RType irow = idx[i]; + int j = -1, left = 0, right = nnr - 1; + while (left <= right) { + int m = left + (right - left) / 2; + const auto in_idx_m = in_idx[m]; + if (in_idx_m == irow) { + j = m; + break; + } else if (in_idx_m < irow) { + left = m + 1; + } else { + right = m - 1; + } + } + out_idx[i] = idx[i]; + if (j >= 0) { + const size_t in_offset = j * num_cols; + const size_t out_offset = i * num_cols; + for (size_t k = 0; k < num_cols; ++k) { + out_data[out_offset+k] = in_data[in_offset+k]; + } + } + } +}; + +template +void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; + + CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain operator only takes row sparse NDArray as input"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain operator only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage) + << "sparse_retain operator only outputs row sparse NDArray"; + + const NDArray& input_nd = inputs[sr::kArr]; + const TBlob idx_data = inputs[sr::kIdx].data(); + + if (req[sr::kOut] == kNullOp + || !input_nd.storage_initialized() + || idx_data.Size() == 0U) return; + + const TBlob input_data = input_nd.data(); + if (input_data.shape_[0] == 0) return; + const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); + + NDArray output_nd = outputs[sr::kOut]; + output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob output_data = output_nd.data(); + TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type + MSHADOW_INT_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + Kernel::Launch(s, output_data.Size(), output_data.dptr()); + Kernel::Launch(s, idx_data.Size(), output_data.dptr(), + output_idx.dptr(), input_data.dptr(), input_idx.dptr(), + idx_data.dptr(), input_data.shape_[0], input_data.shape_[1]); + }); + }); + }); +} + +template +struct SparseRetainRspBackward { + template + MSHADOW_XINLINE static void Map(int i, DType* in_grad, RType* in_grad_idx, + const DType* out_grad, const IType* idx, + const size_t num_cols) { + const RType irow = idx[i]; + in_grad_idx[i] = irow; + const size_t out_offset = irow * num_cols; + const size_t in_offset = i * num_cols; + for (size_t j = 0; j < num_cols; ++j) { + KERNEL_ASSIGN(in_grad[in_offset+j], req, out_grad[out_offset+j]); + } + } +}; + +template +void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_NE(req[sr::kArr], kWriteInplace); + CHECK_EQ(req[sr::kIdx], kNullOp) + << "sparse_retain does not support calculating gradients of indices"; + + CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as ograd"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain backward only outputs row sparse NDArray as grad of input"; + + const TBlob out_grad_data = inputs[sr::kOut].data(); + const TBlob idx_data = inputs[sr::kIdx].data(); + + NDArray in_grad_nd = outputs[sr::kArr]; + in_grad_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob in_grad_data = in_grad_nd.data(); + TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type + MSHADOW_INT_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, { + Kernel, xpu>::Launch( + s, in_grad_idx.Size(), in_grad_data.dptr(), in_grad_idx.dptr(), + out_grad_data.dptr(), idx_data.dptr(), out_grad_data.shape_[1]); + }); + }); + }); + }); +} + } // namespace op } // namespace mxnet #ifdef __CUDACC__ diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 16f71fc7e4e3..a5827330a61f 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -21,6 +21,7 @@ NNVM_REGISTER_OP(_zeros) .set_attr("FInferShape", InitShape) .set_attr("FInferType", InitType) .set_attr("FCompute", FillCompute) +.set_attr(FCOMP_EX_CPU, FillComputeZerosEx) .add_arguments(InitOpParam::__FIELDS__()); NNVM_REGISTER_OP(_ones) diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu index a798f26db60d..bcb10f70b3c3 100644 --- a/src/operator/tensor/init_op.cu +++ b/src/operator/tensor/init_op.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_zeros) -.set_attr("FCompute", FillCompute); +.set_attr("FCompute", FillCompute) +.set_attr(FCOMP_EX_GPU, FillComputeZerosEx); NNVM_REGISTER_OP(_ones) .set_attr("FCompute", FillCompute); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 5ce132d4bebf..ca61f9bba460 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -15,6 +15,8 @@ #include #include #include "../elemwise_op_common.h" +#include "../mxnet_op.h" + namespace mxnet { namespace op { @@ -111,7 +113,6 @@ inline bool InitType(const nnvm::NodeAttrs& attrs, return true; } - template void FillCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -127,6 +128,51 @@ void FillCompute(const nnvm::NodeAttrs& attrs, }); } +// Fill a rsp NDArray with zeros by updating the aux shape. +template +void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + auto storage_shape = dst->storage_shape(); + storage_shape[0] = 0; + dst->SetAuxShape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); + dst->SetStorageShape(storage_shape); +} + +// Fill a CSR NDArray with zeros by updating the aux shape. +template +void FillZerosCsrImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + TShape new_shape(mshadow::Shape1(0)); + dst->SetAuxShape(csr::kIndPtr, new_shape); + dst->SetAuxShape(csr::kIdx, new_shape); + dst->SetStorageShape(new_shape); +} + +// This operator never needs to fall back, since there's no input NDArray +template +void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(inputs.size(), 0); + auto stype = outputs[0].storage_type(); + if (stype == kRowSparseStorage) { + NDArray nd(outputs[0]); + FillZerosRspImpl(s, &nd); + } else if (stype == kCSRStorage) { + NDArray nd(outputs[0]); + FillZerosCsrImpl(s, &nd); + } else { + LOG(FATAL) << "storage type not implemented."; + } +} template void RangeCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index cdc8819da18e..05fba76d0ff3 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../mxnet_op.h" @@ -476,6 +477,242 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, } } +inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + out_attrs->at(0) = kDefaultStorage; + return true; +} + +inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 2U); + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; + return true; +} + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrTransDnsDns { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows, const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto val = data_l[k]; + const size_t offset_r = col_idx_l[k] * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrTransDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +template +void DotCsrDnsDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kDefaultStorage); + CHECK_EQ(ret->storage_type(), kDefaultStorage); + if (!lhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob data_out = ret->data(); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); + } + } else { // gpu parallelization by output elements + if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + data_out.shape_[1]); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); + }); + } + } + }); + }); + }); +} + +template +void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const DotParam& param = nnvm::get(attrs.parsed); + NDArray ret = outputs[1]; + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0], req[1], !param.transpose_a, &ret); +} + inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -519,6 +756,57 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, return true; } +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + + NDArray ret = outputs[0]; // get rid of the const qualifier + if (inputs[0].storage_type() == kCSRStorage + && inputs[1].storage_type() == kDefaultStorage + && outputs[0].storage_type() == kDefaultStorage) { + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { // TODO(junwu): add fallback + LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " + << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() + << ", out.storage_type = " << outputs[0].storage_type(); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + // TODO(junwu): check whether this CHECK is reasonable + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + if (inputs[0].storage_type() == kDefaultStorage // ograd dns format + // dns, csr, dns => *, dns + && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op + && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op + && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format + DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + } +} + template void BatchDotForward_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -786,6 +1074,96 @@ void Slice(const nnvm::NodeAttrs& attrs, }); } +// slice the indptr of a csr +struct SliceCsrIndPtr { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const IType* in, const IType* base) { + KERNEL_ASSIGN(out[i], kWriteTo, in[i] - *base); + } +}; + +/* + * a wrapper to launch SliceCsrIndPtr kernel. + * slice [src[begin] .. src[end]) and store in dst[0, end - begin) + */ +template +void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx, + const IType* src, IType* dst) { + using namespace mshadow; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + int indptr_len = end - begin + 1; + Kernel::Launch(s, indptr_len, dst, src + begin, src + begin); +} + +/* + * Slice a CSR NDArray + * Only implemented for CPU + */ +template +void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, + const NDArray &in, OpReqType req, const NDArray &out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK((std::is_same::value)) << "Slice for CSR input only implemented for CPU"; + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; + CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; + Stream *s = ctx.get_stream(); + int begin = *param.begin[0]; + int end = *param.end[0]; + int indptr_len = end - begin + 1; + out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); + if (!in.storage_initialized()) { + out.SetAuxShape(kIndPtr, Shape1(0)); + return; + } + CHECK_EQ(in.aux_type(kIndPtr), in.aux_type(kIdx)) + << "The type for indptr and indices are different. This is not implemented yet."; + // assume idx indptr share the same type + MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), IType, { + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + auto in_indptr = in.aux_data(kIndPtr).dptr(); + auto out_indptr = out.aux_data(kIndPtr).dptr(); + SliceCsrIndPtrImpl(begin, end, ctx.run_ctx, in_indptr, out_indptr); + + // retrieve nnz (CPU implementation) + int nnz = out_indptr[indptr_len - 1]; + // copy indices and values + out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + auto in_idx = in.aux_data(kIdx).dptr(); + auto out_idx = out.aux_data(kIdx).dptr(); + auto in_data = in.data().dptr(); + auto out_data = out.data().dptr(); + int offset = in_indptr[begin]; + // this is also a CPU-only implementation + memcpy(out_idx, in_idx + offset, nnz * sizeof(IType)); + memcpy(out_data, in_data + offset, nnz * sizeof(DType)); + }); + }); +} + +template +void SliceEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + const SliceParam& param = nnvm::get(attrs.parsed); + auto in_stype = inputs[0].storage_type(); + CHECK_NE(in_stype, kDefaultStorage) + << "SliceEx is not expected to execute for input with default storage type"; + if (in_stype == kCSRStorage) { + SliceCsrImpl(param, ctx, inputs[0], req[0], outputs[0]); + } else { + LOG(FATAL) << "Slice not implemented for storage type" << in_stype; + } +} + inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index f3d69733a814..0e1d986291cc 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -232,6 +232,9 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``. +For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports +slicing on the first dimension. + Example:: x = [[ 1., 2., 3., 4.], @@ -245,8 +248,10 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) +.set_attr(FCOMP_EX_CPU, SliceEx) .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(SliceParam::__FIELDS__()); @@ -370,7 +375,13 @@ NNVM_REGISTER_OP(dot) }) .set_attr("FInferShape", DotShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) .add_argument("lhs", "NDArray-or-Symbol", "The first input") .add_argument("rhs", "NDArray-or-Symbol", "The second input") @@ -381,7 +392,13 @@ NNVM_REGISTER_OP(_backward_dot) .set_num_outputs(2) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx) .add_arguments(DotParam::__FIELDS__()); NNVM_REGISTER_OP(batch_dot) diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 96c075a7d483..2e1effb9e560 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -40,10 +40,13 @@ NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("FCompute", SliceAxisGrad_); NNVM_REGISTER_OP(dot) -.set_attr("FCompute", DotForward_); +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx); NNVM_REGISTER_OP(_backward_dot) -.set_attr("FCompute", DotBackward_); +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx); + NNVM_REGISTER_OP(batch_dot) .set_attr("FCompute", BatchDotForward_); diff --git a/tests/ci_build/install/ubuntu_install_python.sh b/tests/ci_build/install/ubuntu_install_python.sh index 0459bb9198c4..6ac615c7ee7f 100755 --- a/tests/ci_build/install/ubuntu_install_python.sh +++ b/tests/ci_build/install/ubuntu_install_python.sh @@ -6,5 +6,5 @@ apt-get update && apt-get install -y python-dev python3-dev # the version of the pip shipped with ubuntu may be too lower, install a recent version here cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py -pip2 install nose pylint numpy nose-timer requests -pip3 install nose pylint numpy nose-timer requests +pip2 install nose pylint numpy nose-timer requests scipy +pip3 install nose pylint numpy nose-timer requests scipy diff --git a/tests/cpp/include/test_ndarray_utils.h b/tests/cpp/include/test_ndarray_utils.h new file mode 100644 index 000000000000..4a99d2759c3b --- /dev/null +++ b/tests/cpp/include/test_ndarray_utils.h @@ -0,0 +1,115 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file test_utils.h + * \brief operator unit test utility functions + * \author Haibin Lin +*/ +#ifndef TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ +#define TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ + +/*#include +#include +#include +#include +#include +#include +#include +#include + +#include "../src/operator/tensor/elemwise_binary_op.h" +#include "../src/operator/tensor/elemwise_unary_op.h" +#include "../src/operator/optimizer_op-inl.h" +#include "../src/operator/tensor/init_op.h" + +using namespace mxnet; +#define TEST_DTYPE float +#define TEST_ITYPE int32_t + +void CheckDataRegion(const TBlob &src, const TBlob &dst) { + auto size = src.shape_.Size() * mshadow::mshadow_sizeof(src.type_flag_); + auto equals = memcmp(src.dptr_, dst.dptr_, size); + EXPECT_EQ(equals, 0); +} + +float RandFloat() { + float v = rand() * 1.0 / RAND_MAX; + return v; +} + +// Get an NDArray with provided indices, prepared for a RowSparse NDArray. +NDArray RspIdxND(const TShape shape, const Context ctx, const std::vector &values) { + NDArray nd(shape, ctx, false, ROW_SPARSE_IDX_TYPE); + size_t num_val = values.size(); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = values[i]; + } + }); + return nd; +} + +// Get a dense NDArray with provided values. +NDArray DnsND(const TShape shape, const Context ctx, std::vector vs) { + NDArray nd(shape, ctx, false); + size_t num_val = shape.Size(); + // generate random values + while (vs.size() < num_val) { + auto v = RandFloat(); + vs.push_back(v); + } + CHECK_EQ(vs.size(), nd.shape().Size()); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = vs[i]; + } + }); + return nd; +} + +// Get a RowSparse NDArray with provided indices and values +NDArray RspND(const TShape shape, const Context ctx, const std::vector idx, + std::vector vals) { + CHECK(shape.ndim() <= 2) << "High dimensional row sparse not implemented yet"; + index_t num_rows = idx.size(); + index_t num_cols = vals.size() / idx.size(); + // create index NDArray + NDArray index = RspIdxND(mshadow::Shape1(num_rows), ctx, idx); + CHECK_EQ(vals.size() % idx.size(), 0); + // create value NDArray + NDArray data = DnsND(mshadow::Shape2(num_rows, num_cols), ctx, vals); + // create result nd + NDArray nd(kRowSparseStorage, shape, ctx, false, mshadow::default_type_flag, + {}, {mshadow::Shape1(num_rows)}); + // assign values + NDArray nd_aux = nd.aux_ndarray(0); + NDArray nd_data = nd.data_ndarray(); + CopyFromTo(index, &nd_aux); + CopyFromTo(data, &nd_data); + return nd; +} + +// TODO(haibin) support other types +NDArray Convert(NDArrayStorageType type, NDArray src) { + CHECK_EQ(type, kDefaultStorage); + NDArray converted(src.shape(), src.ctx(), false); + Engine::Get()->PushSync([src, converted](RunContext ctx) { + // TODO provide type in attrs, which is empty now + OpContext op_ctx; + op_ctx.run_ctx = ctx; + if (src.storage_type() == kRowSparseStorage) { + std::vector inputs({src}), outputs({converted}); + op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + } else if (src.storage_type() == kDefaultStorage) { + std::vector inputs({src.data()}), outputs({converted.data()}); + op::IdentityCompute({}, op_ctx, inputs, {kWriteTo}, outputs); + } else { + LOG(FATAL) << "unsupported storage type"; + } + }, src.ctx(), {src.var()}, {converted.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + converted.WaitToRead(); + return converted; +}*/ +#endif // TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 719980b5d4f5..32d60cf3e4e4 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file batchnorm_test.cc - * \brief operator unit test utility functions + * \brief batchnorm operator unit test utility functions * \author Chris Olivier */ @@ -874,8 +874,8 @@ TEST(BATCH_NORM, TestIterAll) { kwargs.push_back({ "cudnn_off", "True" }); } for (TShape shape : shapes) { - for (int g1 = 0; g1 < 2U; ++g1) { - for (int g2 = 0; g2 < 2U; ++g2) { + for (int g1 = 0; g1 < 2; ++g1) { + for (int g2 = 0; g2 < 2; ++g2) { for (int type : v2_types) { MSHADOW_REAL_TYPE_SWITCH_EX( type, DType, AccReal, diff --git a/tests/cpp/operator/ndarray_test.cc b/tests/cpp/operator/ndarray_test.cc new file mode 100644 index 000000000000..f2ed30793881 --- /dev/null +++ b/tests/cpp/operator/ndarray_test.cc @@ -0,0 +1,6 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file ndarray_test.cc + * \brief ndarray unit test utility functions + * \author Haibin Lin +*/ diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk index 808b655e9dba..ec7bb55ec983 100644 --- a/tests/cpp/unittest.mk +++ b/tests/cpp/unittest.mk @@ -47,4 +47,4 @@ testclean: -include build/tests/cpp/*.d -include build/tests/cpp/operator/*.d -include build/tests/cpp/storage/*.d --include build/tests/cpp/engine/*.d \ No newline at end of file +-include build/tests/cpp/engine/*.d diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 35598bc55be8..9188dd9d933f 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -112,6 +112,37 @@ def test_incomplete_infer_concat(): assert arg_shapes['b'] == (2, 5) assert arg_shapes['d'] == (2, 15) +def test_fc_infer_type(): + mx_real_t = mx.base.mx_real_t + data = mx.symbol.Variable('data') + out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) + + # infer type + data_type = mx_real_t + arg_types, out_types, aux_types = out.infer_type(data=data_type) + arg_type_dict = dict(zip(out.list_arguments(), arg_types)) + assert len(out_types) == 1 + assert out_types[0] == mx_real_t + true_types = { + 'fc1_bias' : mx_real_t, + 'fc1_weight' : mx_real_t } + for k, v in true_types.items(): + assert arg_type_dict[k] == v + +def check_infer_storage(v1, v2, v1_storage, v2_storage, out_chunk): + out = mx.symbol.elemwise_add(v1, v2) + arg_storage_types, out_storage_types, aux_storage_types = out.infer_storage_type(v1=v1_storage, v2=v2_storage) + assert len(out_storage_types) == 1 + assert out_storage_types[0] == out_chunk + +def test_elemwise_add_infer_storage_type(): + v1 = mx.symbol.Variable('v1') + v2 = mx.symbol.Variable('v2') + check_infer_storage(v1, v2, 'default', 'default', 'default') + check_infer_storage(v1, v2, 'default', 'row_sparse', 'default') + check_infer_storage(v1, v2, 'row_sparse', 'default', 'default') + check_infer_storage(v1, v2, 'row_sparse', 'row_sparse', 'row_sparse') + if __name__ == "__main__": test_mlp2_infer_shape() test_mlp2_infer_error() @@ -121,3 +152,4 @@ def test_incomplete_infer_concat(): test_incomplete_infer_slicechannel() test_incomplete_infer_convolution() test_incomplete_infer_concat() + test_elemwise_add_infer_storage_type() diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 5fe61b185041..4cbb4f19e40a 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -1,5 +1,6 @@ # pylint: skip-file import mxnet as mx +from mxnet.test_utils import * import numpy as np import os, gzip import pickle as pickle @@ -88,7 +89,43 @@ def test_NDArrayIter(): else: assert(labelcount[i] == 100) +''' +def test_libsvm(): + #TODO(haibin) automatic the test instead of hard coded test + cwd = os.getcwd() + data_path = os.path.join(cwd, 'data.t') + label_path = os.path.join(cwd, 'label.t') + with open(data_path, 'w') as fout: + fout.write('1.0 0:0.5 2:1.2\n') + fout.write('-2.0\n') + fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') + fout.write('4 2:-1.2\n') + + with open(label_path, 'w') as fout: + fout.write('1.0\n') + fout.write('-2.0 0:0.125\n') + fout.write('-3.0 2:1.2\n') + fout.write('4 1:1.0 2:-1.2\n') + + data_dir = os.path.join(os.getcwd(), 'data') + f = (data_path, label_path, (3,), (3,), 3) + data_train = mx.io.LibSVMIter(data_libsvm=f[0], + label_libsvm=f[1], + data_shape=f[2], + label_shape=f[3], + batch_size=f[4]) + + first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) + second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) + i = 0 + for batch in iter(data_train): + expected = first.asnumpy() if i == 0 else second.asnumpy() + assert_almost_equal(data_train.getdata().asnumpy(), expected) + i += 1 +''' + if __name__ == "__main__": test_NDArrayIter() test_MNISTIter() test_Cifar10Rec() + # test_libsvm() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 9f3cff8e1265..470312352b0e 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -1,9 +1,12 @@ import mxnet as mx import mxnet.ndarray as nd +from mxnet.test_utils import * import numpy as np from functools import reduce from mxnet.module.executor_group import DataParallelExecutorGroup +import numpy.random as rnd +import scipy def test_module_dtype(): dtype = np.float16 @@ -262,7 +265,6 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) - def test_executor_group(): def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len): stack = mx.rnn.SequentialRNNCell() @@ -374,6 +376,73 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N test_shared_exec_group(exec_grp_shared=exec_group1, exec_grp_created=exec_group2, shared_arg_names=shared_arg_names, extra_args=extra_args) +def test_module_fm(): + mx.random.seed(11) + rnd.seed(11) + def fm_model(k, feature_dim, storage_type='default'): + initializer = mx.initializer.Normal(sigma=0.01) + x = mx.symbol.Variable("data", storage_type=storage_type) + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=initializer) + + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=initializer) + w1 = mx.symbol.dot(x, w1_weight) + + v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) + x_s = mx.symbol.square(data=x) + bd = 0.5 * mx.symbol.negative(data=mx.symbol.broadcast_mul(x_s, v_s)) + + w2 = mx.symbol.dot(x, v) + w2_squared = 0.5 * mx.symbol.square(data=w2) + + w_all = mx.symbol.Concat(w1, w2_squared, bd, dim=1) + model = mx.symbol.sum(data=w_all, axis=1, keepdims=True) + y = mx.symbol.Variable("out_label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + ctx = default_context() + k = 5 + feature_dim = 20 + model = fm_model(k, feature_dim, 'csr') + + num_batches = 8 + batch_size = 25 + import scipy.sparse as sp + scipy_data = sp.rand(num_batches * batch_size, feature_dim, + density=0.5, format='csr') + dns_label = mx.nd.ones((num_batches * batch_size,1)) + csr_data = mx.sparse_nd.csr(scipy_data.data, scipy_data.indptr, scipy_data.indices, + (num_batches * batch_size, feature_dim)) + data = csr_data + + train_iter = mx.io.NDArrayIter(data=data, + label={'out_label':dns_label}, + batch_size=batch_size) + + # create module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label']) + # allocate memory by given the input data and lable shapes + mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) + # initialize parameters by uniform random numbers + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + # use Sparse SGD with learning rate 0.1 to train + mod.init_optimizer(optimizer='sgd') + # use accuracy as the metric + metric = mx.metric.create('MSE') + # train 5 epoch, i.e. going over the data iter one pass + storage_type_dict = {'v' : 'row_sparse'} + + for epoch in range(10): + train_iter.reset() + metric.reset() + for batch in train_iter: + mod.forward(batch, is_train=True) # compute predictions + mod.update_metric(metric, batch.label) # accumulate prediction accuracy + mod.backward() # compute gradients + mod.update(storage_type_dict) # update parameters + # print('Epoch %d, Training %s' % (epoch, metric.get())) + assert(metric.get()[1] < 0.2) + if __name__ == '__main__': test_module_dtype() @@ -385,3 +454,4 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N test_module_switch_bucket() test_monitor() test_executor_group() + test_module_fm() diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index 8956c4edebac..3293ae2b0abc 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -1,4 +1,5 @@ import os +import numpy as np import mxnet as mx def test_ctx_group(): @@ -32,5 +33,35 @@ def test_ctx_group(): else: assert arr.context == group2ctx['stage2'] +def check_ctx_group_sparse(lhs_stype, rhs_stype): + with mx.AttrScope(ctx_group='stage1'): + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + plus = mx.symbol.elemwise_add(lhs, rhs, name='plus') + + set_stage1 = set(plus.list_arguments()) + with mx.AttrScope(ctx_group='stage2'): + softmax = mx.symbol.SoftmaxOutput(data = plus, name = 'softmax') + + set_stage2 = set(softmax.list_arguments()) - set_stage1 + + group2ctx = { + 'stage1' : mx.cpu(1), + 'stage2' : mx.cpu(2) + } + texec = softmax.simple_bind(mx.cpu(0), group2ctx=group2ctx, lhs=(1,200), rhs=(1,200)) + + for arr, name in zip(texec.arg_arrays, softmax.list_arguments()): + if name in set_stage1: + assert arr.context == group2ctx['stage1'] + else: + assert arr.context == group2ctx['stage2'] + +def test_ctx_group_sparse(): + check_ctx_group_sparse('default', 'default') + check_ctx_group_sparse('default', 'row_sparse') + check_ctx_group_sparse('row_sparse', 'row_sparse') + if __name__ == '__main__': test_ctx_group() + test_ctx_group_sparse() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index dd38bdf98606..adf93a98f26f 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -330,6 +330,7 @@ def test_dot(): assert_almost_equal(c, C.asnumpy()) + def test_reduce(): sample_num = 200 def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 924ef351dbe5..e437b802a825 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2993,7 +2993,6 @@ def test_where_numeric_gradient(shape, same_shape): test_where_numeric_gradient((5, 7, 9), True) test_where_numeric_gradient((5, 7, 9), False) - def test_new_softmax(): for ndim in range(1, 5): for _ in range(5): diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 11ca7bed1743..6f69828ed9b1 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -30,12 +30,23 @@ def test_lr_wd_mult(): assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1) -def compare_optimizer(opt1, opt2, shape): - w1 = mx.random.uniform(shape=shape, ctx=default_context()) - g1 = mx.random.uniform(shape=shape, ctx=default_context()) - - w2 = w1.copyto(default_context()) - g2 = g1.copyto(default_context()) +def compare_optimizer(opt1, opt2, shape, w_stype='default', g_stype='default'): + if w_stype == 'default': + w2 = mx.random.uniform(shape=shape, ctx=default_context()) + w1 = w2.copyto(default_context()) + elif w_stype == 'row_sparse': + w2 = rand_ndarray(shape, w_stype) + w1 = rand_ndarray(shape, w_stype).to_dense() + else: + raise Exception("type not supported yet") + if g_stype == 'default': + g2 = mx.random.uniform(shape=shape, ctx=default_context()) + g1 = g2.copyto(default_context()) + elif g_stype == 'row_sparse': + g2 = rand_ndarray(shape, g_stype) + g1 = g2.copyto(default_context()).to_dense() + else: + raise Exception("type not supported yet") state1 = opt1.create_state(0, w1) state2 = opt2.create_state(0, w2) @@ -130,6 +141,97 @@ def test_sgd(): for kwarg in kwargs: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape) +class PySparseSGD(mx.optimizer.Optimizer): + """python reference implemenation of sgd""" + def __init__(self, learning_rate=0.01, momentum=0.0, **kwargs): + super(PySparseSGD, self).__init__(learning_rate=learning_rate, **kwargs) + self.momentum = momentum + + def create_state(self, index, weight): + """Create additional optimizer state: momentum + + Parameters + ---------- + weight : NDArray + The weight data + + """ + if self.momentum == 0.0: + return None + else: + return mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) + + def update(self, index, weight, grad, state): + """Update the parameters. + + Parameters + ---------- + index : int + An unique integer key used to index the parameters + + weight : NDArray + weight ndarray + + grad : NDArray + grad ndarray + + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. + """ + lr = self._get_lr(index) + wd = self._get_wd(index) + self._update_count(index) + num_rows = weight.shape[0] + if self.momentum == 0.0: + # Update on a per row basis, skip all-zero rows + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + weight[row] = ((1 - lr*wd)*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, + -self.clip_gradient, self.clip_gradient)) + else: + weight[row] = (1 - lr*wd)*weight[row] - lr*self.rescale_grad*grad[row] + else: + mom = state + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + mom[row] = (self.momentum*mom[row] - lr*wd*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, -self.clip_gradient, self.clip_gradient)) + weight[row] += mom[row] + else: + mom[row] = self.momentum*mom[row] - lr*wd*weight[row] - lr*self.rescale_grad*grad[row] + weight[row] += mom[row] + +def test_sparse_sgd(): + mx.random.seed(0) + opt1 = PySparseSGD + opt2 = mx.optimizer.SGD + shape = (3, 4) + kwargs = [{}, + {'momentum': 0.9}, + {'clip_gradient': 0.5}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14}, + {'rescale_grad': 0.8}, + {'clip_gradient': 0.5, 'wd': 0.07}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03}, + {'rescale_grad': 0.8, 'wd': 0.05}, + {'clip_gradient': 0.5, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'momentum': 0.9}, + {'clip_gradient': 0.5, 'wd': 0.07, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] + for kwarg in kwargs: + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='default', g_stype='row_sparse') + # ADAM class PyAdam(mx.optimizer.Optimizer): @@ -354,3 +456,4 @@ def test_rms(): test_adam() test_rms() test_sgd() + test_sparse_sgd() diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py new file mode 100644 index 000000000000..fc27b80f4530 --- /dev/null +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -0,0 +1,276 @@ +import os +import mxnet as mx +import numpy as np +import pickle as pkl +from mxnet.test_utils import * +from numpy.testing import assert_allclose +import numpy.random as rnd + +def assert_fcompex(f, *args, **kwargs): + prev_val = mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", "0", "1") + f(*args, **kwargs) + mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", prev_val) + +def sparse_nd_ones(shape, stype): + return mx.nd.cast_storage(mx.nd.ones(shape), storage_type=stype) + +def check_sparse_nd_elemwise_binary(shapes, storage_types, f, g): + # generate inputs + nds = [] + for i, storage_type in enumerate(storage_types): + if storage_type == 'row_sparse': + nd, _ = rand_sparse_ndarray(shapes[i], storage_type) + elif storage_type == 'default': + nd = mx.nd.array(random_arrays(shapes[i]), dtype = np.float32) + else: + assert(False) + nds.append(nd) + # check result + test = f(nds[0], nds[1]) + assert_almost_equal(test.asnumpy(), g(nds[0].asnumpy(), nds[1].asnumpy())) + +def test_sparse_nd_elemwise_add(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.elemwise_add + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default'] * 2, op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default', 'row_sparse'], op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['row_sparse', 'row_sparse'], op, g) + +# Test a operator which doesn't implement FComputeEx +def test_sparse_nd_elementwise_fallback(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.add_n + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + check_sparse_nd_elemwise_binary(shape, ['default'] * 2, op, g) + check_sparse_nd_elemwise_binary(shape, ['default', 'row_sparse'], op, g) + check_sparse_nd_elemwise_binary(shape, ['row_sparse', 'row_sparse'], op, g) + +def test_sparse_nd_zeros(): + def check_sparse_nd_zeros(stype, shape): + zero = mx.nd.zeros(shape) + sparse_zero = mx.sparse_nd.zeros('row_sparse', shape) + assert_almost_equal(sparse_zero.asnumpy(), zero.asnumpy()) + + shape = rand_shape_2d() + check_sparse_nd_zeros('row_sparse', shape) + check_sparse_nd_zeros('csr', shape) + + +def test_sparse_nd_copy(): + def check_sparse_nd_copy(from_stype, to_stype): + shape = rand_shape_2d() + from_nd = rand_ndarray(shape, from_stype) + # copy to ctx + to_ctx = from_nd.copyto(default_context()) + # copy to stype + to_nd = rand_ndarray(shape, to_stype) + to_nd = from_nd.copyto(to_nd) + assert np.sum(np.abs(from_nd.asnumpy() != to_ctx.asnumpy())) == 0.0 + assert np.sum(np.abs(from_nd.asnumpy() != to_nd.asnumpy())) == 0.0 + + check_sparse_nd_copy('row_sparse', 'row_sparse') + check_sparse_nd_copy('row_sparse', 'default') + check_sparse_nd_copy('default', 'row_sparse') + check_sparse_nd_copy('default', 'csr') + +def check_sparse_nd_prop_rsp(): + storage_type = 'row_sparse' + shape = rand_shape_2d() + nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) + assert(nd._num_aux == 1) + assert(nd.indices.dtype == np.int32) + assert(nd.storage_type == 'row_sparse') + assert_almost_equal(nd.indices.asnumpy(), idx) + +def test_sparse_nd_basic(): + def check_rsp_creation(values, indices, shape): + rsp = mx.sparse_nd.row_sparse(values, indices, shape) + dns = mx.nd.zeros(shape) + dns[1] = mx.nd.array(values[0]) + dns[3] = mx.nd.array(values[1]) + assert_almost_equal(rsp.asnumpy(), dns.asnumpy()) + indices = mx.nd.array(indices).asnumpy() + assert_almost_equal(rsp.indices.asnumpy(), indices) + + def check_csr_creation(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + assert_almost_equal(csr.indptr.asnumpy(), indptr) + assert_almost_equal(csr.indices.asnumpy(), indices) + assert_almost_equal(csr.values.asnumpy(), values) + + shape = (4,2) + values = np.random.rand(2,2) + indices = np.array([1,3]) + check_rsp_creation(values, indices, shape) + + values = mx.nd.array(np.random.rand(2,2)) + indices = mx.nd.array([1,3], dtype='int32') + check_rsp_creation(values, indices, shape) + + values = [[0.1, 0.2], [0.3, 0.4]] + indices = [1,3] + check_rsp_creation(values, indices, shape) + + check_csr_creation(shape) + check_sparse_nd_prop_rsp() + +def test_sparse_nd_setitem(): + def check_sparse_nd_setitem(storage_type, shape, dst): + x = mx.sparse_nd.zeros(storage_type, shape) + x[:] = dst + dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst + assert same(x.asnumpy(), dst_nd.asnumpy()) + + shape = rand_shape_2d() + for stype in ['row_sparse', 'csr']: + # ndarray assignment + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, 'default')) + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype)) + # numpy assignment + check_sparse_nd_setitem(stype, shape, np.ones(shape)) + +def test_sparse_nd_slice(): + def check_sparse_nd_csr_slice(shape): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + A2 = A.asnumpy() + start = rnd.randint(0, shape[0] - 1) + end = rnd.randint(start + 1, shape[0]) + assert same(A[start:end].asnumpy(), A2[start:end]) + + shape = (rnd.randint(2, 10), rnd.randint(1, 10)) + check_sparse_nd_csr_slice(shape) + +def test_sparse_nd_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x == y + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 == x + assert (z.asnumpy() == np.ones(shape)).all() + +def test_sparse_nd_not_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x != y + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 != x + assert (z.asnumpy() == np.zeros(shape)).all() + +def test_sparse_nd_greater(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x > y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y > 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 > y + assert (z.asnumpy() == np.zeros(shape)).all() + +def test_sparse_nd_greater_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = x >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 1 + assert (z.asnumpy() == np.ones(shape)).all() + +def test_sparse_nd_lesser(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = y < x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 < y + assert (z.asnumpy() == np.ones(shape)).all() + z = y < 0 + assert (z.asnumpy() == np.zeros(shape)).all() + +def test_sparse_nd_lesser_equal(): + stype = 'csr' + shape = rand_shape_2d() + x = mx.sparse_nd.zeros(stype, shape) + y = sparse_nd_ones(shape, stype) + z = y <= x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 <= y + assert (z.asnumpy() == np.ones(shape)).all() + z = y <= 0 + assert (z.asnumpy() == np.zeros(shape)).all() + z = 1 <= y + assert (z.asnumpy() == np.ones(shape)).all() + +def test_sparse_nd_binary(): + N = 100 + def check_binary(fn): + for _ in range(N): + ndim = 2 + oshape = np.random.randint(1, 6, size=(ndim,)) + bdim = 2 + lshape = list(oshape) + rshape = list(oshape[ndim-bdim:]) + for i in range(bdim): + sep = np.random.uniform(0, 1) + if sep < 0.33: + lshape[ndim-i-1] = 1 + elif sep < 0.66: + rshape[bdim-i-1] = 1 + lhs = np.random.normal(0, 1, size=lshape) + rhs = np.random.normal(0, 1, size=rshape) + lhs_nd = mx.nd.array(lhs).to_csr() + rhs_nd = mx.nd.array(rhs).to_csr() + assert_allclose(fn(lhs, rhs), + fn(lhs_nd, rhs_nd).asnumpy(), + rtol=1e-4, atol=1e-4) + + #check_binary(lambda x, y: x + y) + check_binary(lambda x, y: x - y) + check_binary(lambda x, y: x * y) + check_binary(lambda x, y: x / y) + check_binary(lambda x, y: x > y) + check_binary(lambda x, y: x < y) + check_binary(lambda x, y: x >= y) + check_binary(lambda x, y: x <= y) + check_binary(lambda x, y: x == y) + +def test_sparse_nd_negate(): + npy = np.random.uniform(-10, 10, rand_shape_2d()) + arr = mx.nd.array(npy).to_csr() + assert_almost_equal(npy, arr.asnumpy()) + assert_almost_equal(-npy, (-arr).asnumpy()) + + # a final check to make sure the negation (-) is not implemented + # as inplace operation, so the contents of arr does not change after + # we compute (-arr) + assert_almost_equal(npy, arr.asnumpy()) + +def test_sparse_nd_output_fallback(): + shape = (10, 10) + out = mx.sparse_nd.zeros('row_sparse', shape) + mx.nd.random_normal(shape=shape, out=out) + assert(np.sum(out.asnumpy()) != 0) + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py new file mode 100644 index 000000000000..d625dfa7906b --- /dev/null +++ b/tests/python/unittest/test_sparse_operator.py @@ -0,0 +1,203 @@ +from mxnet.test_utils import * + + +def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + if lhs_grad_stype is not None: + lhs._set_attr(grad_stype_hint=str(lhs_grad_stype)) + if rhs_grad_stype is not None: + rhs._set_attr(grad_stype_hint=str(rhs_grad_stype)) + + lhs_nd = rand_ndarray(shape, lhs_stype) + rhs_nd = rand_ndarray(shape, rhs_stype) + lhs_np = lhs_nd.asnumpy() + rhs_np = rhs_nd.asnumpy() + + out_np = lhs_np + rhs_np + test = mx.symbol.elemwise_add(lhs, rhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + check_symbolic_forward(test, location, [out_np]) + check_numeric_gradient(test, location) + check_symbolic_backward(test, location, [out_np], [out_np, out_np]) + + +def test_elemwise_add_ex(): + shape = rand_shape_2d() + check_elemwise_add_ex('default', 'default', shape) + check_elemwise_add_ex('default', 'row_sparse', shape) + check_elemwise_add_ex('row_sparse', 'default', shape) + check_elemwise_add_ex('row_sparse', 'row_sparse', shape, + lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + + +# TODO(haibin) randomize this test +def test_elemwise_add_ex_multiple_stages(): + # prep data + shape = (4, 2) + ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) + sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) + + val1 = mx.nd.array([[5, 10]]); + val2 = mx.nd.array([[5, 10]]); + idx1 = mx.nd.array([0], dtype=np.int32); + idx2 = mx.nd.array([1], dtype=np.int32); + sp_nd1 = mx.sparse_nd.row_sparse(val1, idx1, shape) + sp_nd2 = mx.sparse_nd.row_sparse(val2, idx2, shape) + ds_nd = mx.nd.array(ds_np) + + # sparse + sparse = sparse + sp_data1 = mx.symbol.Variable('sp_data1', storage_type='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', storage_type='row_sparse') + ds_data = mx.symbol.Variable('ds_data') + plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') + # sparse + dense = dense + test = mx.symbol.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) + + arr_grads = [mx.nd.zeros(shape) for i in range(3)] + exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, args_grad=arr_grads) + exec_test.forward(is_train=True) + assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) + exec_test.backward(out_grads=exec_test.outputs) + assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + + +# TODO(haibin) also add test for backward pass. Check if exception is thrown +def test_cast_storage_ex(): + def test_rsp_to_dns(shape): + rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') + dns_out = mx.nd.cast_storage(rsp, storage_type='default') + dns_expected = np.zeros(shape, dtype=default_dtype()) + if row_idx is not None: + for k, v in enumerate(row_idx): + dns_expected[v, :] = data[k] + assert same(dns_out.asnumpy(), dns_expected) + + def test_dns_to_rsp(shape): + dns_in = rand_ndarray(shape, 'default') + rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='row_sparse') + ret = mx.nd.cast_storage(rsp_out, storage_type='default') + assert same(ret.asnumpy(), dns_in.asnumpy()) + + def test_csr_to_dns(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + mx_dns = csr.to_dense() + np_dns = sp.csr_matrix((values, indices, indptr), shape).todense() + assert_almost_equal(mx_dns.asnumpy(), np_dns) + + def test_dns_to_csr(dns_in): + dns_in = np.array(dns_in) + csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='csr') + ret = mx.nd.cast_storage(csr_out, storage_type='default') + assert same(ret.asnumpy(), dns_in) + + shape = rand_shape_2d() + test_rsp_to_dns(shape) + test_dns_to_rsp(shape) + test_csr_to_dns((4, 4)) + test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + +def test_sparse_dot(): + def test_dot_csr_dns(csr_shape, dns_shape, trans_csr): + dns1 = rand_ndarray(csr_shape, 'default') + dns2 = rand_ndarray(dns_shape, 'default') + csr = mx.nd.cast_storage(dns1, storage_type='csr') + out = mx.nd.dot(csr, dns2, transpose_a=trans_csr) + assert out.storage_type == 'default' + out_expected = mx.nd.dot(dns1, dns2, transpose_a=trans_csr) + out_np = out_expected.asnumpy() + backward_trans = not trans_csr + rhs_backward_grad = mx.nd.dot(dns1, out_expected, transpose_a=backward_trans).asnumpy() + assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + + # test symbolic forward + lhs = mx.symbol.Variable('lhs', storage_type='csr') + rhs = mx.symbol.Variable('rhs', storage_type='default') + test = mx.symbol.dot(lhs, rhs, transpose_a=trans_csr) + location = {'lhs': csr, 'rhs': dns2} + expected = {'rhs': rhs_backward_grad} + # dot(lhs, rhs) + check_symbolic_forward(test, location, [out_expected.asnumpy()], rtol=1e-3, atol=1e-4) + check_symbolic_backward(test, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write'}, + rtol=1e-3, atol=1e-4) + + lhs_shape = rand_shape_2d() + test_dot_csr_dns(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), False) + test_dot_csr_dns(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), True) + + +def test_sparse_embedding(): + in_dim = 10 + out_dim = 4 + batch = 24 + + data = mx.sym.Variable("data", dtype=np.int32) + embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") + exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, + data=(batch,)) + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) + grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) + np_data = np.random.randint(low=0, high=in_dim, size=batch) + np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) + np_onehot = np.zeros((batch, in_dim)) + np_onehot[np.arange(batch), np_data] = 1.0 + # forward + arg_map["data"][:] = np_data + arg_map["embed_weight"][:] = np_weight + exe_test.forward(is_train=True) + assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight)) + # backward + np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) + grad = mx.nd.zeros(np_grad.shape) + grad[:] = np_grad + exe_test.backward([grad]) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5) + + +def test_sparse_slice(): + def check_csr_slice(shape, slice_input): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + B = A._slice(1, shape[0] - 1) if slice_input else A + np = B.asnumpy() + begin = rnd.randint(0, B.shape[0] - 1) + end = rnd.randint(begin + 1, B.shape[0]) + nd_slice = mx.nd.crop(B, begin=begin, end=end) + assert same(nd_slice.asnumpy(), np[begin:end]), (nd_slice.asnumpy(), np[begin:end]) + + shape = (rnd.randint(7, 15), rnd.randint(1, 10)) + check_csr_slice(shape, True) + check_csr_slice(shape, False) + + +def test_sparse_retain(): + for _ in range(10): + shape = rand_shape_2d() + num_rows = shape[0] + rsp, _ = rand_sparse_ndarray(shape=shape, storage_type='row_sparse', density=0.5) + length = np.random.randint(1, num_rows + 1) + idx = random_sample(list(range(0, num_rows)), length) + idx.sort() + dns = rsp.asnumpy() + tensor_retained_expected = np.zeros(shape) + for i in idx: + tensor_retained_expected[i][:] = dns[i] + indices = mx.nd.array(idx) + rsp_retained = mx.nd.sparse_retain(rsp, indices=indices) + assert same(tensor_retained_expected, rsp_retained.asnumpy()) + + # check numeric gradient + data = mx.symbol.Variable('data') + idx = mx.symbol.Variable('indices') + sym = mx.sym.sparse_retain(data=data, indices=idx) + check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index cff4196b6043..d0ee09312cd4 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -109,11 +109,11 @@ if [ ${TASK} == "python_test" ]; then python -m nose tests/python/doctest || exit -1 python3 -m nose tests/python/doctest || exit -1 else - nosetests tests/python/unittest || exit -1 - nosetests3 tests/python/unittest || exit -1 - nosetests3 tests/python/train || exit -1 - nosetests tests/python/doctest || exit -1 - nosetests3 tests/python/doctest || exit -1 + nosetests -v tests/python/unittest || exit -1 + nosetests3 -v tests/python/unittest || exit -1 + nosetests3 -v tests/python/train || exit -1 + nosetests -v tests/python/doctest || exit -1 + nosetests3 -v tests/python/doctest || exit -1 fi exit 0 fi diff --git a/tests/travis/setup.sh b/tests/travis/setup.sh index ec071009bda5..7c9d137b8269 100755 --- a/tests/travis/setup.sh +++ b/tests/travis/setup.sh @@ -15,8 +15,8 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then brew install ImageMagick brew install swig if [ ${TASK} == "python_test" ]; then - python -m pip install --user nose numpy cython - python3 -m pip install --user nose numpy cython + python -m pip install --user nose numpy cython scipy + python3 -m pip install --user nose numpy cython scipy fi fi