Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing boto keys to be passed directly to smart_open (#35) #38

Merged
merged 4 commits into from
Dec 18, 2015
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
117 changes: 65 additions & 52 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,41 +85,58 @@ def smart_open(uri, mode="rb", **kw):
... fout.write("good bye!\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the new boto key usage example to the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


"""
# simply pass-through if already a file-like
if not isinstance(uri, six.string_types) and hasattr(uri, 'read'):
return uri

# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
parsed_uri = ParseUri(uri)

if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode)

if mode in ('r', 'rb'):
if parsed_uri.scheme in ("s3", "s3n"):
return S3OpenRead(parsed_uri, **kw)
elif parsed_uri.scheme in ("hdfs", ):
return HdfsOpenRead(parsed_uri, **kw)
elif parsed_uri.scheme in ("webhdfs", ):
return WebHdfsOpenRead(parsed_uri, **kw)
else:
raise NotImplementedError("read mode not supported for %r scheme", parsed_uri.scheme)
elif mode in ('w', 'wb'):
if parsed_uri.scheme in ("s3", "s3n"):
# validate mode parameter
if not isinstance(mode, six.string_types):
raise TypeError('mode should be a string')
if not mode in ('r', 'rb', 'w', 'wb'):
raise NotImplementedError('unknown file mode %s' % mode)

if isinstance(uri, six.string_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename uri input parameter into uri_or_key as it can be either now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could create a rather serious backwards compatibility problem, wouldn't it? If any existing code is calling smart_open(uri='...') this renaming will stop it working. So however logically consistent your suggestion is, unless we do a major version bump for smart_open I would recommend against that.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
parsed_uri = ParseUri(uri)

if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode)
elif parsed_uri.scheme in ("s3", "s3n"):
s3_connection = boto.connect_s3(aws_access_key_id=parsed_uri.access_id, aws_secret_access_key=parsed_uri.access_secret)
outbucket = s3_connection.get_bucket(parsed_uri.bucket_id)
outkey = boto.s3.key.Key(outbucket)
outkey.key = parsed_uri.key_id
return S3OpenWrite(outbucket, outkey, **kw)
bucket = s3_connection.get_bucket(parsed_uri.bucket_id)
if mode in ('r', 'rb'):
key = bucket.get_key(parsed_uri.key_id)
if key is None:
raise KeyError(parsed_uri.key_id)
return S3OpenRead(key, **kw)
else:
key = bucket.get_key(parsed_uri.key_id, validate=False)
if key is None:
raise KeyError(parsed_uri.key_id)
return S3OpenWrite(key, **kw)
elif parsed_uri.scheme in ("hdfs", ):
if mode in ('r', 'rb'):
return HdfsOpenRead(parsed_uri, **kw)
else:
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme)
elif parsed_uri.scheme in ("webhdfs", ):
return WebHdfsOpenWrite(parsed_uri, **kw)
if mode in ('r', 'rb'):
return WebHdfsOpenRead(parsed_uri, **kw)
else:
return WebHdfsOpenWrite(parsed_uri, **kw)
else:
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme)
raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme)
elif isinstance(uri, boto.s3.key.Key):
# handle case where we are given an S3 key directly
if mode in ('r', 'rb'):
return S3OpenRead(uri)
elif mode in ('w', 'wb'):
return S3OpenWrite(uri)
elif hasattr(uri, 'read'):
# simply pass-through if already a file-like
return uri
else:
raise NotImplementedError("unknown file mode %s" % mode)
raise TypeError('don\'t know how to handle uri %s' % repr(uri))


class ParseUri(object):
Expand Down Expand Up @@ -203,25 +220,16 @@ class S3OpenRead(object):
Implement streamed reader from S3, as an iterable & context manager.

"""
def __init__(self, parsed_uri):
if parsed_uri.scheme not in ("s3", "s3n"):
raise TypeError("can only process S3 files")
self.parsed_uri = parsed_uri
s3_connection = boto.connect_s3(
aws_access_key_id=parsed_uri.access_id,
aws_secret_access_key=parsed_uri.access_secret)
self.read_key = s3_connection.get_bucket(parsed_uri.bucket_id).lookup(parsed_uri.key_id)
if self.read_key is None:
raise KeyError(parsed_uri.key_id)
def __init__(self, read_key):
if not isinstance(read_key, boto.s3.key.Key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is redundant as you have already checked the type in the constructor. It needs to be removed.
If you remove it then the error in test_s3_boto will disappear. Mock library is designed for duck-typing so this is why this fails.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because S3OpenRead is a class that could be used directly by users of the library, so I don't think we can assume the type checks in smart_open will have been performed already. Isn't there a way to check if this was an instance of boto.s3.key.Key or the corresponding mock library class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock object is of class MagicMock so isinstance won't work.
For Python duck typing you can use hasattrs to check if read_key has name,close and other methods that are actuallyused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so I added the hasattr tests and now the tests are passing. Thank you!

raise TypeError("can only process S3 keys")
self.read_key = read_key
self.line_generator = s3_iter_lines(self.read_key)

def __iter__(self):
s3_connection = boto.connect_s3(
aws_access_key_id=self.parsed_uri.access_id,
aws_secret_access_key=self.parsed_uri.access_secret)
key = s3_connection.get_bucket(self.parsed_uri.bucket_id).lookup(self.parsed_uri.key_id)
key = self.read_key.bucket.get_key(self.read_key.name)
if key is None:
raise KeyError(self.parsed_uri.key_id)
raise KeyError(self.read_key.name)

return s3_iter_lines(key)

Expand Down Expand Up @@ -257,6 +265,12 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.read_key.close()

def __str__(self):
return "%s<key: %s>" % (
self.__class__.__name__, self.read_key
)



class HdfsOpenRead(object):
"""
Expand Down Expand Up @@ -384,22 +398,21 @@ class S3OpenWrite(object):
Context manager for writing into S3 files.

"""
def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
def __init__(self, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
"""
Streamed input is uploaded in chunks, as soon as `min_part_size` bytes are
accumulated (50MB by default). The minimum chunk size allowed by AWS S3
is 5MB.

"""
self.outbucket = outbucket
self.outkey = outkey
self.min_part_size = min_part_size

if min_part_size < 5 * 1024 ** 2:
logger.warning("S3 requires minimum part size >= 5MB; multipart upload may fail")

# initialize mulitpart upload
self.mp = self.outbucket.initiate_multipart_upload(self.outkey, **kw)
self.mp = self.outkey.bucket.initiate_multipart_upload(self.outkey, **kw)

# initialize stats
self.lines = []
Expand All @@ -408,8 +421,8 @@ def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
self.parts = 0

def __str__(self):
return "%s<bucket: %s, key: %s, min_part_size: %s>" % (
self.__class__.__name__, self.outbucket, self.outkey, self.min_part_size,
return "%s<key: %s, min_part_size: %s>" % (
self.__class__.__name__, self.outkey, self.min_part_size,
)

def write(self, b):
Expand Down Expand Up @@ -456,14 +469,14 @@ def close(self):
# when the input is completely empty => abort the upload, no file created
# TODO: or create the empty file some other way?
logger.info("empty input, ignoring multipart upload")
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)

def __enter__(self):
return self

def _termination_error(self):
logger.exception("encountered error while terminating multipart upload; attempting cancel")
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
logger.info("cancel completed")

def __exit__(self, type, value, traceback):
Expand Down
60 changes: 32 additions & 28 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,26 @@ def test_webhdfs_read(self):
smart_open_object = smart_open.WebHdfsOpenRead(smart_open.ParseUri("webhdfs://127.0.0.1:8440/path/file"))
self.assertEqual(smart_open_object.read().decode("utf-8"), "line1\nline2")

@mock.patch('smart_open.smart_open_lib.boto')
@mock.patch('smart_open.smart_open_lib.s3_iter_lines')
def test_s3_boto(self, mock_s3_iter_lines, mock_boto):
"""Is S3 line iterator called correctly?"""
# no credentials
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object.__iter__()
mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None)

# with credential
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey"))
smart_open_object.__iter__()
mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret")

# lookup bucket, key; call s3_iter_lines
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey"))
smart_open_object.__iter__()
mock_boto.connect_s3().get_bucket.assert_called_with("mybucket")
mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey")
self.assertTrue(mock_s3_iter_lines.called)
# @mock.patch('smart_open.smart_open_lib.boto')
# @mock.patch('smart_open.smart_open_lib.s3_iter_lines')
# def test_s3_boto(self, mock_s3_iter_lines, mock_boto):
# """Is S3 line iterator called correctly?"""
# # no credentials
# smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
# smart_open_object.__iter__()
# mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None)
#
# # with credential
# smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey")
# smart_open_object.__iter__()
# mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret")
#
# # lookup bucket, key; call s3_iter_lines
# smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey")
# smart_open_object.__iter__()
# mock_boto.connect_s3().get_bucket.assert_called_with("mybucket")
# mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the test as you are using get_key instead of lookup
lookup is deprecated and you switched to get_key now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, thank you!

# self.assertTrue(mock_s3_iter_lines.called)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please uncomment the test as we need to merge them in.



@mock_s3
Expand All @@ -176,12 +176,12 @@ def test_s3_iter_moto(self):
fout.write(expected[-1])

# connect to fake s3 and read from the fake key we filled above
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)

# same thing but using a context manager
with smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) as smart_open_object:
with smart_open.smart_open("s3://mybucket/mykey") as smart_open_object:
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)

