Skip to content

Commit

Permalink
Replace ParsedUri class with functions, cleanup internal argument p…
Browse files Browse the repository at this point in the history
…arsing (#191)

* replace ParsedUri function masquerading as a class with a function

* add comment to explain hacky code
  • Loading branch information
mpenkov authored and menshikh-iv committed Apr 23, 2018
1 parent abb08ae commit 34b3f90
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 91 deletions.
210 changes: 131 additions & 79 deletions smart_open/smart_open_lib.py
Expand Up @@ -22,6 +22,7 @@
"""

import codecs
import collections
import logging
import os
import os.path as P
Expand Down Expand Up @@ -77,6 +78,34 @@
DEFAULT_ERRORS = 'strict'


Uri = collections.namedtuple(
'Uri',
(
'scheme',
'uri_path',
'bucket_id',
'key_id',
'port',
'host',
'ordinary_calling_format',
'access_id',
'access_secret',
)
)
"""Represents all the options that we parse from user input.
Some of the above options only make sense for certain protocols, e.g.
bucket_id is only for S3.
"""
#
# Set the default values for all Uri fields to be None. This allows us to only
# specify the relevant fields when constructing a Uri.
#
# https://stackoverflow.com/questions/11351032/namedtuple-and-default-values-for-optional-keyword-arguments
#
Uri.__new__.__defaults__ = (None,) * len(Uri._fields)


def smart_open(uri, mode="rb", **kw):
"""
Open the given S3 / HDFS / filesystem file pointed to by `uri` for reading or writing.
Expand Down Expand Up @@ -231,7 +260,7 @@ def _shortcut_open(uri, mode, **kw):
if not isinstance(uri, six.string_types):
return None

parsed_uri = ParseUri(uri)
parsed_uri = _parse_uri(uri)
if parsed_uri.scheme != 'file':
return None

Expand Down Expand Up @@ -275,7 +304,7 @@ def _open_binary_stream(uri, mode, **kw):
# this method just routes the request to classes handling the specific storage
# schemes, depending on the URI protocol in `uri`
filename = uri.split('/')[-1]
parsed_uri = ParseUri(uri)
parsed_uri = _parse_uri(uri)
unsupported = "%r mode not supported for %r scheme" % (mode, parsed_uri.scheme)

if parsed_uri.scheme in ("file", ):
Expand Down Expand Up @@ -349,9 +378,9 @@ def _s3_open_uri(parsed_uri, mode, **kwargs):
return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs)


class ParseUri(object):
def _parse_uri(uri_as_string):
"""
Parse the given URI.
Parse the given URI from a string.
Supported URI schemes are "file", "s3", "s3n", "s3u" and "hdfs".
Expand All @@ -372,82 +401,105 @@ class ParseUri(object):
* ./local/path/file.gz
* file:///home/user/file
* file:///home/user/file.bz2
"""
def __init__(self, uri, default_scheme="file"):
"""
Assume `default_scheme` if no scheme given in `uri`.
"""
if os.name == 'nt':
# urlsplit doesn't work on Windows -- it parses the drive as the scheme...
if '://' not in uri:
# no protocol given => assume a local file
uri = 'file://' + uri
parsed_uri = urlsplit(uri, allow_fragments=False)
self.scheme = parsed_uri.scheme if parsed_uri.scheme else default_scheme

if self.scheme == "hdfs":
self.uri_path = parsed_uri.netloc + parsed_uri.path
self.uri_path = "/" + self.uri_path.lstrip("/")

if not self.uri_path:
raise RuntimeError("invalid HDFS URI: %s" % uri)
elif self.scheme == "webhdfs":
self.uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path
if parsed_uri.query:
self.uri_path += "?" + parsed_uri.query

