-
-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing boto keys to be passed directly to smart_open (#35) #38
Changes from all commits
c51f8dc
76c5cf1
0ab2d9e
43e3192
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,11 +63,12 @@ def smart_open(uri, mode="rb", **kw): | |
|
||
The `uri` can be either: | ||
|
||
1. local filesystem (compressed ``.gz`` or ``.bz2`` files handled automatically): | ||
1. a URI for the local filesystem (compressed ``.gz`` or ``.bz2`` files handled automatically): | ||
`./lines.txt`, `/home/joe/lines.txt.gz`, `file:///home/joe/lines.txt.bz2` | ||
2. Amazon's S3 (can also supply credentials inside the URI): | ||
2. a URI for HDFS: `hdfs:///some/path/lines.txt` | ||
3. a URI for Amazon's S3 (can also supply credentials inside the URI): | ||
`s3://my_bucket/lines.txt`, `s3://my_aws_key_id:key_secret@my_bucket/lines.txt` | ||
3. HDFS: `hdfs:///some/path/lines.txt` | ||
4. an instance of the boto.s3.key.Key class. | ||
|
||
Examples:: | ||
|
||
|
@@ -76,6 +77,12 @@ def smart_open(uri, mode="rb", **kw): | |
... for line in fin: | ||
... print line | ||
|
||
>>> # you can also use a boto.s3.key.Key instance directly: | ||
>>> key = boto.connect_s3().get_bucket("my_bucket").get_key("my_key") | ||
>>> with smart_open.smart_open(key) as fin: | ||
... for line in fin: | ||
... print line | ||
|
||
>>> # stream line-by-line from an HDFS file | ||
>>> for line in smart_open.smart_open('hdfs:///user/hadoop/my_file.txt'): | ||
... print line | ||
|
@@ -96,41 +103,58 @@ def smart_open(uri, mode="rb", **kw): | |
... fout.write("good bye!\n") | ||
|
||
""" | ||
# simply pass-through if already a file-like | ||
if not isinstance(uri, six.string_types) and hasattr(uri, 'read'): | ||
return uri | ||
|
||
# this method just routes the request to classes handling the specific storage | ||
# schemes, depending on the URI protocol in `uri` | ||
parsed_uri = ParseUri(uri) | ||
|
||
if parsed_uri.scheme in ("file", ): | ||
# local files -- both read & write supported | ||
# compression, if any, is determined by the filename extension (.gz, .bz2) | ||
return file_smart_open(parsed_uri.uri_path, mode) | ||
|
||
if mode in ('r', 'rb'): | ||
if parsed_uri.scheme in ("s3", "s3n"): | ||
return S3OpenRead(parsed_uri, **kw) | ||
elif parsed_uri.scheme in ("hdfs", ): | ||
return HdfsOpenRead(parsed_uri, **kw) | ||
elif parsed_uri.scheme in ("webhdfs", ): | ||
return WebHdfsOpenRead(parsed_uri, **kw) | ||
else: | ||
raise NotImplementedError("read mode not supported for %r scheme", parsed_uri.scheme) | ||
elif mode in ('w', 'wb'): | ||
if parsed_uri.scheme in ("s3", "s3n"): | ||
# validate mode parameter | ||
if not isinstance(mode, six.string_types): | ||
raise TypeError('mode should be a string') | ||
if not mode in ('r', 'rb', 'w', 'wb'): | ||
raise NotImplementedError('unknown file mode %s' % mode) | ||
|
||
if isinstance(uri, six.string_types): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename uri input parameter into uri_or_key as it can be either now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could create a rather serious backwards compatibility problem, wouldn't it? If any existing code is calling What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
# this method just routes the request to classes handling the specific storage | ||
# schemes, depending on the URI protocol in `uri` | ||
parsed_uri = ParseUri(uri) | ||
|
||
if parsed_uri.scheme in ("file", ): | ||
# local files -- both read & write supported | ||
# compression, if any, is determined by the filename extension (.gz, .bz2) | ||
return file_smart_open(parsed_uri.uri_path, mode) | ||
elif parsed_uri.scheme in ("s3", "s3n"): | ||
s3_connection = boto.connect_s3(aws_access_key_id=parsed_uri.access_id, aws_secret_access_key=parsed_uri.access_secret) | ||
outbucket = s3_connection.get_bucket(parsed_uri.bucket_id) | ||
outkey = boto.s3.key.Key(outbucket) | ||
outkey.key = parsed_uri.key_id | ||
return S3OpenWrite(outbucket, outkey, **kw) | ||
bucket = s3_connection.get_bucket(parsed_uri.bucket_id) | ||
if mode in ('r', 'rb'): | ||
key = bucket.get_key(parsed_uri.key_id) | ||
if key is None: | ||
raise KeyError(parsed_uri.key_id) | ||
return S3OpenRead(key, **kw) | ||
else: | ||
key = bucket.get_key(parsed_uri.key_id, validate=False) | ||
if key is None: | ||
raise KeyError(parsed_uri.key_id) | ||
return S3OpenWrite(key, **kw) | ||
elif parsed_uri.scheme in ("hdfs", ): | ||
if mode in ('r', 'rb'): | ||
return HdfsOpenRead(parsed_uri, **kw) | ||
else: | ||
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme) | ||
elif parsed_uri.scheme in ("webhdfs", ): | ||
return WebHdfsOpenWrite(parsed_uri, **kw) | ||
if mode in ('r', 'rb'): | ||
return WebHdfsOpenRead(parsed_uri, **kw) | ||
else: | ||
return WebHdfsOpenWrite(parsed_uri, **kw) | ||
else: | ||
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme) | ||
raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme) | ||
elif isinstance(uri, boto.s3.key.Key): | ||
# handle case where we are given an S3 key directly | ||
if mode in ('r', 'rb'): | ||
return S3OpenRead(uri) | ||
elif mode in ('w', 'wb'): | ||
return S3OpenWrite(uri) | ||
elif hasattr(uri, 'read'): | ||
# simply pass-through if already a file-like | ||
return uri | ||
else: | ||
raise NotImplementedError("unknown file mode %s" % mode) | ||
raise TypeError('don\'t know how to handle uri %s' % repr(uri)) | ||
|
||
|
||
class ParseUri(object): | ||
|
@@ -214,25 +238,17 @@ class S3OpenRead(object): | |
Implement streamed reader from S3, as an iterable & context manager. | ||
|
||
""" | ||
def __init__(self, parsed_uri): | ||
if parsed_uri.scheme not in ("s3", "s3n"): | ||
raise TypeError("can only process S3 files") | ||
self.parsed_uri = parsed_uri | ||
s3_connection = boto.connect_s3( | ||
aws_access_key_id=parsed_uri.access_id, | ||
aws_secret_access_key=parsed_uri.access_secret) | ||
self.read_key = s3_connection.get_bucket(parsed_uri.bucket_id).lookup(parsed_uri.key_id) | ||
if self.read_key is None: | ||
raise KeyError(parsed_uri.key_id) | ||
def __init__(self, read_key): | ||
if not hasattr(read_key, "bucket") and not hasattr(read_key, "name") and not hasattr(read_key, "read") \ | ||
and not hasattr(read_key, "close"): | ||
raise TypeError("can only process S3 keys") | ||
self.read_key = read_key | ||
self.line_generator = s3_iter_lines(self.read_key) | ||
|
||
def __iter__(self): | ||
s3_connection = boto.connect_s3( | ||
aws_access_key_id=self.parsed_uri.access_id, | ||
aws_secret_access_key=self.parsed_uri.access_secret) | ||
key = s3_connection.get_bucket(self.parsed_uri.bucket_id).lookup(self.parsed_uri.key_id) | ||
key = self.read_key.bucket.get_key(self.read_key.name) | ||
if key is None: | ||
raise KeyError(self.parsed_uri.key_id) | ||
raise KeyError(self.read_key.name) | ||
|
||
return s3_iter_lines(key) | ||
|
||
|
@@ -268,6 +284,12 @@ def __enter__(self): | |
def __exit__(self, type, value, traceback): | ||
self.read_key.close() | ||
|
||
def __str__(self): | ||
return "%s<key: %s>" % ( | ||
self.__class__.__name__, self.read_key | ||
) | ||
|
||
|
||
|
||
class HdfsOpenRead(object): | ||
""" | ||
|
@@ -395,22 +417,24 @@ class S3OpenWrite(object): | |
Context manager for writing into S3 files. | ||
|
||
""" | ||
def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw): | ||
def __init__(self, outkey, min_part_size=S3_MIN_PART_SIZE, **kw): | ||
""" | ||
Streamed input is uploaded in chunks, as soon as `min_part_size` bytes are | ||
accumulated (50MB by default). The minimum chunk size allowed by AWS S3 | ||
is 5MB. | ||
|
||
""" | ||
self.outbucket = outbucket | ||
if not hasattr(outkey, "bucket") and not hasattr(outkey, "name"): | ||
raise TypeError("can only process S3 keys") | ||
|
||
self.outkey = outkey | ||
self.min_part_size = min_part_size | ||
|
||
if min_part_size < 5 * 1024 ** 2: | ||
logger.warning("S3 requires minimum part size >= 5MB; multipart upload may fail") | ||
|
||
# initialize mulitpart upload | ||
self.mp = self.outbucket.initiate_multipart_upload(self.outkey, **kw) | ||
self.mp = self.outkey.bucket.initiate_multipart_upload(self.outkey, **kw) | ||
|
||
# initialize stats | ||
self.lines = [] | ||
|
@@ -419,8 +443,8 @@ def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw): | |
self.parts = 0 | ||
|
||
def __str__(self): | ||
return "%s<bucket: %s, key: %s, min_part_size: %s>" % ( | ||
self.__class__.__name__, self.outbucket, self.outkey, self.min_part_size, | ||
return "%s<key: %s, min_part_size: %s>" % ( | ||
self.__class__.__name__, self.outkey, self.min_part_size, | ||
) | ||
|
||
def write(self, b): | ||
|
@@ -467,14 +491,14 @@ def close(self): | |
# when the input is completely empty => abort the upload, no file created | ||
# TODO: or create the empty file some other way? | ||
logger.info("empty input, ignoring multipart upload") | ||
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def _termination_error(self): | ||
logger.exception("encountered error while terminating multipart upload; attempting cancel") | ||
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
logger.info("cancel completed") | ||
|
||
def __exit__(self, type, value, traceback): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,20 +136,20 @@ def test_webhdfs_read(self): | |
def test_s3_boto(self, mock_s3_iter_lines, mock_boto): | ||
"""Is S3 line iterator called correctly?""" | ||
# no credentials | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
smart_open_object.__iter__() | ||
mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None) | ||
|
||
# with credential | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey") | ||
smart_open_object.__iter__() | ||
mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret") | ||
|
||
# lookup bucket, key; call s3_iter_lines | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey") | ||
smart_open_object.__iter__() | ||
mock_boto.connect_s3().get_bucket.assert_called_with("mybucket") | ||
mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey") | ||
mock_boto.connect_s3().get_bucket().get_key.assert_called_with("mykey") | ||
self.assertTrue(mock_s3_iter_lines.called) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please uncomment the test as we need to merge them in. |
||
|
||
|
||
|
@@ -176,12 +176,12 @@ def test_s3_iter_moto(self): | |
fout.write(expected[-1]) | ||
|
||
# connect to fake s3 and read from the fake key we filled above | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
output = [line.rstrip(b'\n') for line in smart_open_object] | ||
self.assertEqual(output, expected) | ||
|
||
# same thing but using a context manager | ||
with smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) as smart_open_object: | ||
with smart_open.smart_open("s3://mybucket/mykey") as smart_open_object: | ||
output = [line.rstrip(b'\n') for line in smart_open_object] | ||
self.assertEqual(output, expected) | ||
|
||
|
@@ -196,7 +196,7 @@ def test_s3_read_moto(self): | |
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout: | ||
fout.write(content) | ||
|
||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
self.assertEqual(content[:6], smart_open_object.read(6)) | ||
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes | ||
|
||
|
@@ -216,7 +216,7 @@ def test_s3_seek_moto(self): | |
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout: | ||
fout.write(content) | ||
|
||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
self.assertEqual(content[:6], smart_open_object.read(6)) | ||
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes | ||
|
||
|
@@ -344,10 +344,11 @@ def test_write_01(self): | |
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.name = "testkey" | ||
mykey.bucket = mybucket | ||
test_string = u"žluťoučký koníček".encode('utf8') | ||
|
||
# write into key | ||
with smart_open.S3OpenWrite(mybucket, mykey) as fin: | ||
with smart_open.S3OpenWrite(mykey) as fin: | ||
fin.write(test_string) | ||
|
||
# read key and test content | ||
|
@@ -363,10 +364,11 @@ def test_write_01a(self): | |
conn.create_bucket("mybucket") | ||
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.bucket = mybucket | ||
mykey.name = "testkey" | ||
|
||
try: | ||
with smart_open.S3OpenWrite(mybucket, mykey) as fin: | ||
with smart_open.S3OpenWrite(mykey) as fin: | ||
fin.write(None) | ||
except TypeError: | ||
pass | ||
|
@@ -383,8 +385,9 @@ def test_write_02(self): | |
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.name = "testkey" | ||
mykey.bucket = mybucket | ||
|
||
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey) | ||
smart_open_write = smart_open.S3OpenWrite(mykey) | ||
with smart_open_write as fin: | ||
fin.write(u"testžížáč") | ||
self.assertEqual(fin.total_size, 14) | ||
|
@@ -399,9 +402,10 @@ def test_write_03(self): | |
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.name = "testkey" | ||
mykey.bucket = mybucket | ||
|
||
# write | ||
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey, min_part_size=10) | ||
smart_open_write = smart_open.S3OpenWrite(mykey, min_part_size=10) | ||
with smart_open_write as fin: | ||
fin.write(u"test") # implicit unicode=>utf8 conversion | ||
self.assertEqual(fin.chunk_bytes, 4) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the new boto key usage example to the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.