Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing boto keys to be passed directly to smart_open (#35) #38

Merged
merged 4 commits into from
Dec 18, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
* 1.3.1, 16th December 2015
* 1.4.0, 16th December 2015

- Updated smart_open to accept an instance of boto.s3.key.Key (PR #38, @asieira)

- Disable multiprocessing if unavailable. Allows to run on Google Compute Engine. (PR #41, @nikicc)
- Httpretty updated to allow LC_ALL=C locale config. (PR #39, @jsphpl)
* 1.3.1, 16th December 2015

- Disable multiprocessing if unavailable. Allows to run on Google Compute Engine. (PR #41, @nikicc)
- Httpretty updated to allow LC_ALL=C locale config. (PR #39, @jsphpl)

* 1.3.0, 19th September 2015

Expand Down
134 changes: 79 additions & 55 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def smart_open(uri, mode="rb", **kw):

The `uri` can be either:

1. local filesystem (compressed ``.gz`` or ``.bz2`` files handled automatically):
1. a URI for the local filesystem (compressed ``.gz`` or ``.bz2`` files handled automatically):
`./lines.txt`, `/home/joe/lines.txt.gz`, `file:///home/joe/lines.txt.bz2`
2. Amazon's S3 (can also supply credentials inside the URI):
2. a URI for HDFS: `hdfs:///some/path/lines.txt`
3. a URI for Amazon's S3 (can also supply credentials inside the URI):
`s3://my_bucket/lines.txt`, `s3://my_aws_key_id:key_secret@my_bucket/lines.txt`
3. HDFS: `hdfs:///some/path/lines.txt`
4. an instance of the boto.s3.key.Key class.

Examples::

Expand All @@ -76,6 +77,12 @@ def smart_open(uri, mode="rb", **kw):
... for line in fin:
... print line

>>> # you can also use a boto.s3.key.Key instance directly:
>>> key = boto.connect_s3().get_bucket("my_bucket").get_key("my_key")
>>> with smart_open.smart_open(key) as fin:
... for line in fin:
... print line

>>> # stream line-by-line from an HDFS file
>>> for line in smart_open.smart_open('hdfs:///user/hadoop/my_file.txt'):
... print line
Expand All @@ -96,41 +103,58 @@ def smart_open(uri, mode="rb", **kw):
... fout.write("good bye!\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the new boto key usage example to the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


"""
# simply pass-through if already a file-like
if not isinstance(uri, six.string_types) and hasattr(uri, 'read'):
return uri

# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
parsed_uri = ParseUri(uri)

if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode)

if mode in ('r', 'rb'):
if parsed_uri.scheme in ("s3", "s3n"):
return S3OpenRead(parsed_uri, **kw)
elif parsed_uri.scheme in ("hdfs", ):
return HdfsOpenRead(parsed_uri, **kw)
elif parsed_uri.scheme in ("webhdfs", ):
return WebHdfsOpenRead(parsed_uri, **kw)
else:
raise NotImplementedError("read mode not supported for %r scheme", parsed_uri.scheme)
elif mode in ('w', 'wb'):
if parsed_uri.scheme in ("s3", "s3n"):
# validate mode parameter
if not isinstance(mode, six.string_types):
raise TypeError('mode should be a string')
if not mode in ('r', 'rb', 'w', 'wb'):
raise NotImplementedError('unknown file mode %s' % mode)

if isinstance(uri, six.string_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename uri input parameter into uri_or_key as it can be either now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could create a rather serious backwards compatibility problem, wouldn't it? If any existing code is calling smart_open(uri='...') this renaming will stop it working. So however logically consistent your suggestion is, unless we do a major version bump for smart_open I would recommend against that.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
parsed_uri = ParseUri(uri)

if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode)
elif parsed_uri.scheme in ("s3", "s3n"):
s3_connection = boto.connect_s3(aws_access_key_id=parsed_uri.access_id, aws_secret_access_key=parsed_uri.access_secret)
outbucket = s3_connection.get_bucket(parsed_uri.bucket_id)
outkey = boto.s3.key.Key(outbucket)
outkey.key = parsed_uri.key_id
return S3OpenWrite(outbucket, outkey, **kw)
bucket = s3_connection.get_bucket(parsed_uri.bucket_id)
if mode in ('r', 'rb'):
key = bucket.get_key(parsed_uri.key_id)
if key is None:
raise KeyError(parsed_uri.key_id)
return S3OpenRead(key, **kw)
else:
key = bucket.get_key(parsed_uri.key_id, validate=False)
if key is None:
raise KeyError(parsed_uri.key_id)
return S3OpenWrite(key, **kw)
elif parsed_uri.scheme in ("hdfs", ):
if mode in ('r', 'rb'):
return HdfsOpenRead(parsed_uri, **kw)
else:
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme)
elif parsed_uri.scheme in ("webhdfs", ):
return WebHdfsOpenWrite(parsed_uri, **kw)
if mode in ('r', 'rb'):
return WebHdfsOpenRead(parsed_uri, **kw)
else:
return WebHdfsOpenWrite(parsed_uri, **kw)
else:
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme)
raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme)
elif isinstance(uri, boto.s3.key.Key):
# handle case where we are given an S3 key directly
if mode in ('r', 'rb'):
return S3OpenRead(uri)
elif mode in ('w', 'wb'):
return S3OpenWrite(uri)
elif hasattr(uri, 'read'):
# simply pass-through if already a file-like
return uri
else:
raise NotImplementedError("unknown file mode %s" % mode)
raise TypeError('don\'t know how to handle uri %s' % repr(uri))


class ParseUri(object):
Expand Down Expand Up @@ -214,25 +238,17 @@ class S3OpenRead(object):
Implement streamed reader from S3, as an iterable & context manager.

"""
def __init__(self, parsed_uri):
if parsed_uri.scheme not in ("s3", "s3n"):
raise TypeError("can only process S3 files")
self.parsed_uri = parsed_uri
s3_connection = boto.connect_s3(
aws_access_key_id=parsed_uri.access_id,
aws_secret_access_key=parsed_uri.access_secret)
self.read_key = s3_connection.get_bucket(parsed_uri.bucket_id).lookup(parsed_uri.key_id)
if self.read_key is None:
raise KeyError(parsed_uri.key_id)
def __init__(self, read_key):
if not hasattr(read_key, "bucket") and not hasattr(read_key, "name") and not hasattr(read_key, "read") \
and not hasattr(read_key, "close"):
raise TypeError("can only process S3 keys")
self.read_key = read_key
self.line_generator = s3_iter_lines(self.read_key)

def __iter__(self):
s3_connection = boto.connect_s3(
aws_access_key_id=self.parsed_uri.access_id,
aws_secret_access_key=self.parsed_uri.access_secret)
key = s3_connection.get_bucket(self.parsed_uri.bucket_id).lookup(self.parsed_uri.key_id)
key = self.read_key.bucket.get_key(self.read_key.name)
if key is None:
raise KeyError(self.parsed_uri.key_id)
raise KeyError(self.read_key.name)

return s3_iter_lines(key)

Expand Down Expand Up @@ -268,6 +284,12 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.read_key.close()

def __str__(self):
return "%s<key: %s>" % (
self.__class__.__name__, self.read_key
)



class HdfsOpenRead(object):
"""
Expand Down Expand Up @@ -395,22 +417,24 @@ class S3OpenWrite(object):
Context manager for writing into S3 files.

"""
def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
def __init__(self, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
"""
Streamed input is uploaded in chunks, as soon as `min_part_size` bytes are
accumulated (50MB by default). The minimum chunk size allowed by AWS S3
is 5MB.

"""
self.outbucket = outbucket
if not hasattr(outkey, "bucket") and not hasattr(outkey, "name"):
raise TypeError("can only process S3 keys")

self.outkey = outkey
self.min_part_size = min_part_size

if min_part_size < 5 * 1024 ** 2:
logger.warning("S3 requires minimum part size >= 5MB; multipart upload may fail")

# initialize mulitpart upload
self.mp = self.outbucket.initiate_multipart_upload(self.outkey, **kw)
self.mp = self.outkey.bucket.initiate_multipart_upload(self.outkey, **kw)

# initialize stats
self.lines = []
Expand All @@ -419,8 +443,8 @@ def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
self.parts = 0

def __str__(self):
return "%s<bucket: %s, key: %s, min_part_size: %s>" % (
self.__class__.__name__, self.outbucket, self.outkey, self.min_part_size,
return "%s<key: %s, min_part_size: %s>" % (
self.__class__.__name__, self.outkey, self.min_part_size,
)

def write(self, b):
Expand Down Expand Up @@ -467,14 +491,14 @@ def close(self):
# when the input is completely empty => abort the upload, no file created
# TODO: or create the empty file some other way?
logger.info("empty input, ignoring multipart upload")
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)

def __enter__(self):
return self

def _termination_error(self):
logger.exception("encountered error while terminating multipart upload; attempting cancel")
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
logger.info("cancel completed")

def __exit__(self, type, value, traceback):
Expand Down
28 changes: 16 additions & 12 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,20 @@ def test_webhdfs_read(self):
def test_s3_boto(self, mock_s3_iter_lines, mock_boto):
"""Is S3 line iterator called correctly?"""
# no credentials
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
smart_open_object.__iter__()
mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None)

# with credential
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey")
smart_open_object.__iter__()
mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret")

# lookup bucket, key; call s3_iter_lines
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey")
smart_open_object.__iter__()
mock_boto.connect_s3().get_bucket.assert_called_with("mybucket")
mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey")
mock_boto.connect_s3().get_bucket().get_key.assert_called_with("mykey")
self.assertTrue(mock_s3_iter_lines.called)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please uncomment the test as we need to merge them in.



Expand All @@ -176,12 +176,12 @@ def test_s3_iter_moto(self):
fout.write(expected[-1])

# connect to fake s3 and read from the fake key we filled above
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)

# same thing but using a context manager
with smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) as smart_open_object:
with smart_open.smart_open("s3://mybucket/mykey") as smart_open_object:
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)

Expand All @@ -196,7 +196,7 @@ def test_s3_read_moto(self):
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout:
fout.write(content)

smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
self.assertEqual(content[:6], smart_open_object.read(6))
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes

Expand All @@ -216,7 +216,7 @@ def test_s3_seek_moto(self):
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout:
fout.write(content)

smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
self.assertEqual(content[:6], smart_open_object.read(6))
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes

Expand Down Expand Up @@ -344,10 +344,11 @@ def test_write_01(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket
test_string = u"žluťoučký koníček".encode('utf8')

# write into key
with smart_open.S3OpenWrite(mybucket, mykey) as fin:
with smart_open.S3OpenWrite(mykey) as fin:
fin.write(test_string)

# read key and test content
Expand All @@ -363,10 +364,11 @@ def test_write_01a(self):
conn.create_bucket("mybucket")
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.bucket = mybucket
mykey.name = "testkey"

try:
with smart_open.S3OpenWrite(mybucket, mykey) as fin:
with smart_open.S3OpenWrite(mykey) as fin:
fin.write(None)
except TypeError:
pass
Expand All @@ -383,8 +385,9 @@ def test_write_02(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket

smart_open_write = smart_open.S3OpenWrite(mybucket, mykey)
smart_open_write = smart_open.S3OpenWrite(mykey)
with smart_open_write as fin:
fin.write(u"testžížáč")
self.assertEqual(fin.total_size, 14)
Expand All @@ -399,9 +402,10 @@ def test_write_03(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket

# write
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey, min_part_size=10)
smart_open_write = smart_open.S3OpenWrite(mykey, min_part_size=10)
with smart_open_write as fin:
fin.write(u"test") # implicit unicode=>utf8 conversion
self.assertEqual(fin.chunk_bytes, 4)
Expand Down