diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index e14f725a..91fe711f 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -22,6 +22,7 @@ """ import codecs +import collections import logging import os import os.path as P @@ -77,6 +78,34 @@ DEFAULT_ERRORS = 'strict' +Uri = collections.namedtuple( + 'Uri', + ( + 'scheme', + 'uri_path', + 'bucket_id', + 'key_id', + 'port', + 'host', + 'ordinary_calling_format', + 'access_id', + 'access_secret', + ) +) +"""Represents all the options that we parse from user input. + +Some of the above options only make sense for certain protocols, e.g. +bucket_id is only for S3. +""" +# +# Set the default values for all Uri fields to be None. This allows us to only +# specify the relevant fields when constructing a Uri. +# +# https://stackoverflow.com/questions/11351032/namedtuple-and-default-values-for-optional-keyword-arguments +# +Uri.__new__.__defaults__ = (None,) * len(Uri._fields) + + def smart_open(uri, mode="rb", **kw): """ Open the given S3 / HDFS / filesystem file pointed to by `uri` for reading or writing. @@ -231,7 +260,7 @@ def _shortcut_open(uri, mode, **kw): if not isinstance(uri, six.string_types): return None - parsed_uri = ParseUri(uri) + parsed_uri = _parse_uri(uri) if parsed_uri.scheme != 'file': return None @@ -275,7 +304,7 @@ def _open_binary_stream(uri, mode, **kw): # this method just routes the request to classes handling the specific storage # schemes, depending on the URI protocol in `uri` filename = uri.split('/')[-1] - parsed_uri = ParseUri(uri) + parsed_uri = _parse_uri(uri) unsupported = "%r mode not supported for %r scheme" % (mode, parsed_uri.scheme) if parsed_uri.scheme in ("file", ): @@ -349,9 +378,9 @@ def _s3_open_uri(parsed_uri, mode, **kwargs): return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs) -class ParseUri(object): +def _parse_uri(uri_as_string): """ - Parse the given URI. + Parse the given URI from a string. Supported URI schemes are "file", "s3", "s3n", "s3u" and "hdfs". @@ -372,82 +401,105 @@ class ParseUri(object): * ./local/path/file.gz * file:///home/user/file * file:///home/user/file.bz2 - """ - def __init__(self, uri, default_scheme="file"): - """ - Assume `default_scheme` if no scheme given in `uri`. - - """ - if os.name == 'nt': - # urlsplit doesn't work on Windows -- it parses the drive as the scheme... - if '://' not in uri: - # no protocol given => assume a local file - uri = 'file://' + uri - parsed_uri = urlsplit(uri, allow_fragments=False) - self.scheme = parsed_uri.scheme if parsed_uri.scheme else default_scheme - - if self.scheme == "hdfs": - self.uri_path = parsed_uri.netloc + parsed_uri.path - self.uri_path = "/" + self.uri_path.lstrip("/") - - if not self.uri_path: - raise RuntimeError("invalid HDFS URI: %s" % uri) - elif self.scheme == "webhdfs": - self.uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path - if parsed_uri.query: - self.uri_path += "?" + parsed_uri.query - - if not self.uri_path: - raise RuntimeError("invalid WebHDFS URI: %s" % uri) - elif self.scheme in ("s3", "s3n", "s3u"): - self.bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@') - self.key_id = None - self.port = 443 - self.host = boto.config.get('s3', 'host', 's3.amazonaws.com') - self.ordinary_calling_format = False - if len(self.bucket_id) == 1: - # URI without credentials: s3://bucket/object - self.bucket_id, self.key_id = self.bucket_id[0].split('/', 1) - # "None" credentials are interpreted as "look for credentials in other locations" by boto - self.access_id, self.access_secret = None, None - elif len(self.bucket_id) == 2 and len(self.bucket_id[0].split(':')) == 2: - # URI in full format: s3://key:secret@bucket/object - # access key id: [A-Z0-9]{20} - # secret access key: [A-Za-z0-9/+=]{40} - acc, self.bucket_id = self.bucket_id - self.access_id, self.access_secret = acc.split(':') - self.bucket_id, self.key_id = self.bucket_id.split('/', 1) - elif len(self.bucket_id) == 3 and len(self.bucket_id[0].split(':')) == 2: - # or URI in extended format: s3://key:secret@server[:port]@bucket/object - acc, server, self.bucket_id = self.bucket_id - self.access_id, self.access_secret = acc.split(':') - self.bucket_id, self.key_id = self.bucket_id.split('/', 1) - server = server.split(':') - self.ordinary_calling_format = True - self.host = server[0] - if len(server) == 2: - self.port = int(server[1]) - else: - # more than 2 '@' means invalid uri - # Bucket names must be at least 3 and no more than 63 characters long. - # Bucket names must be a series of one or more labels. - # Adjacent labels are separated by a single period (.). - # Bucket names can contain lowercase letters, numbers, and hyphens. - # Each label must start and end with a lowercase letter or a number. - raise RuntimeError("invalid S3 URI: %s" % uri) - elif self.scheme == 'file': - self.uri_path = parsed_uri.netloc + parsed_uri.path - - # '~/tmp' may be expanded to '/Users/username/tmp' - self.uri_path = os.path.expanduser(self.uri_path) - - if not self.uri_path: - raise RuntimeError("invalid file URI: %s" % uri) - elif self.scheme.startswith('http'): - self.uri_path = uri - else: - raise NotImplementedError("unknown URI scheme %r in %r" % (self.scheme, uri)) + if os.name == 'nt': + # urlsplit doesn't work on Windows -- it parses the drive as the scheme... + if '://' not in uri_as_string: + # no protocol given => assume a local file + uri_as_string = 'file://' + uri_as_string + parsed_uri = urlsplit(uri_as_string, allow_fragments=False) + + if parsed_uri.scheme == "hdfs": + return _parse_uri_hdfs(parsed_uri) + elif parsed_uri.scheme == "webhdfs": + return _parse_uri_webhdfs(parsed_uri) + elif parsed_uri.scheme in ("s3", "s3n", "s3u"): + return _parse_uri_s3x(parsed_uri) + elif parsed_uri.scheme in ('file', '', None): + return _parse_uri_file(parsed_uri) + elif parsed_uri.scheme.startswith('http'): + return Uri(scheme=parsed_uri.scheme, uri_path=uri_as_string) + else: + raise NotImplementedError( + "unknown URI scheme %r in %r" % (parsed_uri.scheme, uri_as_string) + ) + + +def _parse_uri_hdfs(parsed_uri): + assert parsed_uri.scheme == 'hdfs' + uri_path = parsed_uri.netloc + parsed_uri.path + uri_path = "/" + uri_path.lstrip("/") + if not uri_path: + raise RuntimeError("invalid HDFS URI: %s" % uri) + return Uri(scheme='hdfs', uri_path=uri_path) + + +def _parse_uri_webhdfs(parsed_uri): + assert parsed_uri.scheme == 'webhdfs' + uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path + if parsed_uri.query: + uri_path += "?" + parsed_uri.query + if not uri_path: + raise RuntimeError("invalid WebHDFS URI: %s" % uri) + return Uri(scheme='webhdfs', uri_path=uri_path) + + +def _parse_uri_s3x(parsed_uri): + assert parsed_uri.scheme in ("s3", "s3n", "s3u") + + bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@') + key_id = None + port = 443 + host = boto.config.get('s3', 'host', 's3.amazonaws.com') + ordinary_calling_format = False + if len(bucket_id) == 1: + # URI without credentials: s3://bucket/object + bucket_id, key_id = bucket_id[0].split('/', 1) + # "None" credentials are interpreted as "look for credentials in other locations" by boto + access_id, access_secret = None, None + elif len(bucket_id) == 2 and len(bucket_id[0].split(':')) == 2: + # URI in full format: s3://key:secret@bucket/object + # access key id: [A-Z0-9]{20} + # secret access key: [A-Za-z0-9/+=]{40} + acc, bucket_id = bucket_id + access_id, access_secret = acc.split(':') + bucket_id, key_id = bucket_id.split('/', 1) + elif len(bucket_id) == 3 and len(bucket_id[0].split(':')) == 2: + # or URI in extended format: s3://key:secret@server[:port]@bucket/object + acc, server, bucket_id = bucket_id + access_id, access_secret = acc.split(':') + bucket_id, key_id = bucket_id.split('/', 1) + server = server.split(':') + ordinary_calling_format = True + host = server[0] + if len(server) == 2: + port = int(server[1]) + else: + # more than 2 '@' means invalid uri + # Bucket names must be at least 3 and no more than 63 characters long. + # Bucket names must be a series of one or more labels. + # Adjacent labels are separated by a single period (.). + # Bucket names can contain lowercase letters, numbers, and hyphens. + # Each label must start and end with a lowercase letter or a number. + raise RuntimeError("invalid S3 URI: %s" % str(parsed_uri)) + + return Uri( + scheme=parsed_uri.scheme, bucket_id=bucket_id, key_id=key_id, + port=port, host=host, ordinary_calling_format=ordinary_calling_format, + access_id=access_id, access_secret=access_secret + ) + + +def _parse_uri_file(parsed_uri): + assert parsed_uri.scheme in (None, '', 'file') + uri_path = parsed_uri.netloc + parsed_uri.path + # '~/tmp' may be expanded to '/Users/username/tmp' + uri_path = os.path.expanduser(uri_path) + + if not uri_path: + raise RuntimeError("invalid file URI: %s" % uri) + + return Uri(scheme='file', uri_path=uri_path) def _make_closing(base, **attrs): diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 11bc399e..da659c43 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -39,20 +39,20 @@ def test_scheme(self): """Do URIs schemes parse correctly?""" # supported schemes for scheme in ("s3", "s3n", "hdfs", "file", "http", "https"): - parsed_uri = smart_open.ParseUri(scheme + "://mybucket/mykey") + parsed_uri = smart_open_lib._parse_uri(scheme + "://mybucket/mykey") self.assertEqual(parsed_uri.scheme, scheme) # unsupported scheme => NotImplementedError - self.assertRaises(NotImplementedError, smart_open.ParseUri, "foobar://mybucket/mykey") + self.assertRaises(NotImplementedError, smart_open_lib._parse_uri, "foobar://mybucket/mykey") # unknown scheme => default_scheme - parsed_uri = smart_open.ParseUri("blah blah") + parsed_uri = smart_open_lib._parse_uri("blah blah") self.assertEqual(parsed_uri.scheme, "file") def test_s3_uri(self): """Do S3 URIs parse correctly?""" # correct uri without credentials - parsed_uri = smart_open.ParseUri("s3://mybucket/mykey") + parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mykey") self.assertEqual(parsed_uri.scheme, "s3") self.assertEqual(parsed_uri.bucket_id, "mybucket") self.assertEqual(parsed_uri.key_id, "mykey") @@ -60,7 +60,7 @@ def test_s3_uri(self): self.assertEqual(parsed_uri.access_secret, None) # correct uri, key contains slash - parsed_uri = smart_open.ParseUri("s3://mybucket/mydir/mykey") + parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mydir/mykey") self.assertEqual(parsed_uri.scheme, "s3") self.assertEqual(parsed_uri.bucket_id, "mybucket") self.assertEqual(parsed_uri.key_id, "mydir/mykey") @@ -68,7 +68,7 @@ def test_s3_uri(self): self.assertEqual(parsed_uri.access_secret, None) # correct uri with credentials - parsed_uri = smart_open.ParseUri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey") + parsed_uri = smart_open_lib._parse_uri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey") self.assertEqual(parsed_uri.scheme, "s3") self.assertEqual(parsed_uri.bucket_id, "mybucket") self.assertEqual(parsed_uri.key_id, "mykey") @@ -76,7 +76,7 @@ def test_s3_uri(self): self.assertEqual(parsed_uri.access_secret, "acces/sse_cr-et") # correct uri, contains credentials - parsed_uri = smart_open.ParseUri("s3://accessid:access/secret@mybucket/mykey") + parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@mybucket/mykey") self.assertEqual(parsed_uri.scheme, "s3") self.assertEqual(parsed_uri.bucket_id, "mybucket") self.assertEqual(parsed_uri.key_id, "mykey") @@ -84,17 +84,17 @@ def test_s3_uri(self): self.assertEqual(parsed_uri.access_secret, "access/secret") # incorrect uri - only two '@' in uri are allowed - self.assertRaises(RuntimeError, smart_open.ParseUri, "s3://access_id@access_secret@mybucket@port/mykey") + self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id@access_secret@mybucket@port/mykey") def test_webhdfs_uri(self): """Do webhdfs URIs parse correctly""" # valid uri, no query - parsed_uri = smart_open.ParseUri("webhdfs://host:port/path/file") + parsed_uri = smart_open_lib._parse_uri("webhdfs://host:port/path/file") self.assertEqual(parsed_uri.scheme, "webhdfs") self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file") # valid uri, with query - parsed_uri = smart_open.ParseUri("webhdfs://host:port/path/file?query_part_1&query_part2") + parsed_uri = smart_open_lib._parse_uri("webhdfs://host:port/path/file?query_part_1&query_part2") self.assertEqual(parsed_uri.scheme, "webhdfs") self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file?query_part_1&query_part2") @@ -825,7 +825,7 @@ def test_r(self): def test_bad_mode(self): """Bad mode should raise and exception.""" - uri = smart_open.ParseUri("s3://bucket/key") + uri = smart_open_lib._parse_uri("s3://bucket/key") self.assertRaises(NotImplementedError, smart_open.smart_open, uri, "x") @mock_s3 @@ -881,7 +881,7 @@ def test_gzip_write_mode(self): """Should always open in binary mode when writing through a codec.""" s3 = boto3.resource('s3') s3.create_bucket(Bucket='bucket') - uri = smart_open.ParseUri("s3://bucket/key.gz") + uri = smart_open_lib._parse_uri("s3://bucket/key.gz") with mock.patch('smart_open.smart_open_s3.open') as mock_open: smart_open.smart_open("s3://bucket/key.gz", "wb")