Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement object_kwargs parameter #411

Merged
merged 9 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions howto.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,22 @@ text/plain

```

This works only when reading and writing via S3.

## How to Specify the Request Payer (S3 only)

Some public buckets require you to [pay for S3 requests for the data in the bucket](https://docs.aws.amazon.com/AmazonS3/latest/dev/RequesterPaysBuckets.html).
This relieves the bucket owner of the data transfer costs, and spreads them among the consumers of the data.

To access such buckets, you need to pass some special transport parameters:

```python
>>> from smart_open import open
>>> p = {'object_kwargs': {'RequestPayer': 'requester'}}
>>> with open('s3://arxiv/pdf/arXiv_pdf_manifest.xml', transport_params=p) as fin:
... print(fin.read(1024))
<?xml version='1.0' standalone='yes'?>

```

This works only when reading and writing via S3.
211 changes: 102 additions & 109 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def open(
session=None,
resource_kwargs=None,
multipart_upload_kwargs=None,
object_kwargs=None,
):
"""Open an S3 object for reading or writing.

Expand All @@ -103,6 +104,9 @@ def open(
version_id: str, optional
Version of the object, used when reading object.
If None, will fetch the most recent version.
object_kwargs: dict, optional
Additional parameters to pass to boto3's object.get function.
Used during reading only.

"""
logger.debug('%r', locals())
Expand All @@ -113,21 +117,24 @@ def open(
resource_kwargs = {}
if multipart_upload_kwargs is None:
multipart_upload_kwargs = {}
if object_kwargs is None:
object_kwargs = {}

if (mode == WRITE_BINARY) and (version_id is not None):
raise ValueError("version_id must be None when writing")

if mode == READ_BINARY:
fileobj = SeekableBufferedInputBase(
fileobj = Reader(
bucket_id,
key_id,
version_id=version_id,
buffer_size=buffer_size,
session=session,
resource_kwargs=resource_kwargs,
object_kwargs=object_kwargs,
)
elif mode == WRITE_BINARY:
fileobj = BufferedOutputBase(
fileobj = MultipartWriter(
bucket_id,
key_id,
min_part_size=min_part_size,
Expand All @@ -153,28 +160,19 @@ def _get(s3_object, version=None, **kwargs):
)


class RawReader(object):
"""Read an S3 object."""
def __init__(self, s3_object):
self.position = 0
self._object = s3_object
self._body = s3_object.get()['Body']

def read(self, size=-1):
if size == -1:
return self._body.read()
return self._body.read(size)

class _SeekableRawReader(object):
"""Read an S3 object.

class SeekableRawReader(object):
"""Read an S3 object."""
This class is internal to the S3 submodule.
"""

def __init__(self, s3_object, content_length, version_id=None):
def __init__(self, s3_object, content_length, version_id=None, object_kwargs=None):
self._object = s3_object
self._content_length = content_length
self._version_id = version_id
self._position = 0
self._body = None
self._object_kwargs = object_kwargs if object_kwargs else {}

def seek(self, position):
"""Seek to the specified position (byte offset) in the S3 key.
Expand Down Expand Up @@ -203,7 +201,12 @@ def _load_body(self):
#
self._body = io.BytesIO()
else:
self._body = _get(self._object, self._version_id, Range=range_string)['Body']
self._body = _get(
self._object,
version=self._version_id,
Range=range_string,
**self._object_kwargs
)['Body']

def _read_from_body(self, size=-1):
if size == -1:
Expand All @@ -230,23 +233,43 @@ def read(self, size=-1):
return binary


class BufferedInputBase(io.BufferedIOBase):
class Reader(io.BufferedIOBase):
"""Reads bytes from S3.

Implements the io.BufferedIOBase interface of the standard library."""

def __init__(self, bucket, key, version_id=None, buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=None):
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=None,
object_kwargs=None):

self._buffer_size = buffer_size

if session is None:
session = boto3.Session()
if resource_kwargs is None:
resource_kwargs = {}
if object_kwargs is None:
object_kwargs = {}

self._session = session
self._resource_kwargs = resource_kwargs
self._object_kwargs = object_kwargs

s3 = session.resource('s3', **resource_kwargs)
self._object = s3.Object(bucket, key)
self._version_id = version_id
self._raw_reader = RawReader(self._object)
self._content_length = self._object.content_length
self._content_length = _get(self._object, self._version_id)['ContentLength']
self._content_length = _get(
self._object,
version=self._version_id,
**self._object_kwargs
)['ContentLength']

self._raw_reader = _SeekableRawReader(
self._object,
self._content_length,
self._version_id,
self._object_kwargs,
)
self._current_pos = 0
self._buffer = smart_open.bytebuffer.ByteBuffer(buffer_size)
self._eof = False
Expand All @@ -258,8 +281,9 @@ def __init__(self, bucket, key, version_id=None, buffer_size=DEFAULT_BUFFER_SIZE
self.raw = None

#
# Override some methods from io.IOBase.
# io.BufferedIOBase methods.
#

def close(self):
"""Flush and close this stream."""
logger.debug("close: called")
Expand All @@ -269,16 +293,6 @@ def readable(self):
"""Return True if the stream can be read from."""
return True

def seekable(self):
return False

#
# io.BufferedIOBase methods.
#
def detach(self):
"""Unsupported."""
raise io.UnsupportedOperation

def read(self, size=-1):
"""Read up to size bytes from the object and return them."""
if size == 0:
Expand Down Expand Up @@ -343,76 +357,6 @@ def readline(self, limit=-1):
self._fill_buffer()
return the_line.getvalue()

def terminate(self):
"""Do nothing."""
pass

def to_boto3(self):
"""Create an **independent** `boto3.s3.Object` instance that points to
the same resource as this instance.

The created instance will re-use the session and resource parameters of
the current instance, but it will be independent: changes to the
`boto3.s3.Object` may not necessary affect the current instance.

"""
s3 = self._session.resource('s3', **self._resource_kwargs)
return s3.Object(self._object.bucket_name, self._object.key)

#
# Internal methods.
#
def _read_from_buffer(self, size=-1):
"""Remove at most size bytes from our buffer and return them."""
# logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._buffer))
size = size if size >= 0 else len(self._buffer)
part = self._buffer.read(size)
self._current_pos += len(part)
# logger.debug('part: %r', part)
return part

def _fill_buffer(self, size=-1):
size = size if size >= 0 else self._buffer._chunk_size
while len(self._buffer) < size and not self._eof:
bytes_read = self._buffer.fill(self._raw_reader)
if bytes_read == 0:
logger.debug('reached EOF while filling buffer')
self._eof = True


class SeekableBufferedInputBase(BufferedInputBase):
"""Reads bytes from S3.

Implements the io.BufferedIOBase interface of the standard library."""

def __init__(self, bucket, key, version_id=None, buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=None):

self._buffer_size = buffer_size

if session is None:
session = boto3.Session()
if resource_kwargs is None:
resource_kwargs = {}

self._session = session
self._resource_kwargs = resource_kwargs
s3 = session.resource('s3', **resource_kwargs)
self._object = s3.Object(bucket, key)
self._version_id = version_id
self._content_length = _get(self._object, self._version_id)['ContentLength']

self._raw_reader = SeekableRawReader(self._object, self._content_length, self._version_id)
self._current_pos = 0
self._buffer = smart_open.bytebuffer.ByteBuffer(buffer_size)
self._eof = False
self._line_terminator = line_terminator

#
# This member is part of the io.BufferedIOBase interface.
#
self.raw = None

def seekable(self):
"""If False, seek(), tell() and truncate() will raise IOError.

Expand Down Expand Up @@ -453,14 +397,54 @@ def truncate(self, size=None):
"""Unsupported."""
raise io.UnsupportedOperation

def detach(self):
"""Unsupported."""
raise io.UnsupportedOperation

def terminate(self):
"""Do nothing."""
pass

def to_boto3(self):
"""Create an **independent** `boto3.s3.Object` instance that points to
the same resource as this instance.

The created instance will re-use the session and resource parameters of
the current instance, but it will be independent: changes to the
`boto3.s3.Object` may not necessary affect the current instance.

"""
s3 = self._session.resource('s3', **self._resource_kwargs)
return s3.Object(self._object.bucket_name, self._object.key)

#
# Internal methods.
#
def _read_from_buffer(self, size=-1):
"""Remove at most size bytes from our buffer and return them."""
# logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._buffer))
size = size if size >= 0 else len(self._buffer)
part = self._buffer.read(size)
self._current_pos += len(part)
# logger.debug('part: %r', part)
return part

def _fill_buffer(self, size=-1):
size = size if size >= 0 else self._buffer._chunk_size
while len(self._buffer) < size and not self._eof:
bytes_read = self._buffer.fill(self._raw_reader)
if bytes_read == 0:
logger.debug('reached EOF while filling buffer')
self._eof = True

def __str__(self):
return "smart_open.s3.SeekableBufferedInputBase(%r, %r)" % (
return "smart_open.s3.Reader(%r, %r)" % (
self._object.bucket_name, self._object.key
)

def __repr__(self):
return (
"smart_open.s3.SeekableBufferedInputBase("
"smart_open.s3.Reader("
"bucket=%r, "
"key=%r, "
"version_id=%r, "
Expand All @@ -479,7 +463,7 @@ def __repr__(self):
)


class BufferedOutputBase(io.BufferedIOBase):
class MultipartWriter(io.BufferedIOBase):
"""Writes bytes to S3.

Implements the io.BufferedIOBase interface of the standard library."""
Expand Down Expand Up @@ -637,11 +621,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def __str__(self):
return "smart_open.s3.BufferedOutputBase(%r, %r)" % (self._object.bucket_name, self._object.key)
return "smart_open.s3.MultipartWriter(%r, %r)" % (
self._object.bucket_name, self._object.key,
)

def __repr__(self):
return (
"smart_open.s3.BufferedOutputBase("
"smart_open.s3.MultipartWriter("
"bucket=%r, "
"key=%r, "
"min_part_size=%r, "
Expand All @@ -658,6 +644,13 @@ def __repr__(self):
)


#
# For backward compatibility
#
SeekableBufferedInputBase = Reader
BufferedOutputBase = MultipartWriter


def _accept_all(key):
return True

Expand Down
Loading