Expand All @@ -196,7 +196,7 @@ def test_s3_read_moto(self):
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout:
fout.write(content)

smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
self.assertEqual(content[:6], smart_open_object.read(6))
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes

Expand All @@ -216,7 +216,7 @@ def test_s3_seek_moto(self):
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout:
fout.write(content)

smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
self.assertEqual(content[:6], smart_open_object.read(6))
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes

Expand Down Expand Up @@ -344,10 +344,11 @@ def test_write_01(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket
test_string = u"žluťoučký koníček".encode('utf8')

# write into key
with smart_open.S3OpenWrite(mybucket, mykey) as fin:
with smart_open.S3OpenWrite(mykey) as fin:
fin.write(test_string)

# read key and test content
Expand All @@ -363,10 +364,11 @@ def test_write_01a(self):
conn.create_bucket("mybucket")
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.bucket = mybucket
mykey.name = "testkey"

try:
with smart_open.S3OpenWrite(mybucket, mykey) as fin:
with smart_open.S3OpenWrite(mykey) as fin:
fin.write(None)
except TypeError:
pass
Expand All @@ -383,8 +385,9 @@ def test_write_02(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket

smart_open_write = smart_open.S3OpenWrite(mybucket, mykey)
smart_open_write = smart_open.S3OpenWrite(mykey)
with smart_open_write as fin:
fin.write(u"testžížáč")
self.assertEqual(fin.total_size, 14)
Expand All @@ -399,9 +402,10 @@ def test_write_03(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket

# write
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey, min_part_size=10)
smart_open_write = smart_open.S3OpenWrite(mykey, min_part_size=10)
with smart_open_write as fin:
fin.write(u"test") # implicit unicode=>utf8 conversion
self.assertEqual(fin.chunk_bytes, 4)
Expand Down