if not self.uri_path:
raise RuntimeError("invalid WebHDFS URI: %s" % uri)
elif self.scheme in ("s3", "s3n", "s3u"):
self.bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@')
self.key_id = None
self.port = 443
self.host = boto.config.get('s3', 'host', 's3.amazonaws.com')
self.ordinary_calling_format = False
if len(self.bucket_id) == 1:
# URI without credentials: s3://bucket/object
self.bucket_id, self.key_id = self.bucket_id[0].split('/', 1)
# "None" credentials are interpreted as "look for credentials in other locations" by boto
self.access_id, self.access_secret = None, None
elif len(self.bucket_id) == 2 and len(self.bucket_id[0].split(':')) == 2:
# URI in full format: s3://key:secret@bucket/object
# access key id: [A-Z0-9]{20}
# secret access key: [A-Za-z0-9/+=]{40}
acc, self.bucket_id = self.bucket_id
self.access_id, self.access_secret = acc.split(':')
self.bucket_id, self.key_id = self.bucket_id.split('/', 1)
elif len(self.bucket_id) == 3 and len(self.bucket_id[0].split(':')) == 2:
# or URI in extended format: s3://key:secret@server[:port]@bucket/object
acc, server, self.bucket_id = self.bucket_id
self.access_id, self.access_secret = acc.split(':')
self.bucket_id, self.key_id = self.bucket_id.split('/', 1)
server = server.split(':')
self.ordinary_calling_format = True
self.host = server[0]
if len(server) == 2:
self.port = int(server[1])
else:
# more than 2 '@' means invalid uri
# Bucket names must be at least 3 and no more than 63 characters long.
# Bucket names must be a series of one or more labels.
# Adjacent labels are separated by a single period (.).
# Bucket names can contain lowercase letters, numbers, and hyphens.
# Each label must start and end with a lowercase letter or a number.
raise RuntimeError("invalid S3 URI: %s" % uri)
elif self.scheme == 'file':
self.uri_path = parsed_uri.netloc + parsed_uri.path

# '~/tmp' may be expanded to '/Users/username/tmp'
self.uri_path = os.path.expanduser(self.uri_path)

if not self.uri_path:
raise RuntimeError("invalid file URI: %s" % uri)
elif self.scheme.startswith('http'):
self.uri_path = uri
else:
raise NotImplementedError("unknown URI scheme %r in %r" % (self.scheme, uri))
if os.name == 'nt':
# urlsplit doesn't work on Windows -- it parses the drive as the scheme...
if '://' not in uri_as_string:
# no protocol given => assume a local file
uri_as_string = 'file://' + uri_as_string
parsed_uri = urlsplit(uri_as_string, allow_fragments=False)

if parsed_uri.scheme == "hdfs":
return _parse_uri_hdfs(parsed_uri)
elif parsed_uri.scheme == "webhdfs":
return _parse_uri_webhdfs(parsed_uri)
elif parsed_uri.scheme in ("s3", "s3n", "s3u"):
return _parse_uri_s3x(parsed_uri)
elif parsed_uri.scheme in ('file', '', None):
return _parse_uri_file(parsed_uri)
elif parsed_uri.scheme.startswith('http'):
return Uri(scheme=parsed_uri.scheme, uri_path=uri_as_string)
else:
raise NotImplementedError(
"unknown URI scheme %r in %r" % (parsed_uri.scheme, uri_as_string)
)


def _parse_uri_hdfs(parsed_uri):
assert parsed_uri.scheme == 'hdfs'
uri_path = parsed_uri.netloc + parsed_uri.path
uri_path = "/" + uri_path.lstrip("/")
if not uri_path:
raise RuntimeError("invalid HDFS URI: %s" % uri)
return Uri(scheme='hdfs', uri_path=uri_path)


def _parse_uri_webhdfs(parsed_uri):
assert parsed_uri.scheme == 'webhdfs'
uri_path = parsed_uri.netloc + "/webhdfs/v1" + parsed_uri.path
if parsed_uri.query:
uri_path += "?" + parsed_uri.query
if not uri_path:
raise RuntimeError("invalid WebHDFS URI: %s" % uri)
return Uri(scheme='webhdfs', uri_path=uri_path)


def _parse_uri_s3x(parsed_uri):
assert parsed_uri.scheme in ("s3", "s3n", "s3u")

bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@')
key_id = None
port = 443
host = boto.config.get('s3', 'host', 's3.amazonaws.com')
ordinary_calling_format = False
if len(bucket_id) == 1:
# URI without credentials: s3://bucket/object
bucket_id, key_id = bucket_id[0].split('/', 1)
# "None" credentials are interpreted as "look for credentials in other locations" by boto
access_id, access_secret = None, None
elif len(bucket_id) == 2 and len(bucket_id[0].split(':')) == 2:
# URI in full format: s3://key:secret@bucket/object
# access key id: [A-Z0-9]{20}
# secret access key: [A-Za-z0-9/+=]{40}
acc, bucket_id = bucket_id
access_id, access_secret = acc.split(':')
bucket_id, key_id = bucket_id.split('/', 1)
elif len(bucket_id) == 3 and len(bucket_id[0].split(':')) == 2:
# or URI in extended format: s3://key:secret@server[:port]@bucket/object
acc, server, bucket_id = bucket_id
access_id, access_secret = acc.split(':')
bucket_id, key_id = bucket_id.split('/', 1)
server = server.split(':')
ordinary_calling_format = True
host = server[0]
if len(server) == 2:
port = int(server[1])
else:
# more than 2 '@' means invalid uri
# Bucket names must be at least 3 and no more than 63 characters long.
# Bucket names must be a series of one or more labels.
# Adjacent labels are separated by a single period (.).
# Bucket names can contain lowercase letters, numbers, and hyphens.
# Each label must start and end with a lowercase letter or a number.
raise RuntimeError("invalid S3 URI: %s" % str(parsed_uri))

