Skip to content

Commit

Permalink
Minor fixes (#291)
Browse files Browse the repository at this point in the history
* fix warning, fix #288 
* include VERSION in package_data, fix #289 
* fix #285, add special handling for question marks during url parsing
* add unit test, fix #47
  • Loading branch information
mpenkov committed Apr 16, 2019
1 parent ca29842 commit 9d5f32c
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 5 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ include LICENSE
include README.rst
include CHANGELOG.rst
include setup.py
include smart_open/VERSION
18 changes: 16 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ def read(fname):
return io.open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8').read()


#
# This code intentially duplicates a similar function in __init__.py. The
# alternative would be to somehow import that module to access the function,
# which would be too messy for a setup.py script.
#
def _get_version():
curr_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(curr_dir, 'smart_open', 'VERSION')) as fin:
return fin.read().strip()


tests_require = [
'mock',
'moto==1.3.4',
Expand All @@ -32,12 +43,15 @@ def read(fname):

setup(
name='smart_open',
version='1.8.1',
version=_get_version(),
description='Utils for streaming large files (S3, HDFS, gzip, bz2...)',
long_description=read('README.rst'),

packages=find_packages(),
package_data={"smart_open.tests": ["test_data/*gz"]},
package_data={
"smart_open": ["VERSION"],
"smart_open.tests": ["test_data/*gz"],
},

author='Radim Rehurek',
author_email='me@radimrehurek.com',
Expand Down
1 change: 1 addition & 0 deletions smart_open/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1.8.1
10 changes: 10 additions & 0 deletions smart_open/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@
"""

import logging
import os.path

from .smart_open_lib import open, smart_open, register_compressor
from .s3 import iter_bucket as s3_iter_bucket
__all__ = ['open', 'smart_open', 's3_iter_bucket', 'register_compressor']


def _get_version():
curr_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(curr_dir, 'VERSION')) as fin:
return fin.read().strip()


__version__ = _get_version()

This comment has been minimized.

Copy link
@piskvorky

piskvorky Apr 19, 2019

Owner

May be safer like this?

import pkg_resources
pkg_resources.get_distribution("smart_open").version

That is, keep a plain version string only in setup.py; read the run-time version from pkg_resources.

I'm not sure if there are any disadvantages / gotchas, but it seems cleaner and less error-prone (no need for any extra file and extra code).

This comment has been minimized.

Copy link
@mpenkov

mpenkov Apr 20, 2019

Author Collaborator

There are many ways to achieve this.

My personal preference is to use the VERSION file (it's option 4 in the recommended practices list) as it's simple and there are no gotchas. Given that what we have in place already already works and requires no maintenance (other than bumping the VERSION file), my preference is to keep things as is. Let me know if you're OK with this, otherwise we can move to an alternative option.

BTW, the pkg_resources option is number 5 on that list, so it's also a common practice, but I'm not familiar with it and there may be gotchas.

This comment has been minimized.

Copy link
@piskvorky

piskvorky Apr 20, 2019

Owner

I haven't seen that resource / summary before, thanks for sharing.

I don't feel strongly about this. Just my personal preference is always for the simplest solution, with the least room for error or duplication.


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
3 changes: 2 additions & 1 deletion smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import contextlib
import functools
import logging
import warnings

import boto3
import botocore.client
Expand All @@ -20,7 +21,7 @@
import multiprocessing.pool
_MULTIPROCESSING = True
except ImportError:
logger.warning("multiprocessing could not be imported and won't be used")
warnings.warn("multiprocessing could not be imported and won't be used")


DEFAULT_MIN_PART_SIZE = 50 * 1024**2
Expand Down
29 changes: 28 additions & 1 deletion smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,32 @@ def _s3_open_uri(parsed_uri, mode, transport_params):
return smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs)


def _my_urlsplit(url):
"""This is a hack to prevent the regular urlsplit from splitting around question marks.
A question mark (?) in a URL typically indicates the start of a
querystring, and the standard library's urlparse function handles the
querystring separately. Unfortunately, question marks can also appear
_inside_ the actual URL for some schemas like S3.
Replaces question marks with newlines prior to splitting. This is safe because:
1. The standard library's urlsplit completely ignores newlines
2. Raw newlines will never occur in innocuous URLs. They are always URL-encoded.
See Also
--------
https://github.com/python/cpython/blob/3.7/Lib/urllib/parse.py
https://github.com/RaRe-Technologies/smart_open/issues/285
"""
if '?' not in url:
return urlsplit(url, allow_fragments=False)

sr = urlsplit(url.replace('?', '\n'), allow_fragments=False)
SplitResult = collections.namedtuple('SplitResult', 'scheme netloc path query fragment')
return SplitResult(sr.scheme, sr.netloc, sr.path.replace('\n', '?'), '', '')


def _parse_uri(uri_as_string):
"""
Parse the given URI from a string.
Expand Down Expand Up @@ -604,7 +630,8 @@ def _parse_uri(uri_as_string):
if '://' not in uri_as_string:
# no protocol given => assume a local file
uri_as_string = 'file://' + uri_as_string
parsed_uri = urlsplit(uri_as_string, allow_fragments=False)

parsed_uri = _my_urlsplit(uri_as_string)

if parsed_uri.scheme == "hdfs":
return _parse_uri_hdfs(parsed_uri)
Expand Down
3 changes: 2 additions & 1 deletion smart_open/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@

import getpass
import logging
import warnings

logger = logging.getLogger(__name__)

try:
import paramiko
except ImportError:
logger.warning('paramiko missing, opening SSH/SCP/SFTP paths will be disabled. `pip install paramiko` to suppress')
warnings.warn('paramiko missing, opening SSH/SCP/SFTP paths will be disabled. `pip install paramiko` to suppress')

#
# Global storage for SSH connections.
Expand Down
19 changes: 19 additions & 0 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,25 @@ def test_old(self):
self.assertEqual(result, expected)


@maybe_mock_s3
class IterBucketSingleProcessTest(unittest.TestCase):
def setUp(self):
self.old_flag = smart_open.s3._MULTIPROCESSING
smart_open.s3._MULTIPROCESSING = False

def tearDown(self):
smart_open.s3._MULTIPROCESSING = self.old_flag

def test(self):
num_keys = 101
populate_bucket(num_keys=num_keys)
keys = list(smart_open.s3.iter_bucket(BUCKET_NAME))
self.assertEqual(len(keys), num_keys)

expected = [('key_%d' % x, b'%d' % x) for x in range(num_keys)]
self.assertEqual(sorted(keys), sorted(expected))


@maybe_mock_s3
class DownloadKeyTest(unittest.TestCase):

Expand Down
10 changes: 10 additions & 0 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def test_s3_uri_has_atmark_in_key_name2(self):
self.assertEqual(parsed_uri.host, "hostname")
self.assertEqual(parsed_uri.port, 1234)

def test_s3_handles_fragments(self):
uri_str = 's3://bucket-name/folder/picture #1.jpg'
parsed_uri = smart_open_lib._parse_uri(uri_str)
self.assertEqual(parsed_uri.key_id, "folder/picture #1.jpg")

def test_s3_handles_querystring(self):
uri_str = 's3://bucket-name/folder/picture1.jpg?bar'
parsed_uri = smart_open_lib._parse_uri(uri_str)
self.assertEqual(parsed_uri.key_id, "folder/picture1.jpg?bar")

def test_s3_invalid_url_atmark_in_bucket_name(self):
self.assertRaises(ValueError, smart_open_lib._parse_uri, "s3://access_id:access_secret@my@bucket@port/mykey")

Expand Down

0 comments on commit 9d5f32c

Please sign in to comment.