Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve file mode handling #559

Merged
merged 5 commits into from Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
@@ -1,6 +1,7 @@
# Unreleased

- Fix reading empty file or seeking past end of file for s3 backend (PR [#549](https://github.com/RaRe-Technologies/smart_open/pull/549), [@jcushman](https://github.com/jcushman))
- Fix reading empty file or seeking past end of file for s3 backend (PR [#549](https://github.com/RaRe-Technologies/smart_open/pull/549), [@jcushman](https://github.com/jcushman))
- Fix handling of rt/wt mode when working with gzip compression (PR [#559](https://github.com/RaRe-Technologies/smart_open/pull/559), [@mpenkov](https://github.com/mpenkov))

# 3.0.0, 8 Oct 2020

Expand Down
71 changes: 62 additions & 9 deletions smart_open/smart_open_lib.py
Expand Up @@ -46,12 +46,6 @@

SYSTEM_ENCODING = sys.getdefaultencoding()

_TO_BINARY_LUT = {
'r': 'rb', 'r+': 'rb+', 'rt': 'rb', 'rt+': 'rb+',
'w': 'wb', 'w+': 'wb+', 'wt': 'wb', "wt+": 'wb+',
'a': 'ab', 'a+': 'ab+', 'at': 'ab', 'at+': 'ab+',
}


def _sniff_scheme(uri_as_string):
"""Returns the scheme of the URL only, as a string."""
Expand Down Expand Up @@ -218,12 +212,17 @@ def open(
# filename ---------------> bytes -------------> bytes ---------> text
# binary decompressed decode
#
binary_mode = _TO_BINARY_LUT.get(mode, mode)

try:
binary_mode = _get_binary_mode(mode)
except ValueError as ve:
raise NotImplementedError(ve.args[0])

binary = _open_binary_stream(uri, binary_mode, transport_params)
if ignore_ext:
decompressed = binary
else:
decompressed = compression.compression_wrapper(binary, mode)
decompressed = compression.compression_wrapper(binary, binary_mode)

if 'b' not in mode or explicit_encoding is not None:
decoded = _encoding_wrapper(decompressed, mode, encoding=encoding, errors=errors)
Expand All @@ -233,6 +232,60 @@ def open(
return decoded


def _get_binary_mode(mode_str):
#
# https://docs.python.org/3/library/functions.html#open
#
# The order of characters in the mode parameter appears to be unspecified.
# The implementation follows the examples, just to be safe.
#
mode = list(mode_str)
binmode = []

if 't' in mode and 'b' in mode:
raise ValueError("can't have text and binary mode at once")

counts = [mode.count(x) for x in 'rwa']
if sum(counts) > 1:
raise ValueError("must have exactly one of create/read/write/append mode")

def transfer(char):
binmode.append(mode.pop(mode.index(char)))

if 'a' in mode:
transfer('a')
elif 'w' in mode:
transfer('w')
elif 'r' in mode:
transfer('r')
else:
raise ValueError(
"Must have exactly one of create/read/write/append "
"mode and at most one plus"
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
)

if 'b' in mode:
transfer('b')
elif 't' in mode:
mode.pop(mode.index('t'))
binmode.append('b')
else:
binmode.append('b')

if '+' in mode:
transfer('+')

#
# There shouldn't be anything left in the mode list at this stage.
# If there is, then either we've missed something and the implementation
# of this function is broken, or the original input mode is invalid.
#
if mode:
raise ValueError('invalid mode: %r' % mode_str)

return ''.join(binmode)


def _shortcut_open(
uri,
mode,
Expand Down Expand Up @@ -317,7 +370,7 @@ def _open_binary_stream(uri, mode, transport_params):
return uri

if not isinstance(uri, str):
raise TypeError("don't know how to handle uri %r" % uri)
raise TypeError("don't know how to handle uri %s" % repr(uri))
piskvorky marked this conversation as resolved.
Show resolved Hide resolved

scheme = _sniff_scheme(uri)
submodule = transport.get_transport(scheme)
Expand Down
2 changes: 2 additions & 0 deletions smart_open/tests/test_http.py
Expand Up @@ -5,6 +5,7 @@
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
import os
import unittest

import responses
Expand Down Expand Up @@ -38,6 +39,7 @@ def request_callback(request):
return (200, HEADERS, BYTES[start:end])


@unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason')
class HttpTest(unittest.TestCase):

@responses.activate
Expand Down
62 changes: 58 additions & 4 deletions smart_open/tests/test_smart_open.py
Expand Up @@ -10,18 +10,19 @@
import csv
import contextlib
import io
import unittest
import gzip
import hashlib
import logging
import tempfile
import os
import hashlib
import tempfile
import unittest
import warnings

import boto3
import mock
from moto import mock_s3
import responses
import gzip
import parameterizedtestcase
import pytest

import smart_open
Expand Down Expand Up @@ -384,6 +385,7 @@ def test_pathlib_monkeypath_read_gz(self):
_patch_pathlib(obj.old_impl)


@unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason')
class SmartOpenHttpTest(unittest.TestCase):
"""
Test reading from HTTP connections in various ways.
Expand Down Expand Up @@ -859,6 +861,7 @@ def test_hdfs(self, mock_subprocess):
stdout=mock_subprocess.PIPE,
)

@unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason')
@responses.activate
def test_webhdfs(self):
"""Is webhdfs line iterator called correctly"""
Expand All @@ -869,6 +872,7 @@ def test_webhdfs(self):
self.assertEqual(next(iterator).decode("utf-8"), "line1\n")
self.assertEqual(next(iterator).decode("utf-8"), "line2")

@unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason')
@responses.activate
def test_webhdfs_encoding(self):
"""Is HDFS line iterator called correctly?"""
Expand All @@ -881,6 +885,7 @@ def test_webhdfs_encoding(self):
actual = smart_open.open(input_url, encoding='utf-8').read()
self.assertEqual(text, actual)

@unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason')
@responses.activate
def test_webhdfs_read(self):
"""Does webhdfs read method work correctly"""
Expand Down Expand Up @@ -1226,6 +1231,7 @@ def test_write_bad_encoding_replace(self):
self.assertEqual(expected, actual)


@unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason')
class WebHdfsWriteTest(unittest.TestCase):
"""
Test writing into webhdfs files.
Expand Down Expand Up @@ -1325,6 +1331,14 @@ def test_write_read_bz2(self):
"""Can write and read bz2?"""
self.write_read_assertion('.bz2')

def test_gzip_text(self):
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with smart_open.open(f.name, 'wt') as fout:
fout.write('hello world')

with smart_open.open(f.name, 'rt') as fin:
assert fin.read() == 'hello world'


class MultistreamsBZ2Test(unittest.TestCase):
"""
Expand Down Expand Up @@ -1621,6 +1635,46 @@ def test(self):
self.assertEqual(expected, actual)


class GetBinaryModeTest(parameterizedtestcase.ParameterizedTestCase):
@parameterizedtestcase.ParameterizedTestCase.parameterize(
('mode', 'expected'),
[
('r', 'rb'),
('r+', 'rb+'),
('rt', 'rb'),
('rt+', 'rb+'),
('r+t', 'rb+'),
('w', 'wb'),
('w+', 'wb+'),
('wt', 'wb'),
('wt+', 'wb+'),
('w+t', 'wb+'),
('a', 'ab'),
('a+', 'ab+'),
('at', 'ab'),
('at+', 'ab+'),
('a+t', 'ab+'),
]
)
def test(self, mode, expected):
actual = smart_open.smart_open_lib._get_binary_mode(mode)
assert actual == expected

@parameterizedtestcase.ParameterizedTestCase.parameterize(
('mode', 'expected'),
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
[
('rw', ),
('rwa', ),
('rbt', ),
('r++', ),
('+', ),
('x', ),
]
)
def test_bad(self, mode):
self.assertRaises(ValueError, smart_open.smart_open_lib._get_binary_mode, mode)


def test_backwards_compatibility_wrapper():
fpath = os.path.join(CURR_DIR, 'test_data/crime-and-punishment.txt')
expected = open(fpath, 'rb').readline()
Expand Down