Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ParsedUri class with functions, cleanup internal argument parsing #191

Merged
merged 2 commits into from Apr 23, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
203 changes: 124 additions & 79 deletions smart_open/smart_open_lib.py
Expand Up @@ -22,6 +22,7 @@
"""

import codecs
import collections
import logging
import os
import os.path as P
Expand Down Expand Up @@ -77,6 +78,27 @@
DEFAULT_ERRORS = 'strict'


Uri = collections.namedtuple(
'Uri',
(
'scheme',
'uri_path',
'bucket_id',
'key_id',
'port',
'host',
'ordinary_calling_format',
'access_id',
'access_secret',
)
)
"""Represents all the options that we parse from user input.

Some of the above options only make sense for certain protocols, e.g.
bucket_id is only for S3."""
Uri.__new__.__defaults__ = (None,) * len(Uri._fields)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely need a comment here, too hacky



def smart_open(uri, mode="rb", **kw):
"""
Open the given S3 / HDFS / filesystem file pointed to by `uri` for reading or writing.
Expand Down Expand Up @@ -231,7 +253,7 @@ def _shortcut_open(uri, mode, **kw):
if not isinstance(uri, six.string_types):
return None

parsed_uri = ParseUri(uri)
parsed_uri = _parse_uri(uri)
if parsed_uri.scheme != 'file':
return None

Expand Down Expand Up @@ -275,7 +297,7 @@ def _open_binary_stream(uri, mode, **kw):
# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
filename = uri.split('/')[-1]
parsed_uri = ParseUri(uri)
parsed_uri = _parse_uri(uri)
unsupported = "%r mode not supported for %r scheme" % (mode, parsed_uri.scheme)

if parsed_uri.scheme in ("file", ):
Expand Down Expand Up @@ -349,9 +371,9 @@ def _s3_open_uri(parsed_uri, mode, **kwargs):
return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs)


class ParseUri(object):
def _parse_uri(uri_as_string):
"""
Parse the given URI.
Parse the given URI from a string.

Supported URI schemes are "file", "s3", "s3n", "s3u" and "hdfs".

Expand All @@ -372,82 +394,105 @@ class ParseUri(object):
* ./local/path/file.gz
* file:///home/user/file
* file:///home/user/file.bz2

"""
def __init__(self, uri, default_scheme="file"):
"""
Assume `default_scheme` if no scheme given in `uri`.

"""
if os.name == 'nt':
# urlsplit doesn't work on Windows -- it parses the drive as the scheme...
if '://' not in uri:
# no protocol given => assume a local file
uri = 'file://' + uri
parsed_uri = urlsplit(uri, allow_fragments=False)
self.scheme = parsed_uri.scheme if parsed_uri.scheme else default_scheme

if self.scheme == "hdfs":
self.uri_path = parsed_uri.netloc + parsed_uri.path
self.uri_path = "/" + self.uri_path.lstrip("/")

if not self.uri_path:
raise RuntimeError("invalid HDFS URI: %s" % uri)
elif self.scheme == "webhdfs":
self.uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path
if parsed_uri.query:
self.uri_path += "?" + parsed_uri.query

if not self.uri_path:
raise RuntimeError("invalid WebHDFS URI: %s" % uri)
elif self.scheme in ("s3", "s3n", "s3u"):
self.bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@')
self.key_id = None
self.port = 443
self.host = boto.config.get('s3', 'host', 's3.amazonaws.com')
self.ordinary_calling_format = False
if len(self.bucket_id) == 1:
# URI without credentials: s3://bucket/object
self.bucket_id, self.key_id = self.bucket_id[0].split('/', 1)
# "None" credentials are interpreted as "look for credentials in other locations" by boto
self.access_id, self.access_secret = None, None
elif len(self.bucket_id) == 2 and len(self.bucket_id[0].split(':')) == 2:
# URI in full format: s3://key:secret@bucket/object
# access key id: [A-Z0-9]{20}
# secret access key: [A-Za-z0-9/+=]{40}
acc, self.bucket_id = self.bucket_id
self.access_id, self.access_secret = acc.split(':')
self.bucket_id, self.key_id = self.bucket_id.split('/', 1)
elif len(self.bucket_id) == 3 and len(self.bucket_id[0].split(':')) == 2:
# or URI in extended format: s3://key:secret@server[:port]@bucket/object
acc, server, self.bucket_id = self.bucket_id
self.access_id, self.access_secret = acc.split(':')
self.bucket_id, self.key_id = self.bucket_id.split('/', 1)
server = server.split(':')
self.ordinary_calling_format = True
self.host = server[0]
if len(server) == 2:
self.port = int(server[1])
else:
# more than 2 '@' means invalid uri
# Bucket names must be at least 3 and no more than 63 characters long.
# Bucket names must be a series of one or more labels.
# Adjacent labels are separated by a single period (.).
# Bucket names can contain lowercase letters, numbers, and hyphens.
# Each label must start and end with a lowercase letter or a number.
raise RuntimeError("invalid S3 URI: %s" % uri)
elif self.scheme == 'file':
self.uri_path = parsed_uri.netloc + parsed_uri.path

