diff --git a/planet/api/__init__.py b/planet/api/__init__.py index 08829dd4b..840cd0f27 100644 --- a/planet/api/__init__.py +++ b/planet/api/__init__.py @@ -14,12 +14,12 @@ from .exceptions import (APIException, BadQuery, InvalidAPIKey) from .exceptions import (NoPermission, MissingResource, OverQuota) -from .exceptions import (ServerError,) +from .exceptions import (ServerError, RequestCancelled) from .client import Client from .utils import write_to_file __all__ = [ Client, APIException, BadQuery, InvalidAPIKey, - NoPermission, MissingResource, OverQuota, ServerError, + NoPermission, MissingResource, OverQuota, ServerError, RequestCancelled, write_to_file ] diff --git a/planet/api/client.py b/planet/api/client.py index 2b1cc0b48..ff5bc43da 100644 --- a/planet/api/client.py +++ b/planet/api/client.py @@ -35,6 +35,9 @@ def __init__(self, api_key=None, base_url='https://api.planet.com/v0/', self.base_url = base_url self.dispatcher = RequestsDispatcher(workers) + def shutdown(self): + self.dispatcher.session.executor.shutdown(wait=False) + def _url(self, path): if path.startswith('http'): url = path diff --git a/planet/api/exceptions.py b/planet/api/exceptions.py index 064a32214..c40003c8c 100644 --- a/planet/api/exceptions.py +++ b/planet/api/exceptions.py @@ -45,3 +45,7 @@ class ServerError(APIException): class InvalidIdentity(APIException): '''raised when logging in with identity''' pass + + +class RequestCancelled(Exception): + '''When requests get cancelled''' diff --git a/planet/api/models.py b/planet/api/models.py index aad7e59d4..bb9685b35 100644 --- a/planet/api/models.py +++ b/planet/api/models.py @@ -13,6 +13,7 @@ # limitations under the License. from ._fatomic import atomic_open +from .exceptions import RequestCancelled from .utils import get_filename from .utils import check_status from .utils import GeneratorAdapter @@ -30,6 +31,7 @@ def __init__(self, request, dispatcher): self._dispatcher = dispatcher self._body = None self._future = None + self._cancel = False def _create_body(self, response): return self.request.body_type(self.request, response, self._dispatcher) @@ -45,6 +47,8 @@ def get_body(self): return self._body def _async_callback(self, session, response): + if self._cancel: + raise RequestCancelled() check_status(response) self._body = self._create_body(response) self._handler(self._body) @@ -68,6 +72,13 @@ def await(self): self._future.result() return self._body + def cancel(self): + '''Cancel any request.''' + if self._body: + self._body._cancel = True + else: + self._cancel = True + class Request(object): @@ -86,6 +97,7 @@ def __init__(self, request, http_response, dispatcher): self._dispatcher = dispatcher self.size = int(self.response.headers.get('content-length', 0)) self.name = get_filename(self.response) + self._cancel = False def __len__(self): return self.size @@ -108,6 +120,8 @@ def _write(self, fp, callback): callback = lambda x: None callback(self) for chunk in self: + if self._cancel: + raise RequestCancelled() fp.write(chunk) size = len(chunk) total += size diff --git a/planet/api/sync.py b/planet/api/sync.py index dbb2a97b6..2dced1adf 100644 --- a/planet/api/sync.py +++ b/planet/api/sync.py @@ -13,13 +13,19 @@ # limitations under the License. import itertools import json +import logging import os from os import path import threading from ._fatomic import atomic_open -from .utils import write_to_file +from . import exceptions +from .utils import complete from .utils import strp_timestamp from .utils import strf_timestamp +from .utils import write_to_file + + +_logger = logging.getLogger(__name__) class _SyncTool(object): @@ -35,6 +41,8 @@ def __init__(self, client, destination, aoi, scene_type, products, self.workspace = filters.get('workspace', None) self._init() self.sync_file = path.join(self.destination, 'sync.json') + self.error_handler = _logger.exception + self._cancel = False def _init(self): dest = self.destination @@ -87,7 +95,7 @@ def sync(self, callback): summary = _SyncSummary(self._scene_count * len(self.products)) all_scenes = self.get_scenes_to_sync() - while True: + while not self._cancel: # bite of chunks of work to not bog down on too many queued jobs scenes = list(itertools.islice(all_scenes, 100)) if not scenes: @@ -100,10 +108,9 @@ def sync(self, callback): for h in handlers: h.run(self.client, self.scene_type, self.products) # synchronously await them and then write metadata - for h in handlers: - h.finish() + complete(handlers, self._future_handler, self.client) - if summary.latest: + if summary.latest and not self._cancel: sync = self._read_sync_file() sync['latest'] = strf_timestamp(summary.latest) with atomic_open(self.sync_file, 'wb') as fp: @@ -111,6 +118,16 @@ def sync(self, callback): return summary + def _future_handler(self, futures): + for f in futures: + try: + f.finish() + except exceptions.RequestCancelled: + self._cancel = True + break + except: + self.error_handler('Unexpected error') + class _SyncSummary(object): '''Keep track of summary state, thread safe.''' @@ -137,19 +154,40 @@ def __init__(self, destination, summary, metadata, user_callback): self.summary = summary self.metadata = metadata self.user_callback = user_callback or (lambda *args: None) + self._cancel = False + self.futures = [] def run(self, client, scene_type, products): - self.futures = [] + '''start asynchronous execution, must call finish to await''' + if self._cancel: + return for product in products: self.futures.extend(client.fetch_scene_geotiffs( [self.metadata['id']], scene_type, product, callback=self)) + def cancel(self): + '''cancel pending downloads''' + self._cancel = True + futures = getattr(self, 'futures', []) + for f in futures: + f.cancel() + def finish(self): + '''await pending downloads and write out metadata + @todo this is not an atomic operation - it's possible that one + product gets downloaded and the other fails. + ''' + if self._cancel: + return + for f in self.futures: f.await() + if self._cancel: + return + # write out metadata metadata = os.path.join(self.destination, '%s_metadata.json' % self.metadata['id']) diff --git a/planet/api/utils.py b/planet/api/utils.py index 7383c5846..6ee0db6ba 100644 --- a/planet/api/utils.py +++ b/planet/api/utils.py @@ -17,6 +17,7 @@ import json import os import re +import threading from ._fatomic import atomic_open _ISO_FMT = '%Y-%m-%dT%H:%M:%S.%f+00:00' @@ -187,3 +188,33 @@ def probably_geojson(input): ]) valid = typename in supported_types return input if valid else None + + +def complete(futures, check, client): + '''Wait for the future requests to complete without blocking the main + thread. This is a means to intercept a KeyboardInterrupt and gracefully + shutdown current requests without waiting for them to complete. + + The cancel function on each future object should abort processing - any + blocking functions/IO will not be interrupted and this function should + return immediately. + + :param futures: sequence of objects with a cancel function + :param check: function that will be called with the provided futures from + a background thread + :param client: the Client to termintate on interrupt + ''' + # start a background thread to not block main (otherwise hangs on 2.7) + def run(): + check(futures) + t = threading.Thread(target=run) + t.start() + # poll (or we miss the interrupt) and await completion + try: + while t.isAlive(): + t.join(.1) + except KeyboardInterrupt: + for f in futures: + f.cancel() + client.shutdown() + raise diff --git a/planet/scripts/__init__.py b/planet/scripts/__init__.py index 730932679..83aab4a3d 100644 --- a/planet/scripts/__init__.py +++ b/planet/scripts/__init__.py @@ -25,6 +25,7 @@ import planet from planet.api.sync import _SyncTool from planet import api +from planet.api.utils import complete from requests.packages.urllib3 import exceptions as urllib3exc @@ -110,6 +111,8 @@ def check_futures(futures): click_exception(invalid) except api.APIException as other: click.echo('WARNING %s' % other.message) + except api.RequestCancelled: + pass def summarize_throughput(bytes, start_time): @@ -303,9 +306,10 @@ def fetch_scene_geotiff(scene_ids, scene_type, product, dest): return start_time = time.time() - futures = client().fetch_scene_geotiffs(scene_ids, scene_type, product, - api.utils.write_to_file(dest)) - check_futures(futures) + cl = client() + futures = cl.fetch_scene_geotiffs(scene_ids, scene_type, product, + api.utils.write_to_file(dest)) + complete(futures, check_futures, cl) summarize_throughput(total_bytes(futures), start_time) @@ -324,9 +328,10 @@ def fetch_scene_thumbnails(scene_ids, scene_type, size, fmt, dest): if not scene_ids: return - futures = client().fetch_scene_thumbnails(scene_ids, scene_type, size, fmt, - api.write_to_file(dest)) - check_futures(futures) + cl = client() + futures = cl.fetch_scene_thumbnails(scene_ids, scene_type, size, fmt, + api.write_to_file(dest)) + complete(futures, check_futures, cl) @scene_type @@ -426,9 +431,10 @@ def download_quads(mosaic_name, quad_ids, dest): Download quad geotiffs """ quad_ids = read(quad_ids, split=True) - futures = call_and_wrap(client().fetch_mosaic_quad_geotiffs, mosaic_name, + cl = client() + futures = call_and_wrap(cl.fetch_mosaic_quad_geotiffs, mosaic_name, quad_ids, api.write_to_file(dest)) - check_futures(futures) + complete(futures, check_futures, cl) @pretty