Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up kwargs #230

Closed
wants to merge 10 commits into from
21 changes: 2 additions & 19 deletions smart_open/http.py
Expand Up @@ -16,26 +16,9 @@


class BufferedInputBase(io.BufferedIOBase):
"""
Implement streamed reader from a web site.
Supports Kerberos and Basic HTTP authentication.
"""

def __init__(self, url, mode='r', kerberos=False, user=None, password=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
Otherwise, will try to use "basic" HTTP authentication via username/password.

If none of those are set, will connect unauthenticated.
"""
if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
elif user is not None and password is not None:
auth = (user, password)
else:
auth = None
"""Implement streamed reader from a web site."""

def __init__(self, url, mode='r', auth=None):
self.response = requests.get(url, auth=auth, stream=True, headers=_HEADERS)

if not self.response.ok:
Expand Down
121 changes: 83 additions & 38 deletions smart_open/s3.py
Expand Up @@ -57,21 +57,42 @@ def _clamp(value, minval, maxval):
return max(min(value, maxval), minval)


def open(bucket_id, key_id, mode, **kwargs):
def open(bucket_id, key_id, mode,
resource=None,
min_part_size=DEFAULT_MIN_PART_SIZE,
multipart_upload_kwargs=None,
):
"""
Open s3://{bucket_id}/{key_id}

Use the resource object to override the default session (e.g. profile name,
access keys, and endpoint URL).

:param str bucket_id:
:param str key_id:
:param str mode: must be one of rb or wb
:param int min_part_size: For writing only.
:param boto3.s3.ServiceResource resource: The resource for accessing S3.
:param dict multipart_upload_kwargs: For writing only.
"""
logger.debug('%r', locals())
if mode not in MODES:
raise NotImplementedError('bad mode: %r expected one of %r' % (mode, MODES))

encoding = kwargs.pop("encoding", "utf-8")
errors = kwargs.pop("errors", None)
newline = kwargs.pop("newline", None)
line_buffering = kwargs.pop("line_buffering", False)
s3_min_part_size = kwargs.pop("s3_min_part_size", DEFAULT_MIN_PART_SIZE)

if mode == READ_BINARY:
fileobj = SeekableBufferedInputBase(bucket_id, key_id, **kwargs)
fileobj = SeekableBufferedInputBase(
bucket_id,
key_id,
resource=resource
)
elif mode == WRITE_BINARY:
fileobj = BufferedOutputBase(bucket_id, key_id, min_part_size=s3_min_part_size, **kwargs)
fileobj = BufferedOutputBase(
bucket_id,
key_id,
resource=resource,
min_part_size=min_part_size,
multipart_upload_kwargs=multipart_upload_kwargs,
)
else:
assert False, 'unexpected mode: %r' % mode

Expand Down Expand Up @@ -138,13 +159,20 @@ def read(self, size=-1):


class BufferedInputBase(io.BufferedIOBase):
def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE, **kwargs):
session = kwargs.pop(
's3_session',
boto3.Session(profile_name=kwargs.pop('profile_name', None))
)
s3 = session.resource('s3', **kwargs)
def __init__(self,
bucket,
key,
buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE,
session=None,
profile_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
endpoint_url=None,
):
if session is None:
session = boto3.Session(profile_name=profile_name)
s3 = session.resource('s3', endpoint_url=endpoint_url)
self._object = s3.Object(bucket, key)
self._raw_reader = RawReader(self._object)
self._content_length = self._object.content_length
Expand Down Expand Up @@ -279,14 +307,16 @@ class SeekableBufferedInputBase(BufferedInputBase):

Implements the io.BufferedIOBase interface of the standard library."""

def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE, **kwargs):
session = kwargs.pop(
's3_session',
boto3.Session(profile_name=kwargs.pop('profile_name', None))
)
s3 = session.resource('s3', **kwargs)
self._object = s3.Object(bucket, key)
def __init__(self,
bucket,
key,
buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=BINARY_NEWLINE,
resource=None,
):
if resource is None:
resource = boto3.Session().resource('s3')
self._object = resource.Object(bucket, key)
self._raw_reader = SeekableRawReader(self._object)
self._content_length = self._object.content_length
self._current_pos = 0
Expand Down Expand Up @@ -346,27 +376,33 @@ class BufferedOutputBase(io.BufferedIOBase):

Implements the io.BufferedIOBase interface of the standard library."""

def __init__(self, bucket, key, min_part_size=DEFAULT_MIN_PART_SIZE, s3_upload=None, **kwargs):
def __init__(self,
bucket,
key,
resource=None,
min_part_size=DEFAULT_MIN_PART_SIZE,
multipart_upload_kwargs=None,
):
if min_part_size < MIN_MIN_PART_SIZE:
logger.warning("S3 requires minimum part size >= 5MB; \
multipart upload may fail")

session = kwargs.pop(
's3_session',
boto3.Session(profile_name=kwargs.pop('profile_name', None))
)
s3 = session.resource('s3', **kwargs)
if resource is None:
resource = boto3.Session().resource('s3')

#
# https://stackoverflow.com/questions/26871884/how-can-i-easily-determine-if-a-boto-3-s3-bucket-resource-exists
#
try:
s3.meta.client.head_bucket(Bucket=bucket)
resource.meta.client.head_bucket(Bucket=bucket)
except botocore.client.ClientError:
raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket)
self._object = s3.Object(bucket, key)
self._object = resource.Object(bucket, key)
self._min_part_size = min_part_size
self._mp = self._object.initiate_multipart_upload(**(s3_upload or {}))

if multipart_upload_kwargs is None:
multipart_upload_kwargs = {}
self._mp = self._object.initiate_multipart_upload(**multipart_upload_kwargs)

self._buf = io.BytesIO()
self._total_bytes = 0
Expand Down Expand Up @@ -475,7 +511,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):


def iter_bucket(bucket_name, prefix='', accept_key=lambda key: True,
key_limit=None, workers=16, retries=3):
key_limit=None, workers=16, retries=3, profile_name=None,
aws_access_key_id=None, aws_secret_access_key=None):
"""
Iterate and download all S3 files under `bucket/prefix`, yielding out
`(key, key content)` 2-tuples (generator).
Expand Down Expand Up @@ -511,7 +548,10 @@ def iter_bucket(bucket_name, prefix='', accept_key=lambda key: True,

total_size, key_no = 0, -1
key_iterator = _list_bucket(bucket_name, prefix=prefix, accept_key=accept_key)
download_key = functools.partial(_download_key, bucket_name=bucket_name, retries=retries)
download_key = functools.partial(
_download_key, bucket_name=bucket_name, retries=retries, profile_name=profile_name,
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key,
)

with _create_process_pool(processes=workers) as pool:
result_iterator = pool.imap_unordered(download_key, key_iterator)
Expand Down Expand Up @@ -550,15 +590,20 @@ def _list_bucket(bucket_name, prefix='', accept_key=lambda k: True):
break


def _download_key(key_name, bucket_name=None, retries=3):
def _download_key(key_name, bucket_name=None, retries=3, profile_name=None,
aws_access_key_id=None, aws_secret_access_key=None, endpoint_url=None):
if bucket_name is None:
raise ValueError('bucket_name may not be None')

#
# https://geekpete.com/blog/multithreading-boto3/
#
session = boto3.session.Session()
s3 = session.resource('s3')
session = boto3.Session(
profile_name=profile_name, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)

s3 = session.resource('s3', endpoint_url=endpoint_url)
bucket = s3.Bucket(bucket_name)

# Sometimes, https://github.com/boto/boto/issues/2409 can happen because of network issues on either side.
Expand Down