Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Modifies T-SNE for sparse matrix #10206

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
b14d689
initial commit
thechargedneutron Nov 26, 2017
8f10631
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Nov 26, 2017
a224e68
changes added
thechargedneutron Nov 28, 2017
390cb86
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Nov 28, 2017
92d2912
changes added
thechargedneutron Nov 29, 2017
b8fc385
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Nov 29, 2017
fb9a76e
changes added
thechargedneutron Nov 30, 2017
9e109a9
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Nov 30, 2017
500f0c4
check for sparse added
thechargedneutron Nov 30, 2017
2d985bd
changes added
thechargedneutron Nov 30, 2017
fdfbf11
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Nov 30, 2017
4c2cf9a
check for non-negativity done
thechargedneutron Nov 30, 2017
164a232
tests added
thechargedneutron Nov 30, 2017
960d2dd
added backport of getnnz
thechargedneutron Dec 1, 2017
1e53a47
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 1, 2017
c51e60f
added sparse in error message
thechargedneutron Dec 1, 2017
5d8e59d
ensure_min_sample restored
thechargedneutron Dec 4, 2017
c394077
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 4, 2017
a11bd8e
backport for downcast_intp_index added
thechargedneutron Dec 4, 2017
c4005fb
pep8 fialures corrected
thechargedneutron Dec 4, 2017
9a47517
spelling corrected
thechargedneutron Dec 4, 2017
5fc3f04
newline added
thechargedneutron Dec 4, 2017
ef809a6
_swap removed
thechargedneutron Dec 4, 2017
016cb39
corrections added
thechargedneutron Dec 4, 2017
abcbb91
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 4, 2017
6c8d913
swap removed
thechargedneutron Dec 4, 2017
833726d
changes added
thechargedneutron Dec 4, 2017
fc9c138
checking of sparse matrix changed
thechargedneutron Dec 4, 2017
aa0da86
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 4, 2017
2a7edfc
indentation corrected
thechargedneutron Dec 4, 2017
4084ad0
spelling corrected
thechargedneutron Dec 5, 2017
49f275c
distance check moved inside precomputed distance check
thechargedneutron Dec 7, 2017
509e573
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 7, 2017
dcd7b79
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 11, 2017
8c933fd
tests removed
thechargedneutron Dec 14, 2017
1511104
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 14, 2017
768c3d3
started adding tests
thechargedneutron Dec 14, 2017
c7d1270
added non negative test in csr matrix
thechargedneutron Dec 16, 2017
38d639d
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 16, 2017
bb97e05
test for high perplexity added
thechargedneutron Dec 17, 2017
4107d64
redundant code removed
thechargedneutron Dec 17, 2017
31be9b2
restored with better error message
thechargedneutron Dec 17, 2017
e0dd45c
import statement added
thechargedneutron Dec 17, 2017
a582b16
tests added and error message corrected
thechargedneutron Dec 17, 2017
acf180a
minor mistakes corrected
thechargedneutron Dec 17, 2017
d4851fe
changes added
thechargedneutron Dec 17, 2017
ec7ac78
pep8 corrected
thechargedneutron Dec 17, 2017
b03c37c
changes added
thechargedneutron Dec 18, 2017
dd404e4
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 19, 2017
27f41fc
dtype corrected
thechargedneutron Dec 19, 2017
c472657
query_is_train logic added
thechargedneutron Dec 20, 2017
7c7a240
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Dec 20, 2017
acb9c2e
tests added
thechargedneutron Jan 14, 2018
eba780a
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Jan 14, 2018
d105231
pep8 corrected
thechargedneutron Jan 14, 2018
ab82925
Tests for sparse support in tSNE and Neighbors (#4)
jnothman Jan 16, 2018
e6bc8aa
previous review comments
thechargedneutron Jan 16, 2018
cb3127d
number of neighbors mismatch changed
thechargedneutron Jan 17, 2018
788d890
initial changes
thechargedneutron Jan 20, 2018
76b0e69
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
thechargedneutron Jan 20, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 15 additions & 14 deletions sklearn/manifold/t_sne.py
Expand Up @@ -25,6 +25,7 @@
from . import _barnes_hut_tsne
from ..externals.six import string_types
from ..utils import deprecated
from ..utils.fixes import getnnz


MACHINE_EPSILON = np.finfo(np.double).eps
Expand Down Expand Up @@ -639,29 +640,22 @@ def _fit(self, X, skip_num_points=0):
raise ValueError("'method' must be 'barnes_hut' or 'exact'")
if self.angle < 0.0 or self.angle > 1.0:
raise ValueError("'angle' must be between 0.0 - 1.0")
if self.method == 'barnes_hut':
X = check_array(X, accept_sparse=['csr'], ensure_min_samples=2,
dtype=[np.float32, np.float64])
else:
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
dtype=[np.float32, np.float64])
if self.metric == "precomputed":
if isinstance(self.init, string_types) and self.init == 'pca':
raise ValueError("The parameter init=\"pca\" cannot be "
"used with metric=\"precomputed\".")
if X.shape[0] != X.shape[1]:
raise ValueError("X should be a square distance matrix")
if np.any(X < 0):
if X.min() < 0:
raise ValueError("All distances should be positive, the "
"precomputed distances given as X is not "
"correct")
if self.method == 'barnes_hut' and sp.issparse(X):
raise TypeError('A sparse matrix was passed, but dense '
'data is required for method="barnes_hut". Use '
'X.toarray() to convert to a dense numpy array if '
'the array is small enough for it to fit in '
'memory. Otherwise consider dimensionality '
'reduction techniques (e.g. TruncatedSVD)')
if self.method == 'barnes_hut':
X = check_array(X, ensure_min_samples=2,
dtype=[np.float32, np.float64])
else:
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
dtype=[np.float32, np.float64])
if self.method == 'barnes_hut' and self.n_components > 3:
raise ValueError("'n_components' should be inferior to 4 for the "
"barnes_hut algorithm as it relies on "
Expand Down Expand Up @@ -714,6 +708,13 @@ def _fit(self, X, skip_num_points=0):
if self.verbose:
print("[t-SNE] Computing {} nearest neighbors...".format(k))

if self.metric == "precomputed" and sp.issparse(X):
if np.any(getnnz(X, axis=1) < k):
raise ValueError("{} neighbors per sample are required for"
" current perplexity. Decrease "
"perplexity, or provide more precomputed"
" distances per sample".format(k))

# Find the nearest neighbors for every point
knn = NearestNeighbors(algorithm='auto', n_neighbors=k,
metric=self.metric)
Expand Down
74 changes: 50 additions & 24 deletions sklearn/manifold/tests/test_t_sne.py
Expand Up @@ -2,9 +2,11 @@
from sklearn.externals.six.moves import cStringIO as StringIO
import numpy as np
import scipy.sparse as sp
import pytest

from sklearn.neighbors import BallTree
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import kneighbors_graph
from sklearn.utils.testing import assert_less_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_almost_equal
Expand Down Expand Up @@ -265,14 +267,15 @@ def test_optimization_minimizes_kl_divergence():
assert_less_equal(kl_divergences[2], kl_divergences[1])


def test_fit_csr_matrix():
@pytest.mark.parametrize('method', ['exact', 'barnes_hut'])
def test_fit_csr_matrix(method):
# X can be a sparse matrix.
random_state = check_random_state(0)
X = random_state.randn(100, 2)
X[(np.random.randint(0, 100, 50), np.random.randint(0, 2, 50))] = 0.0
X_csr = sp.csr_matrix(X)
tsne = TSNE(n_components=2, perplexity=10, learning_rate=100.0,
random_state=0, method='exact')
random_state=0, method=method)
X_embedded = tsne.fit_transform(X_csr)
assert_almost_equal(trustworthiness(X_csr, X_embedded, n_neighbors=1), 1.0,
decimal=1)
Expand Down Expand Up @@ -307,20 +310,54 @@ def test_too_few_iterations():
np.array([[0.0], [0.0]]))


