Skip to content

Commit

Permalink
Allow using '@' in object (key) names. Fix #94 (#224)
Browse files Browse the repository at this point in the history
* Allow using '@' in bucket and object names

* add unit tests for processing '@' in key

* split large test into smaller ones

* add explicit unit test for issue 223
  • Loading branch information
mpenkov authored and menshikh-iv committed Sep 2, 2018
1 parent 1d7538f commit ec58647
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 33 deletions.
61 changes: 32 additions & 29 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,8 @@ def _parse_uri_hdfs(parsed_uri):
uri_path = parsed_uri.netloc + parsed_uri.path
uri_path = "/" + uri_path.lstrip("/")
if not uri_path:
raise RuntimeError("invalid HDFS URI: %s" % parsed_uri)
raise RuntimeError("invalid HDFS URI: %s" % str(parsed_uri))

return Uri(scheme='hdfs', uri_path=uri_path)


Expand All @@ -462,42 +463,44 @@ def _parse_uri_webhdfs(parsed_uri):
if parsed_uri.query:
uri_path += "?" + parsed_uri.query
if not uri_path:
raise RuntimeError("invalid WebHDFS URI: %s" % parsed_uri)
raise RuntimeError("invalid WebHDFS URI: %s" % str(parsed_uri))

return Uri(scheme='webhdfs', uri_path=uri_path)


def _parse_uri_s3x(parsed_uri):
assert parsed_uri.scheme in ("s3", "s3n", "s3u")

bucket_id = (parsed_uri.netloc + parsed_uri.path).split('@')
key_id = None
port = 443
host = boto.config.get('s3', 'host', 's3.amazonaws.com')
ordinary_calling_format = False
if len(bucket_id) == 1:
# URI without credentials: s3://bucket/object
bucket_id, key_id = bucket_id[0].split('/', 1)
# "None" credentials are interpreted as "look for credentials in other locations" by boto
access_id, access_secret = None, None
elif len(bucket_id) == 2 and len(bucket_id[0].split(':')) == 2:
# URI in full format: s3://key:secret@bucket/object
# access key id: [A-Z0-9]{20}
# secret access key: [A-Za-z0-9/+=]{40}
acc, bucket_id = bucket_id
access_id, access_secret = acc.split(':')
bucket_id, key_id = bucket_id.split('/', 1)
elif len(bucket_id) == 3 and len(bucket_id[0].split(':')) == 2:
# or URI in extended format: s3://key:secret@server[:port]@bucket/object
acc, server, bucket_id = bucket_id
access_id, access_secret = acc.split(':')
bucket_id, key_id = bucket_id.split('/', 1)
server = server.split(':')
ordinary_calling_format = True
host = server[0]
if len(server) == 2:
port = int(server[1])
else:
# more than 2 '@' means invalid uri

# Common URI template [secret:key@][host[:port]@]bucket/object
try:
uri = parsed_uri.netloc + parsed_uri.path
# Separate authentication from URI if exist
if ':' in uri.split('@')[0]:
auth, uri = uri.split('@', 1)
access_id, access_secret = auth.split(':')
else:
# "None" credentials are interpreted as "look for credentials in other locations" by boto
access_id, access_secret = None, None

# Split [host[:port]@]bucket/path
host_bucket, key_id = uri.split('/', 1)
if '@' in host_bucket:
host_port, bucket_id = host_bucket.split('@')
ordinary_calling_format = True
if ':' in host_port:
server = host_port.split(':')
host = server[0]
if len(server) == 2:
port = int(server[1])
else:
host = host_port
else:
bucket_id = host_bucket
except Exception:
# Bucket names must be at least 3 and no more than 63 characters long.
# Bucket names must be a series of one or more labels.
# Adjacent labels are separated by a single period (.).
Expand All @@ -519,7 +522,7 @@ def _parse_uri_file(parsed_uri):
uri_path = os.path.expanduser(uri_path)

if not uri_path:
raise RuntimeError("invalid file URI: %s" % parsed_uri)
raise RuntimeError("invalid file URI: %s" % str(parsed_uri))

return Uri(scheme='file', uri_path=uri_path)

Expand Down
37 changes: 33 additions & 4 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,52 @@ def test_s3_uri(self):
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

# correct uri, key contains slash
def test_s3_uri_contains_slash(self):
parsed_uri = smart_open_lib._parse_uri("s3://mybucket/mydir/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mydir/mykey")
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

# correct uri with credentials
def test_s3_uri_with_credentials(self):
parsed_uri = smart_open_lib._parse_uri("s3://ACCESSID456:acces/sse_cr-et@mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, "ACCESSID456")
self.assertEqual(parsed_uri.access_secret, "acces/sse_cr-et")

# correct uri, contains credentials
def test_s3_uri_with_credentials2(self):
parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@mybucket/mykey")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "mykey")
self.assertEqual(parsed_uri.access_id, "accessid")
self.assertEqual(parsed_uri.access_secret, "access/secret")

# incorrect uri - only two '@' in uri are allowed
def test_s3_uri_has_atmark_in_key_name(self):
parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@mybucket/my@ke@y")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "my@ke@y")
self.assertEqual(parsed_uri.access_id, "accessid")
self.assertEqual(parsed_uri.access_secret, "access/secret")

def test_s3_uri_has_atmark_in_key_name2(self):
parsed_uri = smart_open_lib._parse_uri("s3://accessid:access/secret@hostname:1234@mybucket/dir/my@ke@y")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "mybucket")
self.assertEqual(parsed_uri.key_id, "dir/my@ke@y")
self.assertEqual(parsed_uri.access_id, "accessid")
self.assertEqual(parsed_uri.access_secret, "access/secret")
self.assertEqual(parsed_uri.host, "hostname")
self.assertEqual(parsed_uri.port, 1234)

def test_s3_invalid_url_atmark_in_bucket_name(self):
self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id:access_secret@my@bucket@port/mykey")

def test_s3_invalid_uri_missing_colon(self):
self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id@access_secret@mybucket@port/mykey")

def test_webhdfs_uri(self):
Expand All @@ -99,6 +120,14 @@ def test_webhdfs_uri(self):
self.assertEqual(parsed_uri.scheme, "webhdfs")
self.assertEqual(parsed_uri.uri_path, "host:port/webhdfs/v1/path/file?query_part_1&query_part2")

def test_uri_from_issue_223_works(self):
parsed_uri = smart_open_lib._parse_uri("s3://:@omax-mis/twilio-messages-media/final/MEcd7c36e75f87dc6dd9e33702cdcd8fb6")
self.assertEqual(parsed_uri.scheme, "s3")
self.assertEqual(parsed_uri.bucket_id, "omax-mis")
self.assertEqual(parsed_uri.key_id, "twilio-messages-media/final/MEcd7c36e75f87dc6dd9e33702cdcd8fb6")
self.assertEqual(parsed_uri.access_id, "")
self.assertEqual(parsed_uri.access_secret, "")


class SmartOpenHttpTest(unittest.TestCase):
"""
Expand Down

0 comments on commit ec58647

Please sign in to comment.