diff --git a/smart_open/doctools.py b/smart_open/doctools.py new file mode 100644 index 00000000..0f28f356 --- /dev/null +++ b/smart_open/doctools.py @@ -0,0 +1,121 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Radim Rehurek +# +# This code is distributed under the terms and conditions from the MIT License (MIT). +# +"""Common functions for working with docstrings. + +For internal use only. +""" +import inspect +import io + + +def extract_kwargs(docstring): + """Extract keyword argument documentation from a function's docstring. + + Parameters + ---------- + docstring: str + The docstring to extract keyword arguments from. + + Returns + ------- + list of (str, str, list str) + + str + The name of the keyword argument. + str + Its type. + str + Its documentation as a list of lines. + + Notes + ----- + The implementation is rather fragile. It expects the following: + + 1. The parameters are under an underlined Parameters section + 2. Keyword parameters have the literal ", optional" after the type + 3. Names and types are not indented + 4. Descriptions are indented with 4 spaces + 5. The Parameters section ends with an empty line. + + Examples + -------- + + >>> docstring = '''The foo function. + ... Parameters + ... ---------- + ... bar: str, optional + ... This parameter is the bar. + ... baz: int, optional + ... This parameter is the baz. + ... + ... ''' + >>> kwargs = extract_kwargs(docstring) + >>> kwargs[0] + ('bar', 'str, optional', ['This parameter is the bar.']) + + """ + lines = inspect.cleandoc(docstring).split('\n') + retval = [] + + # + # 1. Find the underlined 'Parameters' section + # 2. Once there, continue parsing parameters until we hit an empty line + # + while lines[0] != 'Parameters': + lines.pop(0) + lines.pop(0) + lines.pop(0) + + while lines and lines[0]: + name, type_ = lines.pop(0).split(':', 1) + description = [] + while lines and lines[0].startswith(' '): + description.append(lines.pop(0).strip()) + if 'optional' in type_: + retval.append((name.strip(), type_.strip(), description)) + + return retval + + +def to_docstring(kwargs, lpad=''): + """Reconstruct a docstring from keyword argument info. + + Basically reverses :func:`extract_kwargs`. + + Parameters + ---------- + kwargs: list + Output from the extract_kwargs function + lpad: str, optional + Padding string (from the left). + + Returns + ------- + str + The docstring snippet documenting the keyword arguments. + + Examples + -------- + + >>> kwargs = [ + ... ('bar', 'str, optional', ['This parameter is the bar.']), + ... ('baz', 'int, optional', ['This parameter is the baz.']), + ... ] + >>> print(to_docstring(kwargs), end='') + bar: str, optional + This parameter is the bar. + baz: int, optional + This parameter is the baz. + + """ + buf = io.StringIO() + for name, type_, description in kwargs: + buf.write('%s%s: %s\n' % (lpad, name, type_)) + for line in description: + buf.write('%s %s\n' % (lpad, line)) + return buf.getvalue() diff --git a/smart_open/hdfs.py b/smart_open/hdfs.py index 91313af7..742a25b0 100644 --- a/smart_open/hdfs.py +++ b/smart_open/hdfs.py @@ -6,6 +6,15 @@ logger.addHandler(logging.NullHandler()) +def open(uri, mode): + if mode == 'rb': + return CliRawInputBase(uri) + elif mode == 'wb': + return CliRawOutputBase(uri) + else: + raise NotImplementedError('hdfs support for mode %r not implemented' % mode) + + class CliRawInputBase(io.RawIOBase): """Reads bytes from HDFS via the "hdfs dfs" command-line interface. diff --git a/smart_open/http.py b/smart_open/http.py index 8b93c11b..a3895382 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -23,20 +23,37 @@ """ -class BufferedInputBase(io.BufferedIOBase): - """ - Implement streamed reader from a web site. +def open(uri, mode, kerberos=False, user=None, password=None): + """Implement streamed reader from a web site. + Supports Kerberos and Basic HTTP authentication. + + Parameters + ---------- + url: str + The URL to open. + mode: str + The mode to open using. + kerberos: boolean, optional + If True, will attempt to use the local Kerberos credentials + user: str, optional + The username for authenticating over HTTP + password: str, optional + The password for authenticating over HTTP + + Note + ---- + If neither kerberos or (user, password) are set, will connect unauthenticated. + """ + if mode == 'rb': + return BufferedInputBase(uri, mode, kerberos=kerberos, user=user, password=password) + else: + raise NotImplementedError('http support for mode %r not implemented' % mode) - def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, - kerberos=False, user=None, password=None): - """ - If Kerberos is True, will attempt to use the local Kerberos credentials. - Otherwise, will try to use "basic" HTTP authentication via username/password. - If none of those are set, will connect unauthenticated. - """ +class BufferedInputBase(io.BufferedIOBase): + def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None): if kerberos: import requests_kerberos auth = requests_kerberos.HTTPKerberosAuth() diff --git a/smart_open/s3.py b/smart_open/s3.py index 18691dd7..1bf85e31 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -58,21 +58,59 @@ def make_range_string(start, stop=None): return 'bytes=%d-%d' % (start, stop) -def open(bucket_id, key_id, mode, **kwargs): +def open( + bucket_id, + key_id, + mode, + buffer_size=DEFAULT_BUFFER_SIZE, + min_part_size=DEFAULT_MIN_PART_SIZE, + session=None, + resource_kwargs=dict(), + multipart_upload_kwargs=dict(), + ): + """Open an S3 object for reading or writing. + + Parameters + ---------- + bucket_id: str + The name of the bucket this object resides in. + key_id: str + The name of the key within the bucket. + mode: str + The mode with which to open the object. Must be either rb or wb. + buffer_size: int, optional + The buffer size to use when performing I/O. + min_part_size: int + For writing only. + session: object, optional + The S3 session to use when working with boto3. + resource_kwargs: dict, optional + Keyword arguments to use when creating a new resource. + multipart_upload_kwargs: dict, optional + For writing only. + + """ logger.debug('%r', locals()) if mode not in MODES: raise NotImplementedError('bad mode: %r expected one of %r' % (mode, MODES)) - encoding = kwargs.pop("encoding", "utf-8") - errors = kwargs.pop("errors", None) - newline = kwargs.pop("newline", None) - line_buffering = kwargs.pop("line_buffering", False) - s3_min_part_size = kwargs.pop("s3_min_part_size", DEFAULT_MIN_PART_SIZE) - if mode == READ_BINARY: - fileobj = SeekableBufferedInputBase(bucket_id, key_id, **kwargs) + fileobj = SeekableBufferedInputBase( + bucket_id, + key_id, + buffer_size=buffer_size, + session=session, + resource_kwargs=resource_kwargs, + ) elif mode == WRITE_BINARY: - fileobj = BufferedOutputBase(bucket_id, key_id, min_part_size=s3_min_part_size, **kwargs) + fileobj = BufferedOutputBase( + bucket_id, + key_id, + min_part_size=min_part_size, + session=session, + multipart_upload_kwargs=multipart_upload_kwargs, + resource_kwargs=resource_kwargs, + ) else: assert False, 'unexpected mode: %r' % mode @@ -143,12 +181,10 @@ def read(self, size=-1): class BufferedInputBase(io.BufferedIOBase): def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=BINARY_NEWLINE, **kwargs): - session = kwargs.pop( - 's3_session', - boto3.Session(profile_name=kwargs.pop('profile_name', None)) - ) - s3 = session.resource('s3', **kwargs) + line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=dict()): + if session is None: + session = boto3.Session() + s3 = session.resource('s3', **resource_kwargs) self._object = s3.Object(bucket, key) self._raw_reader = RawReader(self._object) self._content_length = self._object.content_length @@ -284,12 +320,10 @@ class SeekableBufferedInputBase(BufferedInputBase): Implements the io.BufferedIOBase interface of the standard library.""" def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=BINARY_NEWLINE, **kwargs): - session = kwargs.pop( - 's3_session', - boto3.Session(profile_name=kwargs.pop('profile_name', None)) - ) - s3 = session.resource('s3', **kwargs) + line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=dict()): + if session is None: + session = boto3.Session() + s3 = session.resource('s3', **resource_kwargs) self._object = s3.Object(bucket, key) self._raw_reader = SeekableRawReader(self._object) self._content_length = self._object.content_length @@ -350,16 +384,24 @@ class BufferedOutputBase(io.BufferedIOBase): Implements the io.BufferedIOBase interface of the standard library.""" - def __init__(self, bucket, key, min_part_size=DEFAULT_MIN_PART_SIZE, s3_upload=None, **kwargs): + def __init__( + self, + bucket, + key, + min_part_size=DEFAULT_MIN_PART_SIZE, + s3_upload=None, + session=None, + resource_kwargs=dict(), + multipart_upload_kwargs=dict(), + ): if min_part_size < MIN_MIN_PART_SIZE: logger.warning("S3 requires minimum part size >= 5MB; \ multipart upload may fail") - session = kwargs.pop( - 's3_session', - boto3.Session(profile_name=kwargs.pop('profile_name', None)) - ) - s3 = session.resource('s3', **kwargs) + if session is None: + session = boto3.Session() + + s3 = session.resource('s3', **resource_kwargs) # # https://stackoverflow.com/questions/26871884/how-can-i-easily-determine-if-a-boto-3-s3-bucket-resource-exists @@ -370,7 +412,7 @@ def __init__(self, bucket, key, min_part_size=DEFAULT_MIN_PART_SIZE, s3_upload=N raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket) self._object = s3.Object(bucket, key) self._min_part_size = min_part_size - self._mp = self._object.initiate_multipart_upload(**(s3_upload or {})) + self._mp = self._object.initiate_multipart_upload(**multipart_upload_kwargs) self._buf = io.BytesIO() self._total_bytes = 0 diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index a70bcac2..6aac7730 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -8,15 +8,15 @@ """ -Utilities for streaming from several file-like data storages: S3 / HDFS / standard -filesystem / compressed files..., using a single, Pythonic API. +Utilities for streaming to/from several file-like data storages: S3 / HDFS / local +filesystem / compressed files, and many more, using a simple, Pythonic API. The streaming makes heavy use of generators and pipes, to avoid loading full file contents into memory, allowing work with arbitrarily large files. -The main methods are: +The main functions are: -* `smart_open()`, which opens the given file for reading/writing +* `open()`, which opens the given file for reading/writing * `s3_iter_bucket()`, which goes over all keys in an S3 bucket in parallel * `register_compressor()`, which registers callbacks for transparent compressor handling @@ -25,6 +25,7 @@ import codecs import collections import logging +import inspect import os import os.path as P import importlib @@ -41,8 +42,9 @@ except ImportError: PATHLIB_SUPPORT = False +import boto +import boto3 from boto.compat import BytesIO, urlsplit, six -import boto.s3.key import six from six.moves.urllib import parse as urlparse import sys @@ -59,20 +61,14 @@ import smart_open.http as smart_open_http import smart_open.ssh as smart_open_ssh +from smart_open import doctools logger = logging.getLogger(__name__) SYSTEM_ENCODING = sys.getdefaultencoding() -_ISSUE_146_FSTR = ( - "You have explicitly specified encoding=%(encoding)s, but smart_open does " - "not currently support decoding text via the %(scheme)s scheme. " - "Re-open the file without specifying an encoding to suppress this warning." -) _ISSUE_189_URL = 'https://github.com/RaRe-Technologies/smart_open/issues/189' -DEFAULT_ERRORS = 'strict' - _COMPRESSOR_REGISTRY = {} @@ -164,68 +160,161 @@ def _handle_xz(file_obj, mode): Uri.__new__.__defaults__ = (None,) * len(Uri._fields) -def smart_open(uri, mode="rb", **kw): +def _inspect_kwargs(kallable): + args, varargs, keywords, defaults = inspect.getargspec(kallable) + if not defaults: + return {} + supported_keywords = args[-len(defaults):] + return dict(zip(supported_keywords, defaults)) + + +def _check_kwargs(kallable, kwargs): + """Check which keyword arguments the callable supports. + + Parameters + ---------- + kallable: callable + A function or method to test + kwargs: dict + The keyword arguments to check. If the callable doesn't support any + of these, a warning message will get printed. + + Returns + ------- + dict + A dictionary of argument names and values supported by the callable. """ - Open the given S3 / HDFS / filesystem file pointed to by `uri` for reading or writing. + supported_keywords = sorted(_inspect_kwargs(kallable)) + unsupported_keywords = [k for k in sorted(kwargs) if k not in supported_keywords] + supported_kwargs = {k: v for (k, v) in kwargs.items() if k in supported_keywords} - The only supported modes for now are 'rb' (read, default) and 'wb' (replace & write). + if unsupported_keywords: + logger.warn('ignoring unsupported keyword arguments: %r', unsupported_keywords) - The reads/writes are memory efficient (streamed) and therefore suitable for - arbitrarily large files. + return supported_kwargs - The `uri` can be either: - 1. a URI for the local filesystem (compressed ``.gz``, ``.bz2`` or ``.xz`` files handled - automatically): `./lines.txt`, `/home/joe/lines.txt.gz`, `file:///home/joe/lines.txt.bz2` +_builtin_open = open + + +def open( + uri, + mode='r', + buffering=-1, + encoding=None, + errors=None, + newline=None, + closefd=True, + opener=None, + ignore_ext=False, + tkwa=dict(), + ): + """Open the URI object, returning a file-like object. + + The URI is usually a string in a variety of formats: + + 1. a URI for the local filesystem: `./lines.txt`, `/home/joe/lines.txt.gz`, `file:///home/joe/lines.txt.bz2` 2. a URI for HDFS: `hdfs:///some/path/lines.txt` 3. a URI for Amazon's S3 (can also supply credentials inside the URI): `s3://my_bucket/lines.txt`, `s3://my_aws_key_id:key_secret@my_bucket/lines.txt` - 4. an instance of the boto.s3.key.Key class. - 5. an instance of the pathlib.Path class. - - Examples:: - - >>> # stream lines from http; you can use context managers too: - >>> with smart_open.smart_open('http://www.google.com') as fin: - ... for line in fin: - ... print line - - >>> # stream lines from S3; you can use context managers too: - >>> with smart_open.smart_open('s3://mybucket/mykey.txt') as fin: - ... for line in fin: - ... print line - - >>> # you can also use a boto.s3.key.Key instance directly: - >>> key = boto.connect_s3().get_bucket("my_bucket").get_key("my_key") - >>> with smart_open.smart_open(key) as fin: - ... for line in fin: - ... print line - - >>> # stream line-by-line from an HDFS file - >>> for line in smart_open.smart_open('hdfs:///user/hadoop/my_file.txt'): - ... print line - - >>> # stream content *into* S3: - >>> with smart_open.smart_open('s3://mybucket/mykey.txt', 'wb') as fout: - ... for line in ['first line', 'second line', 'third line']: - ... fout.write(line + '\n') - - >>> # stream from/to (compressed) local files: - >>> for line in smart_open.smart_open('/home/radim/my_file.txt'): - ... print line - >>> for line in smart_open.smart_open('/home/radim/my_file.txt.gz'): - ... print line - >>> with smart_open.smart_open('/home/radim/my_file.txt.gz', 'wb') as fout: - ... fout.write("hello world!\n") - >>> with smart_open.smart_open('/home/radim/another.txt.bz2', 'wb') as fout: - ... fout.write("good bye!\n") - >>> with smart_open.smart_open('/home/radim/another.txt.xz', 'wb') as fout: - ... fout.write("never say never!\n") - >>> # stream from/to (compressed) local files with Expand ~ and ~user constructions: - >>> for line in smart_open.smart_open('~/my_file.txt'): - ... print line - >>> for line in smart_open.smart_open('my_file.txt'): - ... print line + + The URI may also be one of: + + - an instance of the pathlib.Path class + - a stream (anything that implements io.IOBase-like functionality) + + This function supports transparent compression and decompression using the + following codec: + + - ``.gz`` + - ``.bz2`` + - ``.xz`` + + The function depends on the file extension to determine the appropriate codec. + + Parameters + ---------- + uri: str or object + The object to open. + mode: str, optional + Mimicks built-in open parameter of the same name. + buffering: int, optional + Mimicks built-in open parameter of the same name. + encoding: str, optional + Mimicks built-in open parameter of the same name. + errors: str, optional + Mimicks built-in open parameter of the same name. + newline: str, optional + Mimicks built-in open parameter of the same name. + closefd: boolean, optional + Mimicks built-in open parameter of the same name. Ignored. + opener: object, optional + Mimicks built-in open parameter of the same name. Ignored. + ignore_ext: boolean, optional + Disable transparent compression/decompression based on the file extension. + tkwa: dict + Keyword arguments for the transport layer (see notes below). + + Returns + ------- + A file-like object. + + Notes + ----- + smart_open has several implementations for its transport layer (e.g. S3, HTTP). + Each transport layer has a different set of keyword arguments for overriding + default behavior. If you specify a keyword argument that is *not* supported + by the transport layer being used, smart_open will ignore that argument and + log a warning message. + + S3 (for details, see :mod:`smart_open.s3` and :func:`smart_open.s3.open`): + +%(s3)s + HTTP (for details, see :mod:`smart_open.http` and :func:`smart_open.http.open`): + +%(http)s + WebHDFS (for details, see :mod:`smart_open.webhdfs` and :func:`smart_open.webhdfs.open`): + +%(webhdfs)s + + Examples + -------- + >>> from smart_open import open + >>> # stream lines from http; you can use context managers too: + >>> with open('http://www.google.com') as fin: + ... for line in fin: + ... print(line) + + >>> # stream lines from S3; you can use context managers too: + >>> with open('s3://mybucket/mykey.txt') as fin: + ... for line in fin: + ... print(line) + + >>> # stream line-by-line from an HDFS file + >>> for line in open('hdfs:///user/hadoop/my_file.txt'): + ... print(line) + + >>> # stream content *into* S3: + >>> with open('s3://mybucket/mykey.txt', 'wb') as fout: + ... for line in ['first line', 'second line', 'third line']: + ... fout.write(line + '\n') + + >>> # stream from/to (compressed) local files: + >>> for line in open('/home/radim/my_file.txt'): + ... print(line) + >>> for line in open('/home/radim/my_file.txt.gz'): + ... print(line) + >>> with open('/home/radim/my_file.txt.gz', 'wb') as fout: + ... fout.write("hello world!\n") + >>> with open('/home/radim/another.txt.bz2', 'wb') as fout: + ... fout.write("good bye!\n") + >>> with open('/home/radim/another.txt.xz', 'wb') as fout: + ... fout.write("never say never!\n") + >>> # stream from/to (compressed) local files with Expand ~ and ~user constructions: + >>> for line in open('~/my_file.txt'): + ... print(line) + >>> for line in open('my_file.txt'): + ... print(line) """ logger.debug('%r', locals()) @@ -233,7 +322,14 @@ def smart_open(uri, mode="rb", **kw): if not isinstance(mode, six.string_types): raise TypeError('mode should be a string') - fobj = _shortcut_open(uri, mode, **kw) + fobj = _shortcut_open( + uri, + mode, + ignore_ext=ignore_ext, + buffering=buffering, + encoding=encoding, + errors=errors, + ) if fobj is not None: return fobj @@ -245,25 +341,15 @@ def smart_open(uri, mode="rb", **kw): # If we change the default mode to be text, and match the normal behavior # of Py2 and 3, then the above assumption will be unnecessary. # - if kw.get('encoding') is not None and 'b' in mode: + if encoding is not None and 'b' in mode: mode = mode.replace('b', '') # Support opening ``pathlib.Path`` objects by casting them to strings. if PATHLIB_SUPPORT and isinstance(uri, pathlib.Path): uri = str(uri) - # - # Our API is very liberal with keyword arguments, making it a bit hard to - # manage them. Capture the keyword arguments we'll be using in this - # function in advance to reduce the confusion in downstream functions. - # - # explicit_encoding is what we've been explicitly told to use. encoding is - # what we'll actually end up using. The two may be different if the user - # didn't actually specify the encoding. - # - ignore_extension = kw.pop('ignore_extension', False) - explicit_encoding = kw.get('encoding', None) - encoding = kw.pop('encoding', SYSTEM_ENCODING) + explicit_encoding = encoding + encoding = explicit_encoding if explicit_encoding else SYSTEM_ENCODING # # This is how we get from the filename to the end result. Decompression is @@ -283,14 +369,13 @@ def smart_open(uri, mode="rb", **kw): 'a': 'ab', 'a+': 'ab+'}[mode] except KeyError: binary_mode = mode - binary, filename = _open_binary_stream(uri, binary_mode, **kw) - if ignore_extension: + binary, filename = _open_binary_stream(uri, binary_mode, tkwa) + if ignore_ext: decompressed = binary else: decompressed = _compression_wrapper(binary, filename, mode) if 'b' not in mode or explicit_encoding is not None: - errors = kw.pop('errors', 'strict') decoded = _encoding_wrapper(decompressed, mode, encoding=encoding, errors=errors) else: decoded = decompressed @@ -298,7 +383,60 @@ def smart_open(uri, mode="rb", **kw): return decoded -def _shortcut_open(uri, mode, **kw): +# +# Inject transport keyword argument documentation into the docstring. +# +open.__doc__ = open.__doc__ % { + 's3': doctools.to_docstring( + doctools.extract_kwargs(smart_open_s3.open.__doc__), + lpad=u' ', + ), + 'http': doctools.to_docstring( + doctools.extract_kwargs(smart_open_http.open.__doc__), + lpad=u' ', + ), + 'webhdfs': doctools.to_docstring( + doctools.extract_kwargs(smart_open_webhdfs.open.__doc__), + lpad=u' ', + ), +} + + +def smart_open(uri, mode="rb", **kw): + """Deprecated, use smart_open.open instead.""" + logger.warning('this function is deprecated, use smart_open.open instead') + + # + # The new function uses a shorter name for this parameter, handle it separately. + # + ignore_extension = kw.pop('ignore_extension', False) + + expected_kwargs = _inspect_kwargs(open) + scrubbed_kwargs = {} + tkwa = {} + for key, value in kw.items(): + if key in expected_kwargs: + scrubbed_kwargs[key] = value + else: + # + # Assume that anything not explicitly supported by the new function + # is a transport layer keyword argument. This is safe, because if + # the argument ends up being unsupported in the transport layer, + # it will only cause a logging warning, not a crash. + # + tkwa[key] = value + + return open(uri, mode, ignore_ext=ignore_extension, tkwa=tkwa, **scrubbed_kwargs) + + +def _shortcut_open( + uri, + mode, + ignore_ext=False, + buffering=-1, + encoding=None, + errors=None, + ): """Try to open the URI using the standard library io.open function. This can be much faster than the alternative of opening in binary mode and @@ -325,27 +463,21 @@ def _shortcut_open(uri, mode, **kw): return None _, extension = P.splitext(parsed_uri.uri_path) - ignore_extension = kw.get('ignore_extension', False) - if extension in _COMPRESSOR_REGISTRY and not ignore_extension: + if extension in _COMPRESSOR_REGISTRY and not ignore_ext: return None - # - # https://docs.python.org/2/library/functions.html#open - # - # buffering: 0: off; 1: on; negative number: use system default - # - buffering = kw.get('buffering', -1) - open_kwargs = {} - errors = kw.get('errors') - if errors is not None: - open_kwargs['errors'] = errors - encoding = kw.get('encoding') if encoding is not None: open_kwargs['encoding'] = encoding mode = mode.replace('b', '') + # + # binary mode of the builtin/stdlib open function doesn't take an errors argument + # + if errors and 'b' not in mode: + open_kwargs['errors'] = errors + # # Under Py3, the built-in open accepts kwargs, and it's OK to use that. # Under Py2, the built-in open _doesn't_ accept kwargs, but we still use it @@ -353,20 +485,20 @@ def _shortcut_open(uri, mode, **kw): # kwargs, then we have no option other to use io.open. # if six.PY3: - return open(parsed_uri.uri_path, mode, buffering=buffering, **open_kwargs) + return _builtin_open(parsed_uri.uri_path, mode, buffering=buffering, **open_kwargs) elif not open_kwargs: - return open(parsed_uri.uri_path, mode, buffering=buffering) + return _builtin_open(parsed_uri.uri_path, mode, buffering=buffering) return io.open(parsed_uri.uri_path, mode, buffering=buffering, **open_kwargs) -def _open_binary_stream(uri, mode, **kw): +def _open_binary_stream(uri, mode, tkwa): """Open an arbitrary URI in the specified binary mode. Not all modes are supported for all protocols. :arg uri: The URI to open. May be a string, or something else. :arg str mode: The mode to open with. Must be rb, wb or ab. - :arg kw: TODO: document this. + :arg tkwa: Keyword argumens for the transport layer. :returns: A file object and the filename :rtype: tuple """ @@ -384,9 +516,7 @@ def _open_binary_stream(uri, mode, **kw): parsed_uri = _parse_uri(uri) unsupported = "%r mode not supported for %r scheme" % (mode, parsed_uri.scheme) - if parsed_uri.scheme in ("file", ): - # local files -- both read & write supported - # compression, if any, is determined by the filename extension (.gz, .bz2, .xz) + if parsed_uri.scheme == "file": fobj = io.open(parsed_uri.uri_path, mode) return fobj, filename elif parsed_uri.scheme in smart_open_ssh.SCHEMES: @@ -399,43 +529,23 @@ def _open_binary_stream(uri, mode, **kw): ) return fobj, filename elif parsed_uri.scheme in smart_open_s3.SUPPORTED_SCHEMES: - return _s3_open_uri(parsed_uri, mode, **kw), filename - elif parsed_uri.scheme in ("hdfs", ): - if mode == 'rb': - return smart_open_hdfs.CliRawInputBase(parsed_uri.uri_path), filename - elif mode == 'wb': - return smart_open_hdfs.CliRawOutputBase(parsed_uri.uri_path), filename - else: - raise NotImplementedError(unsupported) - elif parsed_uri.scheme in ("webhdfs", ): - if mode == 'rb': - fobj = smart_open_webhdfs.BufferedInputBase(parsed_uri.uri_path, **kw) - elif mode == 'wb': - fobj = smart_open_webhdfs.BufferedOutputBase(parsed_uri.uri_path, **kw) - else: - raise NotImplementedError(unsupported) - return fobj, filename + return _s3_open_uri(parsed_uri, mode, tkwa), filename + elif parsed_uri.scheme == "hdfs": + _check_kwargs(smart_open_hdfs.open, tkwa) + return smart_open_hdfs.open(parsed_uri.uri_path, mode), filename + elif parsed_uri.scheme == "webhdfs": + kw = _check_kwargs(smart_open_webhdfs.open, tkwa) + return smart_open_webhdfs.open(parsed_uri.uri_path, mode, **kw), filename elif parsed_uri.scheme.startswith('http'): # # The URI may contain a query string and fragments, which interfere - # with out compressed/uncompressed estimation. + # with our compressed/uncompressed estimation, so we strip them. # filename = P.basename(urlparse.urlparse(uri).path) - if mode == 'rb': - return smart_open_http.SeekableBufferedInputBase(uri, **kw), filename - else: - raise NotImplementedError(unsupported) + kw = _check_kwargs(smart_open_http.open, tkwa) + return smart_open_http.open(uri, mode, **kw), filename else: raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme) - elif isinstance(uri, boto.s3.key.Key): - logger.debug('%r', locals()) - # - # TODO: handle boto3 keys as well - # - host = kw.pop('host', None) - if host is not None: - kw['endpoint_url'] = _add_scheme_to_host(host) - return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name elif hasattr(uri, 'read'): # simply pass-through if already a file-like # we need to return something as the file name, but we don't know what @@ -448,23 +558,36 @@ def _open_binary_stream(uri, mode, **kw): raise TypeError("don't know how to handle uri %r" % uri) -def _s3_open_uri(parsed_uri, mode, **kwargs): +def _s3_open_uri(parsed_uri, mode, tkwa): logger.debug('s3_open_uri: %r', locals()) if mode in ('r', 'w'): raise ValueError('this function can only open binary streams. ' 'Use smart_open.smart_open() to open text streams.') elif mode not in ('rb', 'wb'): raise NotImplementedError('unsupported mode: %r', mode) - if parsed_uri.access_id is not None: - kwargs['aws_access_key_id'] = parsed_uri.access_id - if parsed_uri.access_secret is not None: - kwargs['aws_secret_access_key'] = parsed_uri.access_secret - # Get an S3 host. It is required for sigv4 operations. - host = kwargs.pop('host', None) - if host is not None: - kwargs['endpoint_url'] = _add_scheme_to_host(host) + # + # There are two explicit ways we can receive session parameters from the user. + # + # 1. Via the session keyword argument (tkwa) + # 2. Via the URI itself + # + # They are not mutually exclusive, but we have to pick one of the two. + # Go with 1). + # + if tkwa.get('session') is not None and (parsed_uri.access_id or parsed_uri.access_secret): + logger.warning( + 'ignoring credentials parsed from URL because ' + 'they conflict with tkwa.session. Set tkwa.session to None to ' + 'suppress this warning.' + ) + elif (parsed_uri.access_id and parsed_uri.access_secret): + tkwa['session'] = boto3.Session( + aws_access_key_id=parsed_uri.access_id, + aws_secret_access_key=parsed_uri.access_secret, + ) + kwargs = _check_kwargs(smart_open_s3.open, tkwa) return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs) @@ -679,7 +802,7 @@ def _compression_wrapper(file_obj, filename, mode): return callback(file_obj, mode) -def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS): +def _encoding_wrapper(fileobj, mode, encoding=None, errors=None): """Decode bytes into text, if necessary. If mode specifies binary access, does nothing, unless the encoding is @@ -708,13 +831,9 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS): if encoding is None: encoding = SYSTEM_ENCODING + kw = {'errors': errors} if errors else {} if mode[0] == 'r' or mode.endswith('+'): - fileobj = codecs.getreader(encoding)(fileobj, errors=errors) + fileobj = codecs.getreader(encoding)(fileobj, **kw) if mode[0] in ('w', 'a') or mode.endswith('+'): - fileobj = codecs.getwriter(encoding)(fileobj, errors=errors) + fileobj = codecs.getwriter(encoding)(fileobj, **kw) return fileobj - -def _add_scheme_to_host(host): - if host.startswith('http://') or host.startswith('https://'): - return host - return 'http://' + host diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 915ad57e..2783b71b 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -391,7 +391,7 @@ def test_open_side_effect(self): # See the _shortcut_open function for details. # _IO_OPEN = 'io.open' -_BUILTIN_OPEN = 'smart_open.smart_open_lib.open' +_BUILTIN_OPEN = 'smart_open.smart_open_lib._builtin_open' class SmartOpenReadTest(unittest.TestCase): @@ -402,7 +402,7 @@ class SmartOpenReadTest(unittest.TestCase): def test_shortcut(self): fpath = os.path.join(CURR_DIR, 'test_data/crime-and-punishment.txt') - with mock.patch('smart_open.smart_open_lib.open') as mock_open: + with mock.patch('smart_open.smart_open_lib._builtin_open') as mock_open: smart_open.smart_open(fpath, 'r').read() mock_open.assert_called_with(fpath, 'r', buffering=-1) @@ -510,7 +510,7 @@ def test_s3_iter_lines(self): self.assertEqual(b''.join(output), test_string) # TODO: add more complex test for file:// - @mock.patch('smart_open.smart_open_lib.open') + @mock.patch('smart_open.smart_open_lib._builtin_open') def test_file(self, mock_smart_open): """Is file:// line iterator called correctly?""" prefix = "file://" @@ -687,37 +687,37 @@ class SmartOpenS3KwargsTest(unittest.TestCase): @mock.patch('boto3.Session') def test_no_kwargs(self, mock_session): smart_open.smart_open('s3://mybucket/mykey') - mock_session.assert_called_with(profile_name=None) mock_session.return_value.resource.assert_called_with('s3') @mock.patch('boto3.Session') def test_credentials(self, mock_session): smart_open.smart_open('s3://access_id:access_secret@mybucket/mykey') - mock_session.assert_called_with(profile_name=None) - mock_session.return_value.resource.assert_called_with( - 's3', aws_access_key_id='access_id', aws_secret_access_key='access_secret' - ) - - @mock.patch('boto3.Session') - def test_profile(self, mock_session): - smart_open.smart_open('s3://mybucket/mykey', profile_name='my_credentials') - mock_session.assert_called_with(profile_name='my_credentials') + mock_session.assert_called_with(aws_access_key_id='access_id', aws_secret_access_key='access_secret') mock_session.return_value.resource.assert_called_with('s3') @mock.patch('boto3.Session') def test_host(self, mock_session): - smart_open.smart_open("s3://access_id:access_secret@mybucket/mykey", host='aa.domain.com') + tkwa = {'resource_kwargs': {'endpoint_url': 'http://aa.domain.com'}} + smart_open.open("s3://access_id:access_secret@mybucket/mykey", tkwa=tkwa) + mock_session.assert_called_with( + aws_access_key_id='access_id', + aws_secret_access_key='access_secret', + ) mock_session.return_value.resource.assert_called_with( - 's3', aws_access_key_id='access_id', aws_secret_access_key='access_secret', - endpoint_url='http://aa.domain.com' + 's3', + endpoint_url='http://aa.domain.com', ) @mock.patch('boto3.Session') def test_s3_upload(self, mock_session): - smart_open.smart_open("s3://bucket/key", 'wb', s3_upload={ - 'ServerSideEncryption': 'AES256', - 'ContentType': 'application/json' - }) + smart_open.open( + "s3://bucket/key", 'wb', tkwa={ + 'multipart_upload_kwargs': { + 'ServerSideEncryption': 'AES256', + 'ContentType': 'application/json', + } + } + ) # Locate the s3.Object instance (mock) s3_resource = mock_session.return_value.resource.return_value @@ -737,7 +737,7 @@ def test_session_read_mode(self): session = boto3.Session() session.resource = mock.MagicMock() - smart_open.smart_open('s3://bucket/key', s3_session=session) + smart_open.open('s3://bucket/key', tkwa={'session': session}) session.resource.assert_called_with('s3') def test_session_write_mode(self): @@ -747,7 +747,7 @@ def test_session_write_mode(self): session = boto3.Session() session.resource = mock.MagicMock() - smart_open.smart_open('s3://bucket/key', 'wb', s3_session=session) + smart_open.open('s3://bucket/key', 'wb', tkwa={'session': session}) session.resource.assert_called_with('s3') @@ -834,7 +834,8 @@ def test_s3_mode_mock(self, mock_session): """Are s3:// open modes passed correctly?""" # correct write mode, correct s3 URI - smart_open.smart_open("s3://mybucket/mykey", "w", host='s3.amazonaws.com') + tkwa = {'resource_kwargs': {'endpoint_url': 'http://s3.amazonaws.com'}} + smart_open.open("s3://mybucket/mykey", "w", tkwa=tkwa) mock_session.return_value.resource.assert_called_with( 's3', endpoint_url='http://s3.amazonaws.com' ) @@ -888,11 +889,13 @@ def test_s3_metadata_write(self): s3.create_bucket(Bucket='mybucket') # Write data, with multipart_upload options - write_stream = smart_open.smart_open( + write_stream = smart_open.open( 's3://mybucket/crime-and-punishment.txt.gz', 'wb', - s3_upload={ - 'ContentType': 'text/plain', - 'ContentEncoding': 'gzip' + tkwa={ + 'multipart_upload_kwargs': { + 'ContentType': 'text/plain', + 'ContentEncoding': 'gzip', + } } ) with write_stream as fout: @@ -1188,6 +1191,7 @@ def test_rw_gzip(self): self.assertEqual(fin.read().decode("utf-8"), text) @mock_s3 + @mock.patch('smart_open.smart_open_lib._inspect_kwargs', mock.Mock(return_value={})) def test_gzip_write_mode(self): """Should always open in binary mode when writing through a codec.""" s3 = boto3.resource('s3') @@ -1199,6 +1203,7 @@ def test_gzip_write_mode(self): mock_open.assert_called_with('bucket', 'key.gz', 'wb') @mock_s3 + @mock.patch('smart_open.smart_open_lib._inspect_kwargs', mock.Mock(return_value={})) def test_gzip_read_mode(self): """Should always open in binary mode when reading through a codec.""" s3 = boto3.resource('s3') @@ -1294,31 +1299,6 @@ def test_write_text_gzip(self): actual = fin.read() self.assertEqual(text, actual) -class HostNameTest(unittest.TestCase): - - def test_host_name_with_http(self): - host = 'http://a.com/b' - expected = 'http://a.com/b' - res = smart_open_lib._add_scheme_to_host(host) - self.assertEqual(expected, res) - - def test_host_name_without_http(self): - host = 'a.com/b' - expected = 'http://a.com/b' - res = smart_open_lib._add_scheme_to_host(host) - self.assertEqual(expected, res) - - def test_host_name_with_https(self): - host = 'https://a.com/b' - expected = 'https://a.com/b' - res = smart_open_lib._add_scheme_to_host(host) - self.assertEqual(expected, res) - - def test_host_name_without_http_prefix(self): - host = 'httpa.com/b' - expected = 'http://httpa.com/b' - res = smart_open_lib._add_scheme_to_host(host) - self.assertEqual(expected, res) if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) diff --git a/smart_open/webhdfs.py b/smart_open/webhdfs.py index d2a46a36..37a0f826 100644 --- a/smart_open/webhdfs.py +++ b/smart_open/webhdfs.py @@ -15,6 +15,22 @@ WEBHDFS_MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads +def open(uri, mode, min_part_size=WEBHDFS_MIN_PART_SIZE): + """ + Parameters + ---------- + min_part_size: int + For writing only. + + """ + if mode == 'rb': + return BufferedInputBase(uri) + elif mode == 'wb': + return BufferedOutputBase(uri, min_part_size=min_part_size) + else: + raise NotImplementedError('webhdfs support for mode %r not implemented' % mode) + + class BufferedInputBase(io.BufferedIOBase): def __init__(self, uri): self._uri = uri @@ -84,6 +100,13 @@ def readline(self): class BufferedOutputBase(io.BufferedIOBase): def __init__(self, uri_path, min_part_size=WEBHDFS_MIN_PART_SIZE): + """ + Parameters + ---------- + min_part_size: int, optional + For writing only. + + """ self.uri_path = uri_path self._closed = False self.min_part_size = min_part_size