def test_non_square_precomputed_distances():
# Precomputed distance matrices must be square matrices.
@pytest.mark.parametrize('method, retype', [
('exact', np.asarray),
('barnes_hut', np.asarray),
('barnes_hut', sp.csr_matrix),
])
@pytest.mark.parametrize('D, message_regex', [
([[0.0], [1.0]], ".* square distance matrix"),
([[0., -1.], [1., 0.]], ".* positive.*"),
])
def test_bad_precomputed_distances(method, D, retype, message_regex):
tsne = TSNE(metric="precomputed", method=method)
assert_raises_regexp(ValueError, message_regex,
tsne.fit_transform, retype(D))


def test_exact_no_precomputed_sparse():
tsne = TSNE(metric='precomputed', method='exact')
assert_raises_regexp(TypeError, 'sparse',
tsne.fit_transform,
sp.csr_matrix([[0, 5], [5, 0]]))


def test_high_perplexity_precomputed_sparse_distances():
# Perplexity should be less than 50
dist = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 0.]])
bad_dist = sp.csr_matrix(dist)
tsne = TSNE(metric="precomputed")
assert_raises_regexp(ValueError, ".* square distance matrix",
tsne.fit_transform, np.array([[0.0], [1.0]]))
assert_raises_regexp(ValueError, "2 neighbors per sample are "
"required .*perplexity.*precomputed distance.*",
tsne.fit_transform, bad_dist)


