Skip to content

Commit

Permalink
Fix two S3 bugs (#307)
Browse files Browse the repository at this point in the history
* avoid using mutable values for defaults, fix #304

* get rid of unneeded variable

* respect endpoint_url, fix #305

* minor updates to make unit tests pass
  • Loading branch information
mpenkov committed Apr 26, 2019
1 parent 052ff93 commit 07a205f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 13 deletions.
63 changes: 63 additions & 0 deletions integration-tests/test_minio.py
@@ -0,0 +1,63 @@
import logging
import boto3

from smart_open import open

#
# These are publicly available via play.min.io
#
KEY_ID = 'Q3AM3UQ867SPQQA43P2F'
SECRET_KEY = 'zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG'
ENDPOINT_URL = 'https://play.min.io:9000'


def read_boto3():
"""Read directly using boto3."""
session = get_minio_session()
s3 = session.resource('s3', endpoint_url=ENDPOINT_URL)

obj = s3.Object('smart-open-test', 'README.rst')
data = obj.get()['Body'].read()
logging.info('read %d bytes via boto3', len(data))
return data


def read_smart_open():
url = 's3://Q3AM3UQ867SPQQA43P2F:zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG@play.min.io:9000@smart-open-test/README.rst' # noqa

#
# If the default region is not us-east-1, we need to construct our own
# session. This is because smart_open will create a session in the default
# region, which _must_ be us-east-1 for minio to work.
#
tp = {}
if get_default_region() != 'us-east-1':
logging.info('injecting custom session')
tp['session'] = get_minio_session()
with open(url, transport_params=tp) as fin:
text = fin.read()
logging.info('read %d characters via smart_open', len(text))
return text


def get_minio_session():
return boto3.Session(
region_name='us-east-1',
aws_access_key_id=KEY_ID,
aws_secret_access_key=SECRET_KEY,
)


def get_default_region():
return boto3.Session().region_name


def main():
logging.basicConfig(level=logging.INFO)
from_boto3 = read_boto3()
from_smart_open = read_smart_open()
assert from_boto3.decode('utf-8') == from_smart_open


if __name__ == '__main__':
main()
30 changes: 22 additions & 8 deletions smart_open/s3.py
Expand Up @@ -73,8 +73,8 @@ def open(
buffer_size=DEFAULT_BUFFER_SIZE,
min_part_size=DEFAULT_MIN_PART_SIZE,
session=None,
resource_kwargs=dict(),
multipart_upload_kwargs=dict(),
resource_kwargs=None,
multipart_upload_kwargs=None,
):
"""Open an S3 object for reading or writing.
Expand All @@ -93,7 +93,7 @@ def open(
session: object, optional
The S3 session to use when working with boto3.
resource_kwargs: dict, optional
Keyword arguments to use when creating a new resource. For writing only.
Keyword arguments to use when accessing the S3 resource for reading or writing.
multipart_upload_kwargs: dict, optional
Additional parameters to pass to boto3's initiate_multipart_upload function.
For writing only.
Expand All @@ -103,6 +103,11 @@ def open(
if mode not in MODES:
raise NotImplementedError('bad mode: %r expected one of %r' % (mode, MODES))

if resource_kwargs is None:
resource_kwargs = {}
if multipart_upload_kwargs is None:
multipart_upload_kwargs = {}

if mode == READ_BINARY:
fileobj = SeekableBufferedInputBase(
bucket_id,
Expand Down Expand Up @@ -165,7 +170,7 @@ def seek(self, position):
#
try:
self._body.close()
except AttributeError as e:
except AttributeError:
pass

if position == self._content_length == 0 or position == self._content_length:
Expand All @@ -190,9 +195,12 @@ def read(self, size=-1):

class BufferedInputBase(io.BufferedIOBase):
def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=dict()):
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=None):
if session is None:
session = boto3.Session()
if resource_kwargs is None:
resource_kwargs = {}

s3 = session.resource('s3', **resource_kwargs)
self._object = s3.Object(bucket, key)
self._raw_reader = RawReader(self._object)
Expand Down Expand Up @@ -324,9 +332,11 @@ class SeekableBufferedInputBase(BufferedInputBase):
Implements the io.BufferedIOBase interface of the standard library."""

def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=dict()):
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=None):
if session is None:
session = boto3.Session()
if resource_kwargs is None:
resource_kwargs = {}
s3 = session.resource('s3', **resource_kwargs)
self._object = s3.Object(bucket, key)
self._raw_reader = SeekableRawReader(self._object)
Expand Down Expand Up @@ -393,15 +403,19 @@ def __init__(
key,
min_part_size=DEFAULT_MIN_PART_SIZE,
session=None,
resource_kwargs=dict(),
multipart_upload_kwargs=dict(),
resource_kwargs=None,
multipart_upload_kwargs=None,
):
if min_part_size < MIN_MIN_PART_SIZE:
logger.warning("S3 requires minimum part size >= 5MB; \
multipart upload may fail")

if session is None:
session = boto3.Session()
if resource_kwargs is None:
resource_kwargs = {}
if multipart_upload_kwargs is None:
multipart_upload_kwargs = {}

s3 = session.resource('s3', **resource_kwargs)

Expand Down
37 changes: 34 additions & 3 deletions smart_open/smart_open_lib.py
Expand Up @@ -62,6 +62,7 @@

_ISSUE_189_URL = 'https://github.com/RaRe-Technologies/smart_open/issues/189'

_DEFAULT_S3_HOST = 's3.amazonaws.com'

_COMPRESSOR_REGISTRY = {}