# '~/tmp' may be expanded to '/Users/username/tmp'
self.uri_path = os.path.expanduser(self.uri_path)

if not self.uri_path:
raise RuntimeError("invalid file URI: %s" % uri)
elif self.scheme.startswith('http'):
self.uri_path = uri
else:
raise NotImplementedError("unknown URI scheme %r in %r" % (self.scheme, uri))
if os.name == 'nt':
# urlsplit doesn't work on Windows -- it parses the drive as the scheme...
if '://' not in uri_as_string:
# no protocol given => assume a local file
uri_as_string = 'file://' + uri_as_string
parsed_uri = urlsplit(uri_as_string, allow_fragments=False)

if parsed_uri.scheme == "hdfs":
return _parse_uri_hdfs(parsed_uri)
elif parsed_uri.scheme == "webhdfs":
return _parse_uri_webhdfs(parsed_uri)
elif parsed_uri.scheme in ("s3", "s3n", "s3u"):
return _parse_uri_s3x(parsed_uri)
elif parsed_uri.scheme in ('file', '', None):
return _parse_uri_file(parsed_uri)
elif parsed_uri.scheme.startswith('http'):
return Uri(scheme=parsed_uri.scheme, uri_path=uri_as_string)
else:
raise NotImplementedError(
"unknown URI scheme %r in %r" % (parsed_uri.scheme, uri_as_string)
)


def _parse_uri_hdfs(parsed_uri):
assert parsed_uri.scheme == 'hdfs'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why asserts needed (you check exact same conditions before in if/elif), here and everywhere? Sanity-check only?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a sanity check only.

uri_path = parsed_uri.netloc + parsed_uri.path
uri_path = "/" + uri_path.lstrip("/")
if not uri_path:
raise RuntimeError("invalid HDFS URI: %s" % uri)
return Uri(scheme='hdfs', uri_path=uri_path)


def _parse_uri_webhdfs(parsed_uri):
assert parsed_uri.scheme == 'webhdfs'
uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path
if parsed_uri.query:
uri_path += "?" + parsed_uri.query
if not uri_path:
raise RuntimeError("invalid WebHDFS URI: %s" % uri)
return Uri(scheme='webhdfs', uri_path=uri_path)


def _parse_uri_s3x(parsed_uri):
assert parsed_uri.scheme in ("s3", "s3n", "s3u")

bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@')
key_id = None
port = 443
host = boto.config.get('s3', 'host', 's3.amazonaws.com')
ordinary_calling_format = False
if len(bucket_id) == 1:
# URI without credentials: s3://bucket/object
bucket_id, key_id = bucket_id[0].split('/', 1)
# "None" credentials are interpreted as "look for credentials in other locations" by boto
access_id, access_secret = None, None
elif len(bucket_id) == 2 and len(bucket_id[0].split(':')) == 2:
# URI in full format: s3://key:secret@bucket/object
# access key id: [A-Z0-9]{20}
# secret access key: [A-Za-z0-9/+=]{40}
acc, bucket_id = bucket_id
access_id, access_secret = acc.split(':')
bucket_id, key_id = bucket_id.split('/', 1)
elif len(bucket_id) == 3 and len(bucket_id[0].split(':')) == 2:
# or URI in extended format: s3://key:secret@server[:port]@bucket/object
acc, server, bucket_id = bucket_id
access_id, access_secret = acc.split(':')
bucket_id, key_id = bucket_id.split('/', 1)
server = server.split(':')
ordinary_calling_format = True
host = server[0]
if len(server) == 2:
port = int(server[1])
else:
# more than 2 '@' means invalid uri
# Bucket names must be at least 3 and no more than 63 characters long.
# Bucket names must be a series of one or more labels.
# Adjacent labels are separated by a single period (.).
# Bucket names can contain lowercase letters, numbers, and hyphens.
# Each label must start and end with a lowercase letter or a number.
raise RuntimeError("invalid S3 URI: %s" % str(parsed_uri))

