Skip to content

Commit

Permalink
Merge pull request #38 from asieira/master
Browse files Browse the repository at this point in the history
Allowing boto keys to be passed directly to smart_open (#35)
  • Loading branch information
tmylk committed Dec 18, 2015
2 parents 0c85092 + 43e3192 commit 23f6921
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 70 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
* 1.3.1, 16th December 2015
* 1.4.0, 16th December 2015

- Updated smart_open to accept an instance of boto.s3.key.Key (PR #38, @asieira)

- Disable multiprocessing if unavailable. Allows to run on Google Compute Engine. (PR #41, @nikicc)
- Httpretty updated to allow LC_ALL=C locale config. (PR #39, @jsphpl)
* 1.3.1, 16th December 2015

- Disable multiprocessing if unavailable. Allows to run on Google Compute Engine. (PR #41, @nikicc)
- Httpretty updated to allow LC_ALL=C locale config. (PR #39, @jsphpl)

* 1.3.0, 19th September 2015

Expand Down
134 changes: 79 additions & 55 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def smart_open(uri, mode="rb", **kw):
The `uri` can be either:
1. local filesystem (compressed ``.gz`` or ``.bz2`` files handled automatically):
1. a URI for the local filesystem (compressed ``.gz`` or ``.bz2`` files handled automatically):
`./lines.txt`, `/home/joe/lines.txt.gz`, `file:///home/joe/lines.txt.bz2`
2. Amazon's S3 (can also supply credentials inside the URI):
2. a URI for HDFS: `hdfs:///some/path/lines.txt`
3. a URI for Amazon's S3 (can also supply credentials inside the URI):
`s3://my_bucket/lines.txt`, `s3://my_aws_key_id:key_secret@my_bucket/lines.txt`
3. HDFS: `hdfs:///some/path/lines.txt`
4. an instance of the boto.s3.key.Key class.
Examples::
Expand All @@ -76,6 +77,12 @@ def smart_open(uri, mode="rb", **kw):
... for line in fin:
... print line
>>> # you can also use a boto.s3.key.Key instance directly:
>>> key = boto.connect_s3().get_bucket("my_bucket").get_key("my_key")
>>> with smart_open.smart_open(key) as fin:
... for line in fin:
... print line
>>> # stream line-by-line from an HDFS file
>>> for line in smart_open.smart_open('hdfs:///user/hadoop/my_file.txt'):
... print line
Expand All @@ -96,41 +103,58 @@ def smart_open(uri, mode="rb", **kw):
... fout.write("good bye!\n")
"""
# simply pass-through if already a file-like
if not isinstance(uri, six.string_types) and hasattr(uri, 'read'):
return uri

# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
parsed_uri = ParseUri(uri)

if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode)

if mode in ('r', 'rb'):
if parsed_uri.scheme in ("s3", "s3n"):
return S3OpenRead(parsed_uri, **kw)
elif parsed_uri.scheme in ("hdfs", ):
return HdfsOpenRead(parsed_uri, **kw)
elif parsed_uri.scheme in ("webhdfs", ):
return WebHdfsOpenRead(parsed_uri, **kw)
else:
raise NotImplementedError("read mode not supported for %r scheme", parsed_uri.scheme)
elif mode in ('w', 'wb'):
if parsed_uri.scheme in ("s3", "s3n"):
# validate mode parameter
if not isinstance(mode, six.string_types):
raise TypeError('mode should be a string')
if not mode in ('r', 'rb', 'w', 'wb'):
raise NotImplementedError('unknown file mode %s' % mode)

if isinstance(uri, six.string_types):
# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
parsed_uri = ParseUri(uri)

if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode)
elif parsed_uri.scheme in ("s3", "s3n"):
s3_connection = boto.connect_s3(aws_access_key_id=parsed_uri.access_id, aws_secret_access_key=parsed_uri.access_secret)
outbucket = s3_connection.get_bucket(parsed_uri.bucket_id)
outkey = boto.s3.key.Key(outbucket)
outkey.key = parsed_uri.key_id
return S3OpenWrite(outbucket, outkey, **kw)
bucket = s3_connection.get_bucket(parsed_uri.bucket_id)
if mode in ('r', 'rb'):
key = bucket.get_key(parsed_uri.key_id)
if key is None:
raise KeyError(parsed_uri.key_id)
return S3OpenRead(key, **kw)
else:
key = bucket.get_key(parsed_uri.key_id, validate=False)
if key is None:
raise KeyError(parsed_uri.key_id)
return S3OpenWrite(key, **kw)
elif parsed_uri.scheme in ("hdfs", ):
if mode in ('r', 'rb'):
return HdfsOpenRead(parsed_uri, **kw)
else:
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme)
elif parsed_uri.scheme in ("webhdfs", ):
return WebHdfsOpenWrite(parsed_uri, **kw)
if mode in ('r', 'rb'):
return WebHdfsOpenRead(parsed_uri, **kw)
else:
return WebHdfsOpenWrite(parsed_uri, **kw)
else:
raise NotImplementedError("write mode not supported for %r scheme", parsed_uri.scheme)
raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme)
elif isinstance(uri, boto.s3.key.Key):
# handle case where we are given an S3 key directly
if mode in ('r', 'rb'):
return S3OpenRead(uri)
elif mode in ('w', 'wb'):
return S3OpenWrite(uri)
elif hasattr(uri, 'read'):
# simply pass-through if already a file-like
return uri
else:
raise NotImplementedError("unknown file mode %s" % mode)
raise TypeError('don\'t know how to handle uri %s' % repr(uri))


class ParseUri(object):
Expand Down Expand Up @@ -214,25 +238,17 @@ class S3OpenRead(object):
Implement streamed reader from S3, as an iterable & context manager.
"""
def __init__(self, parsed_uri):
if parsed_uri.scheme not in ("s3", "s3n"):
raise TypeError("can only process S3 files")
self.parsed_uri = parsed_uri
s3_connection = boto.connect_s3(
aws_access_key_id=parsed_uri.access_id,
aws_secret_access_key=parsed_uri.access_secret)
self.read_key = s3_connection.get_bucket(parsed_uri.bucket_id).lookup(parsed_uri.key_id)
if self.read_key is None:
raise KeyError(parsed_uri.key_id)
def __init__(self, read_key):
if not hasattr(read_key, "bucket") and not hasattr(read_key, "name") and not hasattr(read_key, "read") \
and not hasattr(read_key, "close"):
raise TypeError("can only process S3 keys")
self.read_key = read_key
self.line_generator = s3_iter_lines(self.read_key)

def __iter__(self):
s3_connection = boto.connect_s3(
aws_access_key_id=self.parsed_uri.access_id,
aws_secret_access_key=self.parsed_uri.access_secret)
key = s3_connection.get_bucket(self.parsed_uri.bucket_id).lookup(self.parsed_uri.key_id)
key = self.read_key.bucket.get_key(self.read_key.name)
if key is None:
raise KeyError(self.parsed_uri.key_id)
raise KeyError(self.read_key.name)

return s3_iter_lines(key)

Expand Down Expand Up @@ -268,6 +284,12 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.read_key.close()

def __str__(self):
return "%s<key: %s>" % (
self.__class__.__name__, self.read_key
)



class HdfsOpenRead(object):
"""
Expand Down Expand Up @@ -395,22 +417,24 @@ class S3OpenWrite(object):
Context manager for writing into S3 files.
"""
def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
def __init__(self, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
"""
Streamed input is uploaded in chunks, as soon as `min_part_size` bytes are
accumulated (50MB by default). The minimum chunk size allowed by AWS S3
is 5MB.
"""
self.outbucket = outbucket
if not hasattr(outkey, "bucket") and not hasattr(outkey, "name"):
raise TypeError("can only process S3 keys")

self.outkey = outkey
self.min_part_size = min_part_size

if min_part_size < 5 * 1024 ** 2:
logger.warning("S3 requires minimum part size >= 5MB; multipart upload may fail")

# initialize mulitpart upload
self.mp = self.outbucket.initiate_multipart_upload(self.outkey, **kw)
self.mp = self.outkey.bucket.initiate_multipart_upload(self.outkey, **kw)

# initialize stats
self.lines = []
Expand All @@ -419,8 +443,8 @@ def __init__(self, outbucket, outkey, min_part_size=S3_MIN_PART_SIZE, **kw):
self.parts = 0

def __str__(self):
return "%s<bucket: %s, key: %s, min_part_size: %s>" % (
self.__class__.__name__, self.outbucket, self.outkey, self.min_part_size,
return "%s<key: %s, min_part_size: %s>" % (
self.__class__.__name__, self.outkey, self.min_part_size,
)

def write(self, b):
Expand Down Expand Up @@ -467,14 +491,14 @@ def close(self):
# when the input is completely empty => abort the upload, no file created
# TODO: or create the empty file some other way?
logger.info("empty input, ignoring multipart upload")
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)

def __enter__(self):
return self

def _termination_error(self):
logger.exception("encountered error while terminating multipart upload; attempting cancel")
self.outbucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
self.outkey.bucket.cancel_multipart_upload(self.mp.key_name, self.mp.id)
logger.info("cancel completed")

def __exit__(self, type, value, traceback):
Expand Down
28 changes: 16 additions & 12 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,20 @@ def test_webhdfs_read(self):
def test_s3_boto(self, mock_s3_iter_lines, mock_boto):
"""Is S3 line iterator called correctly?"""
# no credentials
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
smart_open_object.__iter__()
mock_boto.connect_s3.assert_called_with(aws_access_key_id=None, aws_secret_access_key=None)

# with credential
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey")
smart_open_object.__iter__()
mock_boto.connect_s3.assert_called_with(aws_access_key_id="access_id", aws_secret_access_key="access_secret")

# lookup bucket, key; call s3_iter_lines
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://access_id:access_secret@mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey")
smart_open_object.__iter__()
mock_boto.connect_s3().get_bucket.assert_called_with("mybucket")
mock_boto.connect_s3().get_bucket().lookup.assert_called_with("mykey")
mock_boto.connect_s3().get_bucket().get_key.assert_called_with("mykey")
self.assertTrue(mock_s3_iter_lines.called)


Expand All @@ -176,12 +176,12 @@ def test_s3_iter_moto(self):
fout.write(expected[-1])

# connect to fake s3 and read from the fake key we filled above
smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)

# same thing but using a context manager
with smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey")) as smart_open_object:
with smart_open.smart_open("s3://mybucket/mykey") as smart_open_object:
output = [line.rstrip(b'\n') for line in smart_open_object]
self.assertEqual(output, expected)

Expand All @@ -196,7 +196,7 @@ def test_s3_read_moto(self):
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout:
fout.write(content)

smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
self.assertEqual(content[:6], smart_open_object.read(6))
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes

Expand All @@ -216,7 +216,7 @@ def test_s3_seek_moto(self):
with smart_open.smart_open("s3://mybucket/mykey", "wb") as fout:
fout.write(content)

smart_open_object = smart_open.S3OpenRead(smart_open.ParseUri("s3://mybucket/mykey"))
smart_open_object = smart_open.smart_open("s3://mybucket/mykey")
self.assertEqual(content[:6], smart_open_object.read(6))
self.assertEqual(content[6:14], smart_open_object.read(8)) # ř is 2 bytes

Expand Down Expand Up @@ -344,10 +344,11 @@ def test_write_01(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket
test_string = u"žluťoučký koníček".encode('utf8')

# write into key
with smart_open.S3OpenWrite(mybucket, mykey) as fin:
with smart_open.S3OpenWrite(mykey) as fin:
fin.write(test_string)

# read key and test content
Expand All @@ -363,10 +364,11 @@ def test_write_01a(self):
conn.create_bucket("mybucket")
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.bucket = mybucket
mykey.name = "testkey"

try:
with smart_open.S3OpenWrite(mybucket, mykey) as fin:
with smart_open.S3OpenWrite(mykey) as fin:
fin.write(None)
except TypeError:
pass
Expand All @@ -383,8 +385,9 @@ def test_write_02(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket

smart_open_write = smart_open.S3OpenWrite(mybucket, mykey)
smart_open_write = smart_open.S3OpenWrite(mykey)
with smart_open_write as fin:
fin.write(u"testžížáč")
self.assertEqual(fin.total_size, 14)
Expand All @@ -399,9 +402,10 @@ def test_write_03(self):
mybucket = conn.get_bucket("mybucket")
mykey = boto.s3.key.Key()
mykey.name = "testkey"
mykey.bucket = mybucket

# write
smart_open_write = smart_open.S3OpenWrite(mybucket, mykey, min_part_size=10)
smart_open_write = smart_open.S3OpenWrite(mykey, min_part_size=10)
with smart_open_write as fin:
fin.write(u"test") # implicit unicode=>utf8 conversion
self.assertEqual(fin.chunk_bytes, 4)
Expand Down

0 comments on commit 23f6921

Please sign in to comment.