diff --git a/.travis.yml b/.travis.yml index bb918f7..d13976e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ cache: pip install: - pip install -r requirements.txt - - pip install -U pytest pytest-twisted pytest-cov pytest-pep8 coveralls requests-mock attrs>=19.2.0 + - pip install -U pytest pytest-twisted pytest-cov pytest-pep8 coveralls pytest-mock attrs>=19.2.0 script: - pytest --cov=kiosk_client --pep8 kiosk_client diff --git a/kiosk_client/job.py b/kiosk_client/job.py index 57b75a7..264ce80 100644 --- a/kiosk_client/job.py +++ b/kiosk_client/job.py @@ -38,7 +38,7 @@ from twisted.internet import error as twisted_errors from twisted.web import _newclient as twisted_client -from kiosk_client.utils import sleep, strip_bucket_prefix +from kiosk_client.utils import sleep, strip_bucket_prefix, get_download_path class Job(object): @@ -88,6 +88,7 @@ def __init__(self, host, filepath, model_name, model_version, **kwargs): self.expire_time = int(kwargs.get('expire_time', 3600)) self.update_interval = int(kwargs.get('update_interval', 10)) self.original_name = kwargs.get('original_name', self.filepath) + self.download_results = kwargs.get('download_results', False) self.failed = False # for error handling self.is_expired = False @@ -196,7 +197,7 @@ def _make_post_request(self, host, **kwargs): @defer.inlineCallbacks def _retry_post_request_wrapper(self, host, name='REDIS', **kwargs): - retrying = True # retry loop to prevent stackoverflow + retrying = True # retry loop to prevent stackoverflow while retrying: created_at = timeit.default_timer() try: @@ -335,6 +336,34 @@ def expire(self): value = response.get('value') defer.returnValue(value) # "return" the value + @defer.inlineCallbacks + def download_output(self): + start = timeit.default_timer() + basename = self.output_url.split('/')[-1] + dest = os.path.join(get_download_path(), basename) + self.logger.info('[%s]: Downloading output file %s to %s.', + self.job_id, self.output_url, dest) + name = 'DOWNLOAD RESULTS' + retrying = True # retry loop to prevent stackoverflow + while retrying: + try: + request = treq.get(self.output_url, unbuffered=True) + response = yield request + except self._http_errors as err: + self.logger.warning('[%s]: Encountered %s during %s: %s', + self.job_id, type(err).__name__, name, err) + yield self.sleep(self.update_interval) + continue # return to top of retry loop + retrying = False # success + + with open(dest, 'wb') as outfile: + yield response.collect(outfile.write) + + self.logger.info('Saved output file: "%s" in %s s.', + dest, timeit.default_timer() - start) + + defer.returnValue(dest) + @defer.inlineCallbacks def restart(self, delay=0): if not self.failed: @@ -390,6 +419,9 @@ def start(self, delay=0, upload=False): self.job_id, diff.total_seconds(), self.status, self.output_url) + if self.download_results: + success = yield self.download_output() + elif self.status == 'failed': reason = yield self.get_redis_value('reason') self.logger.warning('[%s]: Found final status `%s`: %s', diff --git a/kiosk_client/job_test.py b/kiosk_client/job_test.py index f504dea..298589f 100644 --- a/kiosk_client/job_test.py +++ b/kiosk_client/job_test.py @@ -29,6 +29,7 @@ from __future__ import print_function import datetime +import os import random import timeit @@ -78,6 +79,7 @@ def _get_default_job(filepath='filepath.png'): host='localhost', model_name='model_name', model_version='0', + download_results=True, update_interval=0.0001) @@ -171,6 +173,37 @@ def dummy_request_fail(*_, **__): job_id = yield j.upload_file() assert job_id is None + @pytest_twisted.inlineCallbacks + def test_download_output(self, tmpdir, mocker): + + global _download_failed + _download_failed = False + + @pytest_twisted.inlineCallbacks + def send_get_request(_, **__): + global _download_failed + if _download_failed: + _download_failed = False + response = Bunch(collect=lambda x: x(b'success')) + yield defer.returnValue(response) + else: + _download_failed = True + errs = _get_default_job()._http_errors + err = errs[random.randint(0, len(errs) - 1)] + raise err('on purpose') + + j = _get_default_job() + j.output_url = 'fakeURL.com/testfile.txt' + mocker.patch('kiosk_client.job.get_download_path', + lambda: str(tmpdir)) + mocker.patch('treq.get', send_get_request) + + result = yield j.download_output() + assert os.path.isfile(result) + assert str(result).startswith(str(tmpdir)) + with open(result, 'r') as f: + assert f.read() == 'success' + @pytest_twisted.inlineCallbacks def test_summarize(self): j = _get_default_job() @@ -306,6 +339,10 @@ def dummy_request_success(*_, **__): def dummy_upload_success(*_, **__): yield defer.returnValue('uploads/test.png') + @pytest_twisted.inlineCallbacks + def dummy_download_success(*_, **__): + yield defer.returnValue('downloads/test-results.png') + @pytest_twisted.inlineCallbacks def dummy_request_fail(*_, **__): yield defer.returnValue(None) @@ -318,10 +355,11 @@ def dummy_request_fail(*_, **__): j.get_redis_value = dummy_request_success j.expire = dummy_request_success j.upload_file = dummy_upload_success + j.download_output = dummy_download_success # is_done and is_summarized j.status = 'done' - j.output_url = '' + j.output_url = 'local' j.created_at = datetime.datetime.now().isoformat() j.finished_at = datetime.datetime.now().isoformat() @@ -346,14 +384,13 @@ def dummy_request_fail(*_, **__): assert value is False # failed @pytest_twisted.inlineCallbacks - def test__retry_post_request_wrapper(self): + def test__retry_post_request_wrapper(self, mocker): global _make_request_failed _make_request_failed = False @pytest_twisted.inlineCallbacks - def _make_post_request(*_, **__): - _j = _get_default_job() + def dummy_post_request(*_, **__): global _make_request_failed if _make_request_failed: _make_request_failed = False @@ -365,7 +402,6 @@ def _make_post_request(*_, **__): raise err('on purpose') j = _get_default_job() - j._make_post_request = _make_post_request - + mocker.patch('treq.post', dummy_post_request) result = yield j._retry_post_request_wrapper('host', {}) assert result.get('success') diff --git a/kiosk_client/manager.py b/kiosk_client/manager.py index df82ab7..6e44049 100644 --- a/kiosk_client/manager.py +++ b/kiosk_client/manager.py @@ -43,7 +43,6 @@ from twisted.web.client import HTTPConnectionPool from kiosk_client.job import Job -from kiosk_client.utils import get_download_path from kiosk_client.utils import iter_image_files from kiosk_client.utils import sleep from kiosk_client.utils import strip_bucket_prefix @@ -162,6 +161,7 @@ def make_job(self, filepath): postprocess=self.postprocess, upload_prefix=self.upload_prefix, update_interval=self.update_interval, + download_results=self.download_results, expire_time=self.expire_time, pool=self.pool) @@ -223,18 +223,6 @@ def check_job_status(self): yield self._stop() - def download_file_from_url(self, url, dest): - """Download a file from the URL to the destination file path.""" - # TODO: resolve treq SSL issue, replace requests with treq. - start = timeit.default_timer() - self.logger.info('Downloading output file %s to %s.', url, dest) - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(dest, 'wb') as outfile: - shutil.copyfileobj(r.raw, outfile) - self.logger.info('Saved output file: "%s" in %s s.', - dest, timeit.default_timer() - start) - def summarize(self): time_elapsed = timeit.default_timer() - self.created_at self.logger.info('Finished %s jobs in %s seconds.', @@ -264,17 +252,6 @@ def summarize(self): len(self.all_jobs), self.start_delay, uuid.uuid4().hex) output_filepath = os.path.join(settings.OUTPUT_DIR, output_filepath) - if self.download_results: - for job in self.all_jobs: - try: - basename = job.output_url.split('/')[-1] - filepath = os.path.join(get_download_path(), basename) - self.download_file_from_url(job.output_url, filepath) - except Exception as err: # pylint: disable=broad-except - self.logger.error('Could not download %s due to %s: %s. ', - job.output_url, type(err).__name__, err) - continue - with open(output_filepath, 'w') as jsonfile: json.dump(jsondata, jsonfile, indent=4) self.logger.info('Wrote job data as JSON to %s.', output_filepath) diff --git a/kiosk_client/manager_test.py b/kiosk_client/manager_test.py index 597dd41..c254e80 100644 --- a/kiosk_client/manager_test.py +++ b/kiosk_client/manager_test.py @@ -169,25 +169,13 @@ def fake_download_file_bad(url, dest): # monkey-patches for testing mgr.cost_getter.finish = lambda: (1, 2, 3) mgr.upload_file = fake_upload_file - mgr.download_file_from_url = fake_download_file mgr.summarize() # test Exceptions mgr.cost_getter.finish = lambda: 0 / 1 mgr.upload_file = fake_upload_file_bad - mgr.download_file_from_url = fake_download_file_bad mgr.summarize() - def test_download_file_from_url(self, tmpdir, requests_mock): - tmpdir = str(tmpdir) - expected_file = os.path.join(tmpdir, 'downloaded.txt') - requests_mock.get('http://test.com', text='data') - - mgr = manager.BenchmarkingJobManager(host='localhost', job_type='job') - mgr.download_file_from_url('http://test.com', expected_file) - assert os.path.exists(expected_file) - assert os.path.isfile(expected_file) - @pytest_twisted.inlineCallbacks def test_check_job_status(self): mgr = manager.JobManager( diff --git a/requirements.txt b/requirements.txt index 77c48a7..016ed64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ Pillow>=6.2.0 python-decouple>=3.1,<4 python-dateutil==2.8.0 treq==20.3.0 -twisted>=20.3.0 +twisted[tls]>=20.3.0 diff --git a/setup.py b/setup.py index f4f0118..1be69d9 100644 --- a/setup.py +++ b/setup.py @@ -41,13 +41,13 @@ 'python-decouple', 'python-dateutil==2.8.0', 'treq==20.3.0', - 'twisted>=20.3.0'], + 'twisted[tls]>=20.3.0'], extras_require={ 'tests': ['pytest', 'pytest-twisted', 'pytest-pep8', 'pytest-cov', - 'requests-mock', + 'pytest-mock', 'attrs>=19.2.0'], }, packages=find_packages())