return Uri(
scheme=parsed_uri.scheme, bucket_id=bucket_id, key_id=key_id,
port=port, host=host, ordinary_calling_format=ordinary_calling_format,
access_id=access_id, access_secret=access_secret
)


def _parse_uri_file(parsed_uri):
assert parsed_uri.scheme in (None, '', 'file')
uri_path = parsed_uri.netloc + parsed_uri.path
# '~/tmp' may be expanded to '/Users/username/tmp'
uri_path = os.path.expanduser(uri_path)

if not uri_path:
raise RuntimeError("invalid file URI: %s" % uri)

return Uri(scheme='file', uri_path=uri_path)


def _make_closing(base, **attrs):
Expand Down
24 changes: 12 additions & 12 deletions smart_open/tests/test_smart_open.py
Expand Up @@ -39,62 +39,62 @@ def test_scheme(self):
"""Do URIs schemes parse correctly?"""
# supported schemes
for scheme in ("s3", "s3n", "hdfs", "file", "http", "https"):
parsed_uri = smart_open.ParseUri(scheme + "://mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri(scheme + "://mybucket/mykey")
self.assertEqual(parsed_uri.scheme, scheme)

# unsupported scheme => NotImplementedError
self.assertRaises(NotImplementedError, smart_open.ParseUri, "foobar://mybucket/mykey")
self.assertRaises(NotImplementedError, smart_open_lib._parse_uri, "foobar://mybucket/mykey")

# unknown scheme => default_scheme
parsed_uri = smart_open.ParseUri("blah blah")
parsed_uri = smart_open_lib._parse_uri("blah blah")
self.assertEqual(parsed_uri.scheme, "file")

def test_s3_uri(self):
"""Do S3 URIs parse correctly?"""
# correct uri without credentials
parsed_uri = smart_open.ParseUri("s3://mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

# correct uri, key contains slash
parsed_uri = smart_open.ParseUri("s3://mybucket/mydir/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mydir/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mydir/mykey")
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

# correct uri with credentials
parsed_uri = smart_open.ParseUri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, "ACCESSID456")
self.assertEqual(parsed_uri.access_secret, "acces/sse_cr-et")

# correct uri, contains credentials
parsed_uri = smart_open.ParseUri("s3://accessid:access/secret@mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, "accessid")
self.assertEqual(parsed_uri.access_secret, "access/secret")

# incorrect uri - only two '@' in uri are allowed
self.assertRaises(RuntimeError, smart_open.ParseUri, "s3://access_id@access_secret@mybucket@port/mykey")
self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id@access_secret@mybucket@port/mykey")

def test_webhdfs_uri(self):
"""Do webhdfs URIs parse correctly"""
# valid uri, no query
parsed_uri = smart_open.ParseUri("webhdfs://host:port/path/file")
parsed_uri = smart_open_lib._parse_uri("webhdfs://host:port/path/file")
self.assertEqual(parsed_uri.scheme, "webhdfs")
self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file")

# valid uri, with query
parsed_uri = smart_open.ParseUri("webhdfs://host:port/path/file?query_part_1&query_part2")
parsed_uri = smart_open_lib._parse_uri("webhdfs://host:port/path/file?query_part_1&query_part2")
self.assertEqual(parsed_uri.scheme, "webhdfs")
self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file?query_part_1&query_part2")

Expand Down Expand Up @@ -825,7 +825,7 @@ def test_r(self):

def test_bad_mode(self):
"""Bad mode should raise and exception."""
uri = smart_open.ParseUri("s3://bucket/key")
uri = smart_open_lib._parse_uri("s3://bucket/key")
self.assertRaises(NotImplementedError, smart_open.smart_open, uri, "x")

@mock_s3
Expand Down Expand Up @@ -881,7 +881,7 @@ def test_gzip_write_mode(self):
"""Should always open in binary mode when writing through a codec."""
s3 = boto3.resource('s3')
s3.create_bucket(Bucket='bucket')
uri = smart_open.ParseUri("s3://bucket/key.gz")
uri = smart_open_lib._parse_uri("s3://bucket/key.gz")

with mock.patch('smart_open.smart_open_s3.open') as mock_open:
smart_open.smart_open("s3://bucket/key.gz", "wb")
Expand Down