Skip to content

Commit

Permalink
FIX: make the 20 newsgroups loader explicitly decode latin1 content
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Feb 28, 2012
1 parent 932c155 commit b071cb0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
17 changes: 16 additions & 1 deletion sklearn/datasets/base.py
Expand Up @@ -64,7 +64,8 @@ def clear_data_home(data_home=None):


def load_files(container_path, description=None, categories=None,
load_content=True, shuffle=True, random_state=0):
load_content=True, shuffle=True, charset=None,
charse_error='strict', random_state=0):
"""Load text files with categories as subfolder names.
Individual samples are assumed to be files stored a two levels folder
Expand Down Expand Up @@ -115,6 +116,18 @@ def load_files(container_path, description=None, categories=None,
in the data structure returned. If not, a filenames attribute
gives the path to the files.
charset : string or None (default is None)
If None, do not try to decode the content of the files (e.g. for
images or other non-text content).
If not None, charset to use to decode text files if load_content is
True.
charset_error: {'strict', 'ignore', 'replace'}
Instruction on what to do if a byte sequence is given to analyze that
contains characters not of the given `charset`. By default, it is
'strict', meaning that a UnicodeDecodeError will be raised. Other
values are 'ignore' and 'replace'.
shuffle : bool, optional (default=True)
Whether or not to shuffle the data: might be important for models that
make the assumption that the samples are independent and identically
Expand Down Expand Up @@ -166,6 +179,8 @@ def load_files(container_path, description=None, categories=None,

if load_content:
data = [open(filename).read() for filename in filenames]
if charset is not None:
data = [d.decode(charset, charse_error) for d in data]
return Bunch(data=data,
filenames=filenames,
target_names=target_names,
Expand Down
8 changes: 3 additions & 5 deletions sklearn/datasets/twenty_newsgroups.py
Expand Up @@ -67,9 +67,7 @@


def download_20newsgroups(target_dir, cache_path):
""" Download the 20Newsgroups data and convert is in a zipped pickle
storage.
"""
"""Download the 20 newsgroups data and stored it as a zipped pickle."""
archive_path = os.path.join(target_dir, ARCHIVE_NAME)
train_path = os.path.join(target_dir, TRAIN_FOLDER)
test_path = os.path.join(target_dir, TEST_FOLDER)
Expand All @@ -88,8 +86,8 @@ def download_20newsgroups(target_dir, cache_path):

# Store a zipped pickle
cache = dict(
train=load_files(train_path),
test=load_files(test_path)
train=load_files(train_path, charset='latin1'),
test=load_files(test_path, charset='latin1')
)
open(cache_path, 'wb').write(pickle.dumps(cache).encode('zip'))
shutil.rmtree(target_dir)
Expand Down

0 comments on commit b071cb0

Please sign in to comment.