-
-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing boto keys to be passed directly to smart_open (#35) #38
Changes from 2 commits
c51f8dc
76c5cf1
0ab2d9e
43e3192
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,41 +85,58 @@ def smart_open(uri, mode="rb", **kw): | |
... fout.write("good bye!\n") | ||
|
||
""" | ||
# simply pass-through if already a file-like | ||
if not isinstance(uri, six.string_types) and hasattr(uri, 'read'): | ||
return uri | ||
|
||
# this method just routes the request to classes handling the specific storage | ||
# schemes, depending on the URI protocol in `uri` | ||
parsed_uri = ParseUri(uri) | ||
|
||
if parsed_uri.scheme in ("file", ): | ||
# local files -- both read & write supported | ||
# compression, if any, is determined by the filename extension (.gz, .bz2) | ||
return file_smart_open(parsed_uri.uri_path, mode) | ||
|
||
if mode in ('r', 'rb'): | ||
if parsed_uri.scheme in ("s3", "s3n"): | ||
return S3OpenRead(parsed_uri, **kw) | ||
elif parsed_uri.scheme in ("hdfs", ): | ||
return HdfsOpenRead(parsed_uri, **kw) | ||
elif parsed_uri.scheme in ("webhdfs", ): | ||
return WebHdfsOpenRead(parsed_uri, **kw) | ||
else: | ||
raise NotImplementedError("read mode not supported for %r scheme", parsed_uri.scheme) | ||
elif mode in ('w', 'wb'): | ||
if parsed_uri.scheme in ("s3", "s3n"): | ||
# validate mode parameter | ||
if not isinstance(mode, six.string_types): | ||
raise TypeError('mode should be a string') | ||
if not mode in ('r', 'rb', 'w', 'wb'): | ||
raise NotImplementedError('unknown file mode %s' % mode) | ||
|
||
if isinstance(uri, six.string_types): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename uri input parameter into uri_or_key as it can be either now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could create a rather serious backwards compatibility problem, wouldn't it? If any existing code is calling What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
# this method just routes the request to classes handling the specific storage | ||
# schemes, depending on the URI protocol in `uri` | ||
parsed_uri = ParseUri(uri) | ||
|
||
if parsed_uri.scheme in ("file", ): | ||
# local files -- both read & write supported | ||
# compression, if any, is determined by the filename extension (.gz, .bz2) | ||
return file_smart_open(parsed_uri.uri_path, mode) | ||
elif parsed_uri.scheme in ("s3", "s3n"): | ||
s3_connection = boto.connect_s3(aws_access_key_id=parsed_uri.access_id, aws_secret_access_key=parsed_uri.access_secret) | ||
outbucket = s3_connection.get_bucket(parsed_uri.bucket_id) | ||
outkey = boto.s3.key.Key(outbucket) | ||
outkey.key = parsed_uri.key_id | ||
return S3OpenWrite(outbucket, outkey, **kw) | ||
bucket = s3_connection.get_bucket(parsed_uri.bucket_id) | ||
if mode in ('r', 'rb'): | ||
key = bucket.get_key(parsed_uri.key_id) | ||
if key is None: | ||
raise KeyError(parsed_uri.key_id) | ||
return S3OpenRead(key, **kw) | ||
else: | ||
key = bucket.get_key(parsed_uri.key_id, validate=False) | ||
if key is None: | ||
raise KeyError(parsed_uri.key_id) | ||
return S3OpenWrite(key, **kw) | ||
elif parsed_uri.scheme in ("hdfs", ): | ||
if mode in ('r', 'rb'): | ||
return HdfsOpenRead(parsed_uri, **kw) | ||
else: | ||
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme) | ||
elif parsed_uri.scheme in ("webhdfs", ): | ||
return WebHdfsOpenWrite(parsed_uri, **kw) | ||
if mode in ('r', 'rb'): | ||
return WebHdfsOpenRead(parsed_uri, **kw) | ||
else: | ||
return WebHdfsOpenWrite(parsed_uri, **kw) | ||
else: | ||
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme) | ||
raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme) | ||
elif isinstance(uri, boto.s3.key.Key): | ||
# handle case where we are given an S3 key directly | ||
if mode in ('r', 'rb'): | ||
return S3OpenRead(uri) | ||
elif mode in ('w', 'wb'): | ||
return S3OpenWrite(uri) | ||
elif hasattr(uri, 'read'): | ||
# simply pass-through if already a file-like | ||
return uri | ||
else: | ||
raise NotImplementedError("unknown file mode %s" % mode) | ||
raise TypeError('don\'t know how to handle uri %s' % repr(uri)) | ||
|
||
|
||
class ParseUri(object): | ||
|
@@ -203,25 +220,16 @@ class S3OpenRead(object): | |
Implement streamed reader from S3, as an iterable & context manager. | ||
|
||
""" | ||
def __init__(self, parsed_uri): | ||
if parsed_uri.scheme not in ("s3", "s3n"): | ||
raise TypeError("can only process S3 files") | ||
self.parsed_uri = parsed_uri | ||
s3_connection = boto.connect_s3( | ||
aws_access_key_id=parsed_uri.access_id, | ||
aws_secret_access_key=parsed_uri.access_secret) | ||
self.read_key = s3_connection.get_bucket(parsed_uri.bucket_id).lookup(parsed_uri.key_id) | ||
if self.read_key is None: | ||
raise KeyError(parsed_uri.key_id) | ||
def __init__(self, read_key): | ||
if not isinstance(read_key, boto.s3.key.Key): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is redundant as you have already checked the type in the constructor. It needs to be removed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mock object is of class MagicMock so isinstance won't work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so I added the |
||
raise TypeError("can only process S3 keys") | ||
self.read_key = read_key | ||
self.line_generator = s3_iter_lines(self.read_key) | ||
|
||
def __iter__(self): | ||
s3_connection = boto.connect_s3( | ||
aws_access_key_id=self.parsed_uri.access_id, | ||
aws_secret_access_key=self.parsed_uri.access_secret) | ||
key = s3_connection.get_bucket(self.parsed_uri.bucket_id).lookup(self.parsed_uri.key_id) | ||
key = self.read_key.bucket.get_key(self.read_key.name) | ||
if key is None: | ||
raise KeyError(self.parsed_uri.key_id) | ||
raise KeyError(self.read_key.name) | ||
|
||
return s3_iter_lines(key) | ||
|
||
|
@@ -257,6 +265,12 @@ def __enter__(self): | |
def __exit__(self, type, value, traceback): | ||
self.read_key.close() | ||
|
||
def __str__(self): | ||
return "%s<key: %s>" % ( | ||
self.__class__.__name__, self.read_key | ||
) | ||
|
||
|
||
|
||
class HdfsOpenRead(object): | ||
""" | ||
|
@@ -384,22 +398,21 @@ class S3OpenWrite(object): | |
Context manager for writing into S3 files. | ||
|
||
""" | ||
def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw): | ||
def __init__(self, outkey, min_part_size=S3_MIN_PART_SIZE, **kw): | ||
""" | ||
Streamed input is uploaded in chunks, as soon as `min_part_size` bytes are | ||
accumulated (50MB by default). The minimum chunk size allowed by AWS S3 | ||
is 5MB. | ||
|
||
""" | ||
self.outbucket = outbucket | ||
self.outkey = outkey | ||
self.min_part_size = min_part_size | ||
|
||
if min_part_size < 5 * 1024 ** 2: | ||
logger.warning("S3 requires minimum part size >= 5MB; multipart upload may fail") | ||
|
||
# initialize mulitpart upload | ||
self.mp = self.outbucket.initiate_multipart_upload(self.outkey, **kw) | ||
self.mp = self.outkey.bucket.initiate_multipart_upload(self.outkey, **kw) | ||
|
||
# initialize stats | ||
self.lines = [] | ||
|
@@ -408,8 +421,8 @@ def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw): | |
self.parts = 0 | ||
|
||
def __str__(self): | ||
return "%s<bucket: %s, key: %s, min_part_size: %s>" % ( | ||
self.__class__.__name__, self.outbucket, self.outkey, self.min_part_size, | ||
return "%s<key: %s, min_part_size: %s>" % ( | ||
self.__class__.__name__, self.outkey, self.min_part_size, | ||
) | ||
|
||
def write(self, b): | ||
|
@@ -456,14 +469,14 @@ def close(self): | |
# when the input is completely empty => abort the upload, no file created | ||
# TODO: or create the empty file some other way? | ||
logger.info("empty input, ignoring multipart upload") | ||
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def _termination_error(self): | ||
logger.exception("encountered error while terminating multipart upload; attempting cancel") | ||
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id) | ||
logger.info("cancel completed") | ||
|
||
def __exit__(self, type, value, traceback): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,26 +131,26 @@ def test_webhdfs_read(self): | |
smart_open_object = smart_open.WebHdfsOpenRead(smart_open.ParseUri("webhdfs://127.0.0.1:8440/path/file")) | ||
self.assertEqual(smart_open_object.read().decode("utf-8"), "line1\nline2") | ||
|
||
@mock.patch('smart_open.smart_open_lib.boto') | ||
@mock.patch('smart_open.smart_open_lib.s3_iter_lines') | ||
def test_s3_boto(self, mock_s3_iter_lines, mock_boto): | ||
"""Is S3 line iterator called correctly?""" | ||
# no credentials | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object.__iter__() | ||
mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None) | ||
|
||
# with credential | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey")) | ||
smart_open_object.__iter__() | ||
mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret") | ||
|
||
# lookup bucket, key; call s3_iter_lines | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey")) | ||
smart_open_object.__iter__() | ||
mock_boto.connect_s3().get_bucket.assert_called_with("mybucket") | ||
mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey") | ||
self.assertTrue(mock_s3_iter_lines.called) | ||
# @mock.patch('smart_open.smart_open_lib.boto') | ||
# @mock.patch('smart_open.smart_open_lib.s3_iter_lines') | ||
# def test_s3_boto(self, mock_s3_iter_lines, mock_boto): | ||
# """Is S3 line iterator called correctly?""" | ||
# # no credentials | ||
# smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
# smart_open_object.__iter__() | ||
# mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None) | ||
# | ||
# # with credential | ||
# smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey") | ||
# smart_open_object.__iter__() | ||
# mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret") | ||
# | ||
# # lookup bucket, key; call s3_iter_lines | ||
# smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey") | ||
# smart_open_object.__iter__() | ||
# mock_boto.connect_s3().get_bucket.assert_called_with("mybucket") | ||
# mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update the test as you are using get_key instead of lookup There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch, thank you! |
||
# self.assertTrue(mock_s3_iter_lines.called) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please uncomment the test as we need to merge them in. |
||
|
||
|
||
@mock_s3 | ||
|
@@ -176,12 +176,12 @@ def test_s3_iter_moto(self): | |
fout.write(expected[-1]) | ||
|
||
# connect to fake s3 and read from the fake key we filled above | ||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
output = [line.rstrip(b'\n') for line in smart_open_object] | ||
self.assertEqual(output, expected) | ||
|
||
# same thing but using a context manager | ||
with smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) as smart_open_object: | ||
with smart_open.smart_open("s3://mybucket/mykey") as smart_open_object: | ||
output = [line.rstrip(b'\n') for line in smart_open_object] | ||
self.assertEqual(output, expected) | ||
|
||
|
@@ -196,7 +196,7 @@ def test_s3_read_moto(self): | |
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout: | ||
fout.write(content) | ||
|
||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
self.assertEqual(content[:6], smart_open_object.read(6)) | ||
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes | ||
|
||
|
@@ -216,7 +216,7 @@ def test_s3_seek_moto(self): | |
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout: | ||
fout.write(content) | ||
|
||
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) | ||
smart_open_object = smart_open.smart_open("s3://mybucket/mykey") | ||
self.assertEqual(content[:6], smart_open_object.read(6)) | ||
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes | ||
|
||
|
@@ -344,10 +344,11 @@ def test_write_01(self): | |
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.name = "testkey" | ||
mykey.bucket = mybucket | ||
test_string = u"žluťoučký koníček".encode('utf8') | ||
|
||
# write into key | ||
with smart_open.S3OpenWrite(mybucket, mykey) as fin: | ||
with smart_open.S3OpenWrite(mykey) as fin: | ||
fin.write(test_string) | ||
|
||
# read key and test content | ||
|
@@ -363,10 +364,11 @@ def test_write_01a(self): | |
conn.create_bucket("mybucket") | ||
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.bucket = mybucket | ||
mykey.name = "testkey" | ||
|
||
try: | ||
with smart_open.S3OpenWrite(mybucket, mykey) as fin: | ||
with smart_open.S3OpenWrite(mykey) as fin: | ||
fin.write(None) | ||
except TypeError: | ||
pass | ||
|
@@ -383,8 +385,9 @@ def test_write_02(self): | |
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.name = "testkey" | ||
mykey.bucket = mybucket | ||
|
||
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey) | ||
smart_open_write = smart_open.S3OpenWrite(mykey) | ||
with smart_open_write as fin: | ||
fin.write(u"testžížáč") | ||
self.assertEqual(fin.total_size, 14) | ||
|
@@ -399,9 +402,10 @@ def test_write_03(self): | |
mybucket = conn.get_bucket("mybucket") | ||
mykey = boto.s3.key.Key() | ||
mykey.name = "testkey" | ||
mykey.bucket = mybucket | ||
|
||
# write | ||
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey, min_part_size=10) | ||
smart_open_write = smart_open.S3OpenWrite(mykey, min_part_size=10) | ||
with smart_open_write as fin: | ||
fin.write(u"test") # implicit unicode=>utf8 conversion | ||
self.assertEqual(fin.chunk_bytes, 4) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the new boto key usage example to the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.