def test_non_positive_precomputed_distances():
# Precomputed distance matrices must be positive.
bad_dist = np.array([[0., -1.], [1., 0.]])
for method in ['barnes_hut', 'exact']:
tsne = TSNE(metric="precomputed", method=method)
assert_raises_regexp(ValueError, "All distances .*precomputed.*",
tsne.fit_transform, bad_dist)
def test_sparse_precomputed_distance():
"""Make sure that TSNE works identically for sparse and dense matrix"""
random_state = check_random_state(0)
X = random_state.randn(100, 2)

D_sparse = kneighbors_graph(X, n_neighbors=99, mode='distance')
D = pairwise_distances(X)
assert sp.issparse(D_sparse)
assert_almost_equal(D_sparse.A, D)

tsne = TSNE(metric="precomputed", random_state=0)
Xt_dense = tsne.fit_transform(D)

for fmt in ['csr', 'lil']:
Xt_sparse = tsne.fit_transform(D_sparse.asformat(fmt))
assert_almost_equal(Xt_dense, Xt_sparse)


def test_non_positive_computed_distances():
Expand Down Expand Up @@ -555,17 +592,6 @@ def test_reduction_to_one_component():
assert(np.all(np.isfinite(X_embedded)))


def test_no_sparse_on_barnes_hut():
# No sparse matrices allowed on Barnes-Hut.
random_state = check_random_state(0)
X = random_state.randn(100, 2)
X[(np.random.randint(0, 100, 50), np.random.randint(0, 2, 50))] = 0.0
X_csr = sp.csr_matrix(X)
tsne = TSNE(n_iter=199, method='barnes_hut')
assert_raises_regexp(TypeError, "A sparse matrix was.*",
tsne.fit_transform, X_csr)


def test_64bit():
# Ensure 64bit arrays are handled correctly.
random_state = check_random_state(0)
Expand Down
47 changes: 39 additions & 8 deletions sklearn/neighbors/base.py
Expand Up @@ -18,6 +18,7 @@
from ..metrics import pairwise_distances
from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS
from ..utils import check_X_y, check_array, _get_n_jobs, gen_even_slices
from ..utils.fixes import getnnz
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..externals import six
Expand Down Expand Up @@ -359,17 +360,48 @@ class from an array representing our data set and ask who's
X, self._fit_X, self.effective_metric_, n_jobs=n_jobs,
**self.effective_metric_params_)

neigh_ind = np.argpartition(dist, n_neighbors - 1, axis=1)
neigh_ind = neigh_ind[:, :n_neighbors]
# argpartition doesn't guarantee sorted order, so we sort again
neigh_ind = neigh_ind[
sample_range, np.argsort(dist[sample_range, neigh_ind])]
if issparse(dist):
print "Dist being printed \n"
print dist.toarray()
print dist.indices
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://gist.github.com/thechargedneutron/c2f95ac46a9ab61beb0de8a91a9a533f
Running the above code (tests provided by @jnothman ) on this branch produces the following dist matrix:

[[0.    0.    0.    0.    0.    0.    0.56968608    0.    0.44086692    0. ]
 [0.    0.    0.71768226    0.    0.    0.71600003    0.    0.    0.    0. ]
 [0.    0.71768226    0.    0.    0.    0.    0.51642623    0.    0.    0. ]
 [0.    0.    0.    0.    0.    0.31561065    0.    0.51627111    0.    0. ]
 [0.    0.    0.    0.    0.    0.    0.44107968    0.    0.    0.51790786]
 [0.    0.    0.    0.31561065    0.    0.    0.    0.38322922    0.    0.]
 [0.    0.    0.51642623    0.    0.44107968    0.    0.    0.    0.    0.]
 [0.    0.    0.    0.    0.    0.38322922    0.    0.    0.    0.40501049]
 [0.44086692    0.    0.    0.    0.88847197    0.    0.    0.    0.    0.]
 [0.    0.    0.    0.    0.    0.50529526    0.    0.40501049    0.    0.]]

And the corresponding indices of dist are:

[0 8 6 1 5 2 2 6 1 3 5 7 4 6 9 5 3 7 6 4 2 7 5 9 8 0 4 9 7 5]

And the expected result of indices are:

[[8 6 4]
 [5 2 9]
 [6 1 4]
 [5 7 4]
 [6 9 5]
 [3 7 9]
 [4 2 0]
 [5 9 3]
 [0 4 6]
 [7 5 4]]

But in row 1 for example, how are we supposed to get 4 as an index while it is passed as a sparse array with a zero at that position (visible from indices output). Am I missing out on something?
I know this issue is taking too long. Sorry for this.

Copy link
Member

@TomDLT TomDLT Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dist matrix give you 2 neighbors, 3 if you include self as explicit zeros.
The expected result of indices corresponds to 3 neighbors, without self.

The fix is replacing in the gist nn.kneighbors_graph(X_train.copy(), mode='distance') by nn.kneighbors_graph(None, mode='distance'), which is the correct syntax to have 3 neighbors without self.

In base.py you also have to uncomment:

                if query_is_train:
                    # this is done to add self as nearest neighbor
                    neigh_ind = np.concatenate((sample_range, neigh_ind),
                                               axis=1)
                    neigh_ind = neigh_ind[:, :-1]

Copy link
Member

@TomDLT TomDLT Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my mistake, you are testing the explicit diagonal case...