return Uri(
scheme=parsed_uri.scheme, bucket_id=bucket_id, key_id=key_id,
port=port, host=host, ordinary_calling_format=ordinary_calling_format,
access_id=access_id, access_secret=access_secret
)


def _parse_uri_file(parsed_uri):
assert parsed_uri.scheme in (None, '', 'file')
uri_path = parsed_uri.netloc + parsed_uri.path
# '~/tmp' may be expanded to '/Users/username/tmp'
uri_path = os.path.expanduser(uri_path)

if not uri_path:
raise RuntimeError("invalid file URI: %s" % uri)

return Uri(scheme='file', uri_path=uri_path)


def _make_closing(base, **attrs):
Expand Down
24 changes: 12 additions & 12 deletions smart_open/tests/test_smart_open.py
Expand Up @@ -39,62 +39,62 @@ def test_scheme(self):
"""Do URIs schemes parse correctly?"""
# supported schemes
for scheme in ("s3", "s3n", "hdfs", "file", "http", "https"):
parsed_uri = smart_open.ParseUri(scheme + "://mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri(scheme + "://mybucket/mykey")
self.assertEqual(parsed_uri.scheme, scheme)

# unsupported scheme => NotImplementedError
self.assertRaises(NotImplementedError, smart_open.ParseUri, "foobar://mybucket/mykey")
self.assertRaises(NotImplementedError, smart_open_lib._parse_uri, "foobar://mybucket/mykey")

# unknown scheme => default_scheme
parsed_uri = smart_open.ParseUri("blah blah")
parsed_uri = smart_open_lib._parse_uri("blah blah")
self.assertEqual(parsed_uri.scheme, "file")

def test_s3_uri(self):
"""Do S3 URIs parse correctly?"""
# correct uri without credentials
parsed_uri = smart_open.ParseUri("s3://mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

# correct uri, key contains slash
parsed_uri = smart_open.ParseUri("s3://mybucket/mydir/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mydir/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mydir/mykey")
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

# correct uri with credentials
parsed_uri = smart_open.ParseUri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, "ACCESSID456")
self.assertEqual(parsed_uri.access_secret, "acces/sse_cr-et")

# correct uri, contains credentials
parsed_uri = smart_open.ParseUri("s3://accessid:access/secret@mybucket/mykey")
parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, "accessid")
self.assertEqual(parsed_uri.access_secret, "access/secret")

# incorrect uri - only two '@' in uri are allowed
self.assertRaises(RuntimeError, smart_open.ParseUri, "s3://access_id@access_secret@mybucket@port/mykey")
self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id@access_secret@mybucket@port/mykey")

def test_webhdfs_uri(self):
"""Do webhdfs URIs parse correctly"""
# valid uri, no query
parsed_uri = smart_open.ParseUri("webhdfs://host:port/path/file")
parsed_uri = smart_open_lib._parse_uri("webhdfs://host:port/path/file")
self.assertEqual(parsed_uri.scheme, "webhdfs")
self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file")

# valid uri, with query
parsed_uri = smart_open.ParseUri("webhdfs://host:port/path/file?query_part_1&query_part2")
parsed_uri = smart_open_lib._parse_uri("webhdfs://host:port/path/file?query_part_1&query_part2")
self.assertEqual(parsed_uri.scheme, "webhdfs")
self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file?query_part_1&query_part2")

Expand Down Expand Up @@ -825,7 +825,7 @@ def test_r(self):

def test_bad_mode(self):
"""Bad mode should raise and exception."""
uri = smart_open.ParseUri("s3://bucket/key")
uri = smart_open_lib._parse_uri("s3://bucket/key")
self.assertRaises(NotImplementedError, smart_open.smart_open, uri, "x")

@mock_s3
Expand Down Expand Up @@ -881,7 +881,7 @@ def test_gzip_write_mode(self):
"""Should always open in binary mode when writing through a codec."""
s3 = boto3.resource('s3')
s3.create_bucket(Bucket='bucket')
uri = smart_open.ParseUri("s3://bucket/key.gz")
uri = smart_open_lib._parse_uri("s3://bucket/key.gz")

with mock.patch('smart_open.smart_open_s3.open') as mock_open:
smart_open.smart_open("s3://bucket/key.gz", "wb")
Expand Down

0 comments on commit 34b3f90

Please sign in to comment.