Skip to content

Commit

Permalink
ENH svmlight chunk loader (#935)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Jun 16, 2017
1 parent 7238b46 commit a39c8ab
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 17 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -202,11 +202,16 @@ Enhancements

- Prevent cast from float32 to float64 in
:class:`linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers
:class:`sklearn.linear_model.Ridge` when using svd, sparse_cg, cholesky or lsqr solvers
by :user:`Joan Massich <massich>`, :user:`Nicolas Cordier <ncordier>`

- Add ``max_train_size`` parameter to :class:`model_selection.TimeSeriesSplit`
:issue:`8282` by :user:`Aman Dalmia <dalmia>`.

- Make it possible to load a chunk of an svmlight formatted file by
passing a range of bytes to :func:`datasets.load_svmlight_file`.
:issue:`935` by :user:`Olivier Grisel <ogrisel>`.

Bug fixes
.........

Expand Down
15 changes: 13 additions & 2 deletions sklearn/datasets/_svmlight_format.pyx
Expand Up @@ -26,7 +26,7 @@ cdef bytes COLON = u':'.encode('ascii')
@cython.boundscheck(False)
@cython.wraparound(False)
def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
bint query_id):
bint query_id, long long offset, long long length):
cdef array.array data, indices, indptr
cdef bytes line
cdef char *hash_ptr
Expand All @@ -35,6 +35,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
cdef Py_ssize_t i
cdef bytes qid_prefix = b('qid')
cdef Py_ssize_t n_features
cdef long long offset_max = offset + length if length > 0 else -1

# Special-case float32 but use float64 for everything else;
# the Python code will do further conversions.
Expand All @@ -52,6 +53,12 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
else:
labels = array.array("d")

if offset > 0:
f.seek(offset)
# drop the current line that might be truncated and is to be
# fetched by another call
f.readline()

for line in f:
# skip comments
line_cstr = line
Expand Down Expand Up @@ -90,7 +97,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
idx = int(idx_s)
if idx < 0 or not zero_based and idx == 0:
raise ValueError(
"Invalid index %d in SVMlight/LibSVM data file." % idx)
"Invalid index %d in SVMlight/LibSVM data file." % idx)
if idx <= prev_idx:
raise ValueError("Feature indices in SVMlight/LibSVM data "
"file should be sorted and unique.")
Expand All @@ -106,4 +113,8 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
array.resize_smart(indptr, len(indptr) + 1)
indptr[len(indptr) - 1] = len(data)

if offset_max != -1 and f.tell() > offset_max:
# Stop here and let another call deal with the following.
break

return (dtype, data, indices, indptr, labels, query)
70 changes: 56 additions & 14 deletions sklearn/datasets/svmlight_format.py
Expand Up @@ -31,7 +31,8 @@


def load_svmlight_file(f, n_features=None, dtype=np.float64,
multilabel=False, zero_based="auto", query_id=False):
multilabel=False, zero_based="auto", query_id=False,
offset=0, length=-1):
"""Load datasets in the svmlight / libsvm format into sparse CSR matrix
This format is a text-based format, with one sample per line. It does
Expand Down Expand Up @@ -76,6 +77,8 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
bigger sliced dataset: each subset might not have examples of
every feature, hence the inferred shape might vary from one
slice to another.
n_features is only required if ``offset`` or ``length`` are passed a
non-default value.
multilabel : boolean, optional, default False
Samples may have several labels each (see
Expand All @@ -88,7 +91,10 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
If set to "auto", a heuristic check is applied to determine this from
the file contents. Both kinds of files occur "in the wild", but they
are unfortunately not self-identifying. Using "auto" or True should
always be safe.
always be safe when no ``offset`` or ``length`` is passed.
If ``offset`` or ``length`` are passed, the "auto" mode falls back
to ``zero_based=True`` to avoid having the heuristic check yield
inconsistent results on different segments of the file.
query_id : boolean, default False
If True, will return the query_id array for each file.
Expand All @@ -97,6 +103,15 @@ def load_svmlight_file(f, n_features=None, dtype=np.float64,
Data type of dataset to be loaded. This will be the data type of the
output numpy arrays ``X`` and ``y``.
offset : integer, optional, default 0
Ignore the offset first bytes by seeking forward, then
discarding the following bytes up until the next new line
character.
length : integer, optional, default -1
If strictly positive, stop reading any new line of data once the
position in the file has reached the (offset + length) bytes threshold.
Returns
-------
X : scipy.sparse matrix of shape (n_samples, n_features)
Expand Down Expand Up @@ -129,7 +144,7 @@ def get_data():
X, y = get_data()
"""
return tuple(load_svmlight_files([f], n_features, dtype, multilabel,
zero_based, query_id))
zero_based, query_id, offset, length))


def _gen_open(f):
Expand All @@ -149,15 +164,18 @@ def _gen_open(f):
return open(f, "rb")


def _open_and_load(f, dtype, multilabel, zero_based, query_id):
def _open_and_load(f, dtype, multilabel, zero_based, query_id,
offset=0, length=-1):
if hasattr(f, "read"):
actual_dtype, data, ind, indptr, labels, query = \
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id)
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id,
offset, length)
# XXX remove closing when Python 2.7+/3.1+ required
else:
with closing(_gen_open(f)) as f:
actual_dtype, data, ind, indptr, labels, query = \
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id)
_load_svmlight_file(f, dtype, multilabel, zero_based, query_id,
offset, length)

# convert from array.array, give data the right dtype
if not multilabel:
Expand All @@ -172,7 +190,8 @@ def _open_and_load(f, dtype, multilabel, zero_based, query_id):


def load_svmlight_files(files, n_features=None, dtype=np.float64,
multilabel=False, zero_based="auto", query_id=False):
multilabel=False, zero_based="auto", query_id=False,
offset=0, length=-1):
"""Load dataset from multiple files in SVMlight format
This function is equivalent to mapping load_svmlight_file over a list of
Expand Down Expand Up @@ -216,7 +235,10 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
If set to "auto", a heuristic check is applied to determine this from
the file contents. Both kinds of files occur "in the wild", but they
are unfortunately not self-identifying. Using "auto" or True should
always be safe.
always be safe when no offset or length is passed.
If offset or length are passed, the "auto" mode falls back
to zero_based=True to avoid having the heuristic check yield
inconsistent results on different segments of the file.
query_id : boolean, defaults to False
If True, will return the query_id array for each file.
Expand All @@ -225,6 +247,15 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
Data type of dataset to be loaded. This will be the data type of the
output numpy arrays ``X`` and ``y``.
offset : integer, optional, default 0
Ignore the offset first bytes by seeking forward, then
discarding the following bytes up until the next new line
character.
length : integer, optional, default -1
If strictly positive, stop reading any new line of data once the
position in the file has reached the (offset + length) bytes threshold.
Returns
-------
[X1, y1, ..., Xn, yn]
Expand All @@ -245,16 +276,27 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
--------
load_svmlight_file
"""
r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id))
if (offset != 0 or length > 0) and zero_based == "auto":
# disable heuristic search to avoid getting inconsistent results on
# different segments of the file
zero_based = True

if (offset != 0 or length > 0) and n_features is None:
raise ValueError(
"n_features is required when offset or length is specified.")

r = [_open_and_load(f, dtype, multilabel, bool(zero_based), bool(query_id),
offset=offset, length=length)
for f in files]

if (zero_based is False
or zero_based == "auto" and all(np.min(tmp[1]) > 0 for tmp in r)):
for ind in r:
indices = ind[1]
if (zero_based is False or
zero_based == "auto" and all(len(tmp[1]) and np.min(tmp[1]) > 0
for tmp in r)):
for _, indices, _, _, _ in r:
indices -= 1

n_f = max(ind[1].max() for ind in r) + 1
n_f = max(ind[1].max() if len(ind[1]) else 0 for ind in r) + 1

if n_features is None:
n_features = n_f
elif n_features < n_f:
Expand Down
108 changes: 107 additions & 1 deletion sklearn/datasets/tests/test_svmlight_format.py
@@ -1,3 +1,4 @@
from __future__ import division
from bz2 import BZ2File
import gzip
from io import BytesIO
Expand All @@ -13,8 +14,10 @@
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raises_regex
from sklearn.utils.testing import raises
from sklearn.utils.testing import assert_in
from sklearn.utils.fixes import sp_version

import sklearn
from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
Expand Down Expand Up @@ -401,4 +404,107 @@ def test_load_with_long_qid():
f.seek(0)
X, y = load_svmlight_file(f, query_id=False, zero_based=True)
assert_array_equal(y, true_y)
assert_array_equal(X.toarray(), true_X)
assert_array_equal(X.toarray(), true_X)


def test_load_zeros():
f = BytesIO()
true_X = sp.csr_matrix(np.zeros(shape=(3, 4)))
true_y = np.array([0, 1, 0])
dump_svmlight_file(true_X, true_y, f)

for zero_based in ['auto', True, False]:
f.seek(0)
X, y = load_svmlight_file(f, n_features=4, zero_based=zero_based)
assert_array_equal(y, true_y)
assert_array_equal(X.toarray(), true_X.toarray())


def test_load_with_offsets():
def check_load_with_offsets(sparsity, n_samples, n_features):
rng = np.random.RandomState(0)
X = rng.uniform(low=0.0, high=1.0, size=(n_samples, n_features))
if sparsity:
X[X < sparsity] = 0.0
X = sp.csr_matrix(X)
y = rng.randint(low=0, high=2, size=n_samples)

f = BytesIO()
dump_svmlight_file(X, y, f)
f.seek(0)

size = len(f.getvalue())

# put some marks that are likely to happen anywhere in a row
mark_0 = 0
mark_1 = size // 3
length_0 = mark_1 - mark_0
mark_2 = 4 * size // 5
length_1 = mark_2 - mark_1

# load the original sparse matrix into 3 independant CSR matrices
X_0, y_0 = load_svmlight_file(f, n_features=n_features,
offset=mark_0, length=length_0)
X_1, y_1 = load_svmlight_file(f, n_features=n_features,
offset=mark_1, length=length_1)
X_2, y_2 = load_svmlight_file(f, n_features=n_features,
offset=mark_2)

y_concat = np.concatenate([y_0, y_1, y_2])
X_concat = sp.vstack([X_0, X_1, X_2])
assert_array_equal(y, y_concat)
assert_array_almost_equal(X.toarray(), X_concat.toarray())

# Generate a uniformly random sparse matrix
for sparsity in [0, 0.1, .5, 0.99, 1]:
for n_samples in [13, 101]:
for n_features in [2, 7, 41]:
yield check_load_with_offsets, sparsity, n_samples, n_features


def test_load_offset_exhaustive_splits():
rng = np.random.RandomState(0)
X = np.array([
[0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 0, 6],
[1, 2, 3, 4, 0, 6],
[0, 0, 0, 0, 0, 0],
[1, 0, 3, 0, 0, 0],
[0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0],
])
X = sp.csr_matrix(X)
n_samples, n_features = X.shape
y = rng.randint(low=0, high=2, size=n_samples)
query_id = np.arange(n_samples) // 2

f = BytesIO()
dump_svmlight_file(X, y, f, query_id=query_id)
f.seek(0)

size = len(f.getvalue())

# load the same data in 2 parts with all the possible byte offsets to
# locate the split so has to test for particular boundary cases
for mark in range(size):
if sp_version < (0, 14) and (mark == 0 or mark > size - 100):
# old scipy does not support sparse matrices with 0 rows.
continue
f.seek(0)
X_0, y_0, q_0 = load_svmlight_file(f, n_features=n_features,
query_id=True, offset=0,
length=mark)
X_1, y_1, q_1 = load_svmlight_file(f, n_features=n_features,
query_id=True, offset=mark,
length=-1)
q_concat = np.concatenate([q_0, q_1])
y_concat = np.concatenate([y_0, y_1])
X_concat = sp.vstack([X_0, X_1])
assert_array_equal(y, y_concat)
assert_array_equal(query_id, q_concat)
assert_array_almost_equal(X.toarray(), X_concat.toarray())


def test_load_with_offsets_error():
assert_raises_regex(ValueError, "n_features is required",
load_svmlight_file, datafile, offset=3, length=3)

0 comments on commit a39c8ab

Please sign in to comment.