The expected result of indices is wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomDLT Is the fix that you provided above in the gist still valid? (after you accepted there's fault in your above comment?) . I am not familiar with generating CSR Matrix from kneighbors_graph.

if np.any(getnnz(dist, axis=1) < n_neighbors - query_is_train):
raise ValueError("Not enough neighbors in sparse "
"precomputed matrix to get {} "
"nearest neighbors"
.format(n_neighbors - query_is_train))
neigh_ind = np.full((dist.shape[0], dist.shape[1]), np.inf, dtype=np.int)
for i in range(0, dist.shape[0]):
row = np.full(dist.shape[1], np.inf)
data_col = dist.indices[dist.indptr[i]:dist.indptr[i + 1]]
data_values = dist.data[dist.indptr[i]:dist.indptr[i + 1]]
row[data_col] = data_values
neigh_ind[i] = np.argsort(row)
neigh_ind = neigh_ind[:, :n_neighbors]
'''
if query_is_train and np.sum([(num in neigh_ind[num]) for num in range(neigh_ind.shape[0])]) == 0:
# this is done to add self as nearest neighbor
neigh_ind = np.concatenate((sample_range, neigh_ind),
axis=1)
neigh_ind = neigh_ind[:, :-1]
'''
else:
neigh_ind = np.argpartition(dist, n_neighbors - 1, axis=1)
neigh_ind = neigh_ind[:, :n_neighbors]
# argpartition doesn't guarantee sorted order, so we sort again
neigh_ind = neigh_ind[
sample_range, np.argsort(dist[sample_range, neigh_ind])]

if return_distance:
if self.effective_metric_ == 'euclidean':
result = np.sqrt(dist[sample_range, neigh_ind]), neigh_ind
if issparse(dist):
result = np.sqrt(dist[sample_range, neigh_ind]).toarray(), neigh_ind
else:
result = np.sqrt(dist[sample_range, neigh_ind]), neigh_ind
else:
result = dist[sample_range, neigh_ind], neigh_ind
if issparse(dist):
result = dist[sample_range, neigh_ind].toarray(), neigh_ind
else:
result = dist[sample_range, neigh_ind], neigh_ind
else:
result = neigh_ind

Expand Down Expand Up @@ -410,7 +442,6 @@ class from an array representing our data set and ask who's
# In that case mask the first duplicate.
dup_gr_nbrs = np.all(sample_mask, axis=1)
sample_mask[:, 0][dup_gr_nbrs] = False

neigh_ind = np.reshape(
neigh_ind[sample_mask], (n_samples, n_neighbors - 1))

Expand Down
63 changes: 59 additions & 4 deletions sklearn/neighbors/tests/test_neighbors.py
@@ -1,5 +1,6 @@
from itertools import product

import pytest
import numpy as np
from scipy.sparse import (bsr_matrix, coo_matrix, csc_matrix, csr_matrix,
dok_matrix, lil_matrix, issparse)
Expand All @@ -18,6 +19,7 @@
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_in
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raises_regex
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_warns_message
Expand Down Expand Up @@ -108,14 +110,13 @@ def test_unsupervised_inputs():
assert_array_almost_equal(ind1, ind2)


def test_precomputed(random_state=42):
def check_precomputed(make_train_test):
"""Tests unsupervised NearestNeighbors with a distance matrix."""
# Note: smaller samples may result in spurious test success
rng = np.random.RandomState(random_state)
rng = np.random.RandomState(42)
X = rng.random_sample((10, 4))
Y = rng.random_sample((3, 4))
DXX = metrics.pairwise_distances(X, metric='euclidean')
DYX = metrics.pairwise_distances(Y, X, metric='euclidean')
DXX, DYX = make_train_test(X, Y)
for method in ['kneighbors']:
# TODO: also test radius_neighbors, but requires different assertion

Expand Down Expand Up @@ -163,6 +164,60 @@ def test_precomputed(random_state=42):
assert_array_almost_equal(pred_X, pred_D)


def test_precomputed_dense():
def make_train_test(X_train, X_test):
return (metrics.pairwise_distances(X_train),
metrics.pairwise_distances(X_test, X_train))

check_precomputed(make_train_test)


@pytest.mark.parametrize('fmt', ['csr', 'lil'])
def test_precomputed_sparse_implicit_diagonal(fmt):
def make_train_test(X_train, X_test):
nn = neighbors.NearestNeighbors(n_neighbors=3).fit(X_train)
return (nn.kneighbors_graph(mode='distance').asformat(fmt),
nn.kneighbors_graph(X_test, mode='distance').asformat(fmt))

check_precomputed(make_train_test)


def test_precomputed_sparse_explicit_diagonal():
def make_train_test(X_train, X_test):
nn = neighbors.NearestNeighbors(n_neighbors=3, radius=10).fit(X_train)
return (nn.kneighbors_graph(X_train.copy(), mode='distance'),
nn.kneighbors_graph(X_test, mode='distance'))

check_precomputed(make_train_test)


def test_precomputed_sparse_invalid():
# TODO:
# Ensures enough number of nearest neighbors
dist = np.array([[0., 2., 1.], [2., 0., 3.], [1., 3., 0.]])
dist_csr = csr_matrix(dist)
neigh = neighbors.NearestNeighbors(n_neighbors=1, metric="precomputed")
neigh.fit(dist_csr)
neigh.kneighbors(None, n_neighbors=1)
neigh.kneighbors(np.array([[0., 0., 0.]]), n_neighbors=1)

dist = np.array([[0., 2., 0.], [2., 0., 3.], [0., 3., 0.]])
dist_csr = csr_matrix(dist)
neigh.fit(dist_csr)
assert_raises_regex(ValueError, "Not enough neighbors in"
" .* to get 1 nearest neighbors.*", neigh.kneighbors,
None, n_neighbors=1)

# Checks error with inconsistent distance matrix
dist = np.array([[5., 2., 1.], [2., 0., 3.], [1., 3., 0.]])
dist_csr = csr_matrix(dist)
neigh = neighbors.NearestNeighbors(n_neighbors=1, metric="precomputed")
neigh.fit(dist_csr)
assert_raises_regex(ValueError, "Not a valid distance"
" .*non-negative values.*", neigh.kneighbors,
None, n_neighbors=1)


def test_precomputed_cross_validation():
# Ensure array is split correctly
rng = np.random.RandomState(0)
Expand Down
37 changes: 37 additions & 0 deletions sklearn/utils/fixes.py
Expand Up @@ -295,3 +295,40 @@ def __getstate__(self):
self._fill_value)
else:
from numpy.ma import MaskedArray # noqa


# Remove when minimum required SciPy >= 0.17.0
def downcast_intp_index(arr):
"""
Down-cast index array to np.intp dtype if it is of a larger dtype.

Raise an error if the array contains a value that is too large for
intp.
"""
if arr.dtype.itemsize > np.dtype(np.intp).itemsize:
if arr.size == 0:
return arr.astype(np.intp)
maxval = arr.max()
minval = arr.min()
if maxval > np.iinfo(np.intp).max or minval < np.iinfo(np.intp).min:
raise ValueError("Cannot deal with arrays with indices larger "
"than the machine maximum address size "
"(e.g. 64-bit indices on 32-bit machine).")
return arr.astype(np.intp)
return arr


def getnnz(X, axis=None):
if axis is None:
return int(X.indptr[-1])
else:
if axis < 0:
axis += 2
axis, _ = axis, 1 - axis
_, N = X.shape
if axis == 0:
return np.bincount(downcast_intp_index(X.indices),
minlength=N)
elif axis == 1:
return np.diff(X.indptr)
raise ValueError('axis out of bounds')