Expand Down Expand Up @@ -205,7 +206,7 @@ def open(
closefd=True,
opener=None,
ignore_ext=False,
transport_params=dict(),
transport_params=None,
):
r"""Open the URI object, returning a file-like object.
Expand Down Expand Up @@ -292,6 +293,9 @@ def open(
if not isinstance(mode, six.string_types):
raise TypeError('mode should be a string')

if transport_params is None:
transport_params = {}

fobj = _shortcut_open(
uri,
mode,
Expand Down Expand Up @@ -399,7 +403,6 @@ def smart_open(uri, mode="rb", **kw):
url = kw.pop('host')
if not url.startswith('http'):
url = 'http://' + url
transport_params['multipart_upload_kwargs'].update(endpoint_url=url)
transport_params['resource_kwargs'].update(endpoint_url=url)

if 's3_upload' in kw and kw['s3_upload']:
Expand Down Expand Up @@ -592,10 +595,38 @@ def _s3_open_uri(parsed_uri, mode, transport_params):
aws_secret_access_key=parsed_uri.access_secret,
)

#
# There are two explicit ways the user can provide the endpoint URI:
#
# 1. Via the URL. The protocol is implicit, and we assume HTTPS in this case.
# 2. Via the resource_kwargs and multipart_upload_kwargs endpoint_url parameter.
#
# Again, these are not mutually exclusive: the user can specify both. We
# have to pick one to proceed, however, and we go with 2.
#
if parsed_uri.host != _DEFAULT_S3_HOST:
endpoint_url = 'https://%s:%d' % (parsed_uri.host, parsed_uri.port)
_override_endpoint_url(transport_params, endpoint_url)

kwargs = _check_kwargs(smart_open_s3.open, transport_params)
return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs)


def _override_endpoint_url(tp, url):
try:
resource_kwargs = tp['resource_kwargs']
except KeyError:
resource_kwargs = tp['resource_kwargs'] = {}

if resource_kwargs.get('endpoint_url'):
logger.warning(
'ignoring endpoint_url parsed from URL because it conflicts '
'with transport_params.resource_kwargs.endpoint_url. '
)
else:
resource_kwargs.update(endpoint_url=url)


def _my_urlsplit(url):
"""This is a hack to prevent the regular urlsplit from splitting around question marks.
Expand Down Expand Up @@ -723,7 +754,7 @@ def _parse_uri_s3x(parsed_uri):
assert parsed_uri.scheme in smart_open_s3.SUPPORTED_SCHEMES

port = 443
host = boto.config.get('s3', 'host', 's3.amazonaws.com')
host = boto.config.get('s3', 'host', _DEFAULT_S3_HOST)
ordinary_calling_format = False
#
# These defaults tell boto3 to look for credentials elsewhere
Expand Down
53 changes: 51 additions & 2 deletions smart_open/tests/test_smart_open.py
Expand Up @@ -12,9 +12,7 @@
import logging
import tempfile
import os
import sys
import hashlib
import unittest

import boto3
import mock
Expand Down Expand Up @@ -104,6 +102,16 @@ def test_s3_uri_has_atmark_in_key_name2(self):
self.assertEqual(parsed_uri.host, "hostname")
self.assertEqual(parsed_uri.port, 1234)

def test_s3_uri_has_atmark_in_key_name3(self):
parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@hostname@mybucket/dir/my@ke@y")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "dir/my@ke@y")
self.assertEqual(parsed_uri.access_id, "accessid")
self.assertEqual(parsed_uri.access_secret, "access/secret")
self.assertEqual(parsed_uri.host, "hostname")
self.assertEqual(parsed_uri.port, 443)

def test_s3_handles_fragments(self):
uri_str = 's3://bucket-name/folder/picture #1.jpg'
parsed_uri = smart_open_lib._parse_uri(uri_str)
Expand Down Expand Up @@ -1287,6 +1295,47 @@ def test_write_text_gzip(self):
actual = fin.read()
self.assertEqual(text, actual)

@mock.patch('smart_open.s3.SeekableBufferedInputBase')
def test_transport_params_is_not_mutable(self, mock_open):
smart_open.open('s3://access_key:secret_key@host@bucket/key')
smart_open.open('s3://bucket/key')

#
# The first call should have a non-null session, because the session
# keys were explicitly specified in the URL. The second call should
# _not_ have a session.
#
self.assertIsNone(mock_open.call_args_list[1][1]['session'])
self.assertIsNotNone(mock_open.call_args_list[0][1]['session'])

@mock.patch('smart_open.s3.SeekableBufferedInputBase')
def test_respects_endpoint_url_read(self, mock_open):
url = 's3://key_id:secret_key@play.min.io:9000@smart-open-test/README.rst'
smart_open.open(url)

expected = {'endpoint_url': 'https://play.min.io:9000'}
self.assertEqual(mock_open.call_args[1]['resource_kwargs'], expected)

@mock.patch('smart_open.s3.BufferedOutputBase')
def test_respects_endpoint_url_write(self, mock_open):
url = 's3://key_id:secret_key@play.min.io:9000@smart-open-test/README.rst'
smart_open.open(url, 'wb')

expected = {'endpoint_url': 'https://play.min.io:9000'}
self.assertEqual(mock_open.call_args[1]['resource_kwargs'], expected)


def function(a, b, c, foo='bar', baz='boz'):
pass


class CheckKwargsTest(unittest.TestCase):
def test(self):
kwargs = {'foo': 123, 'bad': False}
expected = {'foo': 123}
actual = smart_open.smart_open_lib._check_kwargs(function, kwargs)
self.assertEqual(expected, actual)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit 07a205f

Please sign in to comment.