Skip to content

Commit

Permalink
return bytes instead of stream, read once
Browse files Browse the repository at this point in the history
  • Loading branch information
shashanksingh28 committed Sep 10, 2019
1 parent f4ca32b commit 831d78b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
70 changes: 35 additions & 35 deletions sklearn/datasets/openml.py
@@ -1,9 +1,7 @@
import gzip
import json
import os
from io import BytesIO
import hashlib
import shutil
from os.path import join
from warnings import warn
from contextlib import closing
Expand Down Expand Up @@ -63,7 +61,7 @@ def wrapper():
return decorator


def _open_openml_url(openml_path, data_home, expected_md5_checksum=None):
def _openml_url_bytes(openml_path, data_home, expected_md5_checksum=None):
"""
Returns a resource from OpenML.org. Caches it to data_home if required.
Expand All @@ -79,49 +77,47 @@ def _open_openml_url(openml_path, data_home, expected_md5_checksum=None):
Returns
-------
result : stream
A stream to the OpenML resource
result : bytes
Byte content of resource
"""
def is_gzip(_fsrc):
return _fsrc.info().get('Content-Encoding', '') == 'gzip'

req = Request(_OPENML_PREFIX + openml_path)
req.add_header('Accept-encoding', 'gzip')

def _md5_validated_stream(input_stream, md5_checksum):
def _md5_validated_bytes(bytes_content, md5_checksum):
"""
Consume binary stream to validate checksum,
return a new stream with same content
Parameters
----------
input_stream : io.BufferedIOBase
Input stream with a read() method to get content in bytes
bytes_content : bytes
md5_checksum: str
Expected md5 checksum
Expected md5 checksum of bytes
Returns
-------
BytesIO stream with the same content as input_stream for consumption
bytes
"""
with closing(input_stream):
bytes_content = input_stream.read()
actual_md5_checksum = hashlib.md5(bytes_content).hexdigest()
if md5_checksum != actual_md5_checksum:
raise ValueError("md5checksum: {} does not match expected: "
"{}".format(actual_md5_checksum,
md5_checksum))
return BytesIO(bytes_content)
actual_md5_checksum = hashlib.md5(bytes_content).hexdigest()
if md5_checksum != actual_md5_checksum:
raise ValueError("md5checksum: {} does not match expected: "
"{}".format(actual_md5_checksum,
md5_checksum))
return bytes_content

if data_home is None:
fsrc = urlopen(req)
if is_gzip(fsrc):
fsrc = gzip.GzipFile(fileobj=fsrc, mode='rb')
bytes_content = fsrc.read()
if expected_md5_checksum:
# validating checksum reads and consumes the stream
return _md5_validated_stream(fsrc, expected_md5_checksum)
return fsrc
return _md5_validated_bytes(bytes_content, expected_md5_checksum)
return bytes_content

local_path = _get_local_path(openml_path, data_home)
if not os.path.exists(local_path):
Expand All @@ -135,18 +131,23 @@ def _md5_validated_stream(input_stream, md5_checksum):
with closing(urlopen(req)) as fsrc:
if is_gzip(fsrc): # unzip it for checksum validation
fsrc = gzip.GzipFile(fileobj=fsrc, mode='rb')
bytes_content = fsrc.read()
if expected_md5_checksum:
fsrc = _md5_validated_stream(fsrc, expected_md5_checksum)
bytes_content = _md5_validated_bytes(bytes_content,
expected_md5_checksum)
with gzip.GzipFile(local_path, 'wb') as fdst:
shutil.copyfileobj(fsrc, fdst)
fdst.write(bytes_content)
except Exception:
if os.path.exists(local_path):
os.unlink(local_path)
raise
else:
with gzip.GzipFile(local_path, "rb") as gzip_file:
bytes_content = gzip_file.read()

# XXX: First time, decompression will not be necessary (by using fsrc), but
# it will happen nonetheless
return gzip.GzipFile(local_path, 'rb')
return bytes_content


def _get_json_content_from_openml_api(url, error_message, raise_if_error,
Expand Down Expand Up @@ -183,8 +184,7 @@ def _get_json_content_from_openml_api(url, error_message, raise_if_error,

@_retry_with_clean_cache(url, data_home)
def _load_json():
with closing(_open_openml_url(url, data_home)) as response:
return json.loads(response.read().decode("utf-8"))
return json.loads(_openml_url_bytes(url, data_home).decode("utf-8"))

try:
return _load_json()
Expand Down Expand Up @@ -489,16 +489,16 @@ def _download_data_arff(file_id, sparse, data_home, encode_nominal=True,

@_retry_with_clean_cache(url, data_home)
def _arff_load():
with closing(_open_openml_url(url, data_home, expected_md5_checksum)) \
as response:
if sparse is True:
return_type = _arff.COO
else:
return_type = _arff.DENSE_GEN

arff_file = _arff.loads(response.read().decode('utf-8'),
encode_nominal=encode_nominal,
return_type=return_type)
bytes_content = _openml_url_bytes(url, data_home,
expected_md5_checksum)
if sparse is True:
return_type = _arff.COO
else:
return_type = _arff.DENSE_GEN

arff_file = _arff.loads(bytes_content.decode('utf-8'),
encode_nominal=encode_nominal,
return_type=return_type)
return arff_file

return _arff_load()
Expand Down
10 changes: 5 additions & 5 deletions sklearn/datasets/tests/test_openml.py
Expand Up @@ -12,7 +12,7 @@

from sklearn import config_context
from sklearn.datasets import fetch_openml
from sklearn.datasets.openml import (_open_openml_url,
from sklearn.datasets.openml import (_openml_url_bytes,
_get_data_description_by_id,
_download_data_arff,
_get_local_path,
Expand Down Expand Up @@ -922,13 +922,13 @@ def test_open_openml_url_cache(monkeypatch, gzip_response, tmpdir):
openml_path = sklearn.datasets.openml._DATA_FILE.format(data_id)
cache_directory = str(tmpdir.mkdir('scikit_learn_data'))
# first fill the cache
response1 = _open_openml_url(openml_path, cache_directory)
response1 = _openml_url_bytes(openml_path, cache_directory)
# assert file exists
location = _get_local_path(openml_path, cache_directory)
assert os.path.isfile(location)
# redownload, to utilize cache
response2 = _open_openml_url(openml_path, cache_directory)
assert response1.read() == response2.read()
response2 = _openml_url_bytes(openml_path, cache_directory)
assert response1 == response2


@pytest.mark.parametrize('gzip_response', [True, False])
Expand All @@ -949,7 +949,7 @@ def _mock_urlopen(request):
monkeypatch.setattr(sklearn.datasets.openml, 'urlopen', _mock_urlopen)

with pytest.raises(ValueError, match="Invalid request"):
_open_openml_url(openml_path, cache_directory)
_openml_url_bytes(openml_path, cache_directory)

assert not os.path.exists(location)

Expand Down

0 comments on commit 831d78b

Please sign in to comment.