Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Support TLS and download files asynchronously. (#53)
* install the TLS dependencies for twisted with `twisted[tls]`.

* Move download result file logic into Job class.

* Replace `requests-mock` with `pytest-mock`.

* use `pytest-mock` in `test__retry_post_request`.
  • Loading branch information
willgraf committed Jun 22, 2020
1 parent 49776b7 commit eb23a8d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -16,7 +16,7 @@ cache: pip

install:
- pip install -r requirements.txt
- pip install -U pytest pytest-twisted pytest-cov pytest-pep8 coveralls requests-mock attrs>=19.2.0
- pip install -U pytest pytest-twisted pytest-cov pytest-pep8 coveralls pytest-mock attrs>=19.2.0

script:
- pytest --cov=kiosk_client --pep8 kiosk_client
Expand Down
36 changes: 34 additions & 2 deletions kiosk_client/job.py
Expand Up @@ -38,7 +38,7 @@
from twisted.internet import error as twisted_errors
from twisted.web import _newclient as twisted_client

from kiosk_client.utils import sleep, strip_bucket_prefix
from kiosk_client.utils import sleep, strip_bucket_prefix, get_download_path


class Job(object):
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(self, host, filepath, model_name, model_version, **kwargs):
self.expire_time = int(kwargs.get('expire_time', 3600))
self.update_interval = int(kwargs.get('update_interval', 10))
self.original_name = kwargs.get('original_name', self.filepath)
self.download_results = kwargs.get('download_results', False)

self.failed = False # for error handling
self.is_expired = False
Expand Down Expand Up @@ -196,7 +197,7 @@ def _make_post_request(self, host, **kwargs):

@defer.inlineCallbacks
def _retry_post_request_wrapper(self, host, name='REDIS', **kwargs):
retrying = True # retry loop to prevent stackoverflow
retrying = True # retry loop to prevent stackoverflow
while retrying:
created_at = timeit.default_timer()
try:
Expand Down Expand Up @@ -335,6 +336,34 @@ def expire(self):
value = response.get('value')
defer.returnValue(value) # "return" the value

@defer.inlineCallbacks
def download_output(self):
start = timeit.default_timer()
basename = self.output_url.split('/')[-1]
dest = os.path.join(get_download_path(), basename)
self.logger.info('[%s]: Downloading output file %s to %s.',
self.job_id, self.output_url, dest)
name = 'DOWNLOAD RESULTS'
retrying = True # retry loop to prevent stackoverflow
while retrying:
try:
request = treq.get(self.output_url, unbuffered=True)
response = yield request
except self._http_errors as err:
self.logger.warning('[%s]: Encountered %s during %s: %s',
self.job_id, type(err).__name__, name, err)
yield self.sleep(self.update_interval)
continue # return to top of retry loop
retrying = False # success

with open(dest, 'wb') as outfile:
yield response.collect(outfile.write)

self.logger.info('Saved output file: "%s" in %s s.',
dest, timeit.default_timer() - start)

defer.returnValue(dest)

@defer.inlineCallbacks
def restart(self, delay=0):
if not self.failed:
Expand Down Expand Up @@ -390,6 +419,9 @@ def start(self, delay=0, upload=False):
self.job_id, diff.total_seconds(),
self.status, self.output_url)

if self.download_results:
success = yield self.download_output()

elif self.status == 'failed':
reason = yield self.get_redis_value('reason')
self.logger.warning('[%s]: Found final status `%s`: %s',
Expand Down
48 changes: 42 additions & 6 deletions kiosk_client/job_test.py
Expand Up @@ -29,6 +29,7 @@
from __future__ import print_function

import datetime
import os
import random
import timeit

Expand Down Expand Up @@ -78,6 +79,7 @@ def _get_default_job(filepath='filepath.png'):
host='localhost',
model_name='model_name',
model_version='0',
download_results=True,
update_interval=0.0001)


Expand Down Expand Up @@ -171,6 +173,37 @@ def dummy_request_fail(*_, **__):
job_id = yield j.upload_file()
assert job_id is None

@pytest_twisted.inlineCallbacks
def test_download_output(self, tmpdir, mocker):

global _download_failed
_download_failed = False

@pytest_twisted.inlineCallbacks
def send_get_request(_, **__):
global _download_failed
if _download_failed:
_download_failed = False
response = Bunch(collect=lambda x: x(b'success'))
yield defer.returnValue(response)
else:
_download_failed = True
errs = _get_default_job()._http_errors
err = errs[random.randint(0, len(errs) - 1)]
raise err('on purpose')

j = _get_default_job()
j.output_url = 'fakeURL.com/testfile.txt'
mocker.patch('kiosk_client.job.get_download_path',
lambda: str(tmpdir))
mocker.patch('treq.get', send_get_request)

result = yield j.download_output()
assert os.path.isfile(result)
assert str(result).startswith(str(tmpdir))
with open(result, 'r') as f:
assert f.read() == 'success'

@pytest_twisted.inlineCallbacks
def test_summarize(self):
j = _get_default_job()
Expand Down Expand Up @@ -306,6 +339,10 @@ def dummy_request_success(*_, **__):
def dummy_upload_success(*_, **__):
yield defer.returnValue('uploads/test.png')

@pytest_twisted.inlineCallbacks
def dummy_download_success(*_, **__):
yield defer.returnValue('downloads/test-results.png')

@pytest_twisted.inlineCallbacks
def dummy_request_fail(*_, **__):
yield defer.returnValue(None)
Expand All @@ -318,10 +355,11 @@ def dummy_request_fail(*_, **__):
j.get_redis_value = dummy_request_success
j.expire = dummy_request_success
j.upload_file = dummy_upload_success
j.download_output = dummy_download_success

# is_done and is_summarized
j.status = 'done'
j.output_url = ''
j.output_url = 'local'
j.created_at = datetime.datetime.now().isoformat()
j.finished_at = datetime.datetime.now().isoformat()

Expand All @@ -346,14 +384,13 @@ def dummy_request_fail(*_, **__):
assert value is False # failed

@pytest_twisted.inlineCallbacks
def test__retry_post_request_wrapper(self):
def test__retry_post_request_wrapper(self, mocker):

global _make_request_failed
_make_request_failed = False

@pytest_twisted.inlineCallbacks
def _make_post_request(*_, **__):
_j = _get_default_job()
def dummy_post_request(*_, **__):
global _make_request_failed
if _make_request_failed:
_make_request_failed = False
Expand All @@ -365,7 +402,6 @@ def _make_post_request(*_, **__):
raise err('on purpose')

j = _get_default_job()
j._make_post_request = _make_post_request

mocker.patch('treq.post', dummy_post_request)
result = yield j._retry_post_request_wrapper('host', {})
assert result.get('success')
25 changes: 1 addition & 24 deletions kiosk_client/manager.py
Expand Up @@ -43,7 +43,6 @@
from twisted.web.client import HTTPConnectionPool

from kiosk_client.job import Job
from kiosk_client.utils import get_download_path
from kiosk_client.utils import iter_image_files
from kiosk_client.utils import sleep
from kiosk_client.utils import strip_bucket_prefix
Expand Down Expand Up @@ -162,6 +161,7 @@ def make_job(self, filepath):
postprocess=self.postprocess,
upload_prefix=self.upload_prefix,
update_interval=self.update_interval,
download_results=self.download_results,
expire_time=self.expire_time,
pool=self.pool)

Expand Down Expand Up @@ -223,18 +223,6 @@ def check_job_status(self):

yield self._stop()

def download_file_from_url(self, url, dest):
"""Download a file from the URL to the destination file path."""
# TODO: resolve treq SSL issue, replace requests with treq.
start = timeit.default_timer()
self.logger.info('Downloading output file %s to %s.', url, dest)
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(dest, 'wb') as outfile:
shutil.copyfileobj(r.raw, outfile)
self.logger.info('Saved output file: "%s" in %s s.',
dest, timeit.default_timer() - start)

def summarize(self):
time_elapsed = timeit.default_timer() - self.created_at
self.logger.info('Finished %s jobs in %s seconds.',
Expand Down Expand Up @@ -264,17 +252,6 @@ def summarize(self):
len(self.all_jobs), self.start_delay, uuid.uuid4().hex)
output_filepath = os.path.join(settings.OUTPUT_DIR, output_filepath)

if self.download_results:
for job in self.all_jobs:
try:
basename = job.output_url.split('/')[-1]
filepath = os.path.join(get_download_path(), basename)
self.download_file_from_url(job.output_url, filepath)
except Exception as err: # pylint: disable=broad-except
self.logger.error('Could not download %s due to %s: %s. ',
job.output_url, type(err).__name__, err)
continue

with open(output_filepath, 'w') as jsonfile:
json.dump(jsondata, jsonfile, indent=4)
self.logger.info('Wrote job data as JSON to %s.', output_filepath)
Expand Down
12 changes: 0 additions & 12 deletions kiosk_client/manager_test.py
Expand Up @@ -169,25 +169,13 @@ def fake_download_file_bad(url, dest):
# monkey-patches for testing
mgr.cost_getter.finish = lambda: (1, 2, 3)
mgr.upload_file = fake_upload_file
mgr.download_file_from_url = fake_download_file
mgr.summarize()

# test Exceptions
mgr.cost_getter.finish = lambda: 0 / 1
mgr.upload_file = fake_upload_file_bad
mgr.download_file_from_url = fake_download_file_bad
mgr.summarize()

def test_download_file_from_url(self, tmpdir, requests_mock):
tmpdir = str(tmpdir)
expected_file = os.path.join(tmpdir, 'downloaded.txt')
requests_mock.get('http://test.com', text='data')

mgr = manager.BenchmarkingJobManager(host='localhost', job_type='job')
mgr.download_file_from_url('http://test.com', expected_file)
assert os.path.exists(expected_file)
assert os.path.isfile(expected_file)

@pytest_twisted.inlineCallbacks
def test_check_job_status(self):
mgr = manager.JobManager(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -3,4 +3,4 @@ Pillow>=6.2.0
python-decouple>=3.1,<4
python-dateutil==2.8.0
treq==20.3.0
twisted>=20.3.0
twisted[tls]>=20.3.0
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -41,13 +41,13 @@
'python-decouple',
'python-dateutil==2.8.0',
'treq==20.3.0',
'twisted>=20.3.0'],
'twisted[tls]>=20.3.0'],
extras_require={
'tests': ['pytest',
'pytest-twisted',
'pytest-pep8',
'pytest-cov',
'requests-mock',
'pytest-mock',
'attrs>=19.2.0'],
},
packages=find_packages())

0 comments on commit eb23a8d

Please sign in to comment.