From 8a7b933f0d7718c27bc303e11029fafd67764d63 Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Mon, 15 Mar 2021 20:22:03 -0500 Subject: [PATCH 1/3] Some refactoring to reduce code complexity: * Pass per-request expiration in request params instead of setting as a temporary instance variable * Make use of dict ordering from python3.6+ in _normalize_parameters() * Add some more type annotations to `CachedSession` methods * Remove `expires_before` param from remove_old_entries, and always use the current time * Remove `relative_to` param from `CachedSession._determine_expiration_datetime` and use mock values in unit tests instead --- requests_cache/backends/base.py | 17 ++- requests_cache/core.py | 235 ++++++++++++-------------------- 2 files changed, 97 insertions(+), 155 deletions(-) diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 06d7f745..f97d522e 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -113,15 +113,11 @@ def clear(self): self.responses.clear() self.keys_map.clear() - def remove_old_entries(self, expires_before): - """Deletes entries from cache with expiration time older than ``expires_before``""" - if expires_before.tzinfo is None: - # if expires_before is not timezone-aware, assume local time - expires_before = expires_before.astimezone() - + def remove_old_entries(self): + """Deletes expired entries from the cache""" keys_to_delete = set() for key, (response, _) in self.responses.items(): - if response.expiration_date is not None and response.expiration_date < expires_before: + if _is_expired(response): keys_to_delete.add(key) for key in keys_to_delete: @@ -279,5 +275,12 @@ def read(self, chunk_size=1): return self._io_with_content_.read(chunk_size) +def _is_expired(response): + """Check a cached response to see if it's expired""" + if getattr(response, 'expire_after', None) is not None: + return datetime.now(timezone.utc) > response.expire_after + return False + + def _to_bytes(s, encoding='utf-8'): return s if isinstance(s, bytes) else bytes(s, encoding) diff --git a/requests_cache/core.py b/requests_cache/core.py index 5e68880e..af60a4be 100644 --- a/requests_cache/core.py +++ b/requests_cache/core.py @@ -4,19 +4,22 @@ Core functions for configuring cache and monkey patching ``requests`` """ +from collections.abc import Mapping from contextlib import contextmanager from datetime import datetime, timedelta, timezone from operator import itemgetter -from typing import Callable, Iterable, Union +from typing import Any, Callable, Dict, Iterable, Optional, Union import requests from requests import Session as OriginalSession from requests.hooks import dispatch_hook -from requests_cache.backends.base import BACKEND_KWARGS +from requests_cache.backends.base import BACKEND_KWARGS, _is_expired from . import backends +ExpirationTime = Union[None, int, float, datetime, timedelta] + class CacheMixin: """Mixin class that extends ``requests.Session`` with caching features. @@ -59,15 +62,11 @@ def __init__( allowable_methods: Iterable['str'] = ('GET', 'HEAD'), filter_fn: Callable = None, old_data_on_error: bool = False, - **kwargs + **kwargs, ): self.cache = backends.create_backend(backend, cache_name, kwargs) self._cache_name = cache_name - - if expire_after is not None and not isinstance(expire_after, timedelta): - expire_after = timedelta(seconds=expire_after) - self._cache_expire_after = expire_after - self._request_expire_after = 'default' + self.expire_after = _get_timedelta(expire_after) self._cache_allowable_codes = allowable_codes self._cache_allowable_methods = allowable_methods @@ -79,174 +78,100 @@ def __init__( session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS} super().__init__(**session_kwargs) - def _determine_expiration_datetime(self, relative_to=None): - """Determines the absolute expiration datetime for a response. - Requires :attr:`self._cache_expire_after` and :attr:`self._request_expire_after` to be set. - See :meth:`request` for more information. - - :param response: the response (potentially loaded from the cache) - :type response: requests.Response - :param relative_to: Parameter for easy unit testing to fix ``now``, - defaults to ``datetime.now(timezone.utc)`` for normal use. - :type relative_to: Union[None, datetime.datetime] - :return: The absolute expiration date - :rtype: datetime.datetime - """ - now = datetime.now(timezone.utc) if relative_to is None else relative_to - - cache_expire_after = self._cache_expire_after - request_expire_after = self._request_expire_after - - def to_absolute(expire_after): - if expire_after is None: - return None - if isinstance(expire_after, timedelta): - return now + expire_after - if isinstance(expire_after, datetime): - return expire_after - return now + timedelta(seconds=expire_after) - - if request_expire_after == 'default': - return to_absolute(cache_expire_after) - return to_absolute(request_expire_after) - def send(self, request, **kwargs): - do_not_cache = ( - self._is_cache_disabled - or request.method not in self._cache_allowable_methods - or self._request_expire_after is None - ) + _request_expire_after = kwargs.get('params', {}).pop('_request_expire_after', None) + expire_after = _request_expire_after or self.expire_after + + # If we shouldn't cache the response, just send the request + do_not_cache = self._is_cache_disabled or request.method not in self._cache_allowable_methods if do_not_cache: response = super().send(request, **kwargs) response.from_cache = False response.cache_date = None - response.expiration_date = None - response.expire_after = 'default' + response.expire_after = None return response + # If a response isn't already cached, send the request and cache the response cache_key = self.cache.create_key(request) - try: response, timestamp = self.cache.get_response_and_time(cache_key) except (ImportError, TypeError): response, timestamp = None, None - if response is None: - return self.send_request_and_cache_response(request, cache_key, **kwargs) - - if getattr(response, 'expiration_date', None) is not None: - now = datetime.now(timezone.utc) - is_expired = now > response.expiration_date - else: - is_expired = False + return self.send_request_and_cache_response(request, cache_key, expire_after, **kwargs) - cache_invalid = response.expire_after != self._request_expire_after and self._request_expire_after != 'default' - if cache_invalid or is_expired: - if not self._return_old_data_on_error: - self.cache.delete(cache_key) - return self.send_request_and_cache_response(request, cache_key, **kwargs) + # If the cached response is invalid, send the request and cache the response + if _is_expired(response): try: - new_response = self.send_request_and_cache_response(request, cache_key, **kwargs) + new_response = self.send_request_and_cache_response(request, cache_key, expire_after, **kwargs) + self.cache.delete(cache_key) + return new_response except Exception: - return response - else: - if new_response.status_code not in self._cache_allowable_codes: + # Return the expired/invalid response on error, if specified + if self._return_old_data_on_error: return response - return new_response + self.cache.delete(cache_key) + raise - # dispatch hook here, because we've removed it before pickling + # Dispatch hook here, because we've removed it before pickling response.from_cache = True response.cache_date = timestamp response = dispatch_hook('response', request.hooks, response, **kwargs) return response - def request(self, method, url, params=None, data=None, expire_after='default', **kwargs): - """This method prepares and sends a request while automatically - performing any necessary caching operations. - - If a cache is installed, whenever a standard ``requests`` function is - called, e.g. :func:`requests.get`, this method is called to handle caching - and calling the original :func:`requests.request` method. - - This method adds an additional keyword argument to :func:`requests.request`, ``expire_after``. - It is used to set the expiry time for a specific request to override - the cache default, and can be omitted on subsequent calls. Subsequent - calls with different values invalidate the cache, calls with the same values (or without any values) don't. - - Given - - - the `expire_after` from the installed cache (the ``'default'``) - - the `expire_after` passed to an individual request - - the `expire_after` stored inside the cache - - the following rules hold for which `expire_after` is used: - - +-----------------------------------+------------------------------+ - | | | request(..., expire_after=X) | - +=======================+===========+===============+==============+ - | | | 'default' | other | - +-----------------------+-----------+---------------+--------------+ - | response.expire_after | 'default' | cache default | from request | - | +-----------+---------------+--------------+ - | | other | cache default | from request | - +-----------------------+-----------+---------------+--------------+ + def send_request_and_cache_response(self, request, cache_key, expire_after, **kwargs): + response = super().send(request, **kwargs) + response.from_cache = False + response.cache_date = None - That is, if the request's ``expire_after`` is set to ``'default'`` - (which is the default value) the default caching behavior is used. + # Cache the response, if possible + if response.status_code in self._cache_allowable_codes: + response.expire_after = _get_absolute_time(expire_after) # type: ignore + self.cache.save_response(cache_key, response) + return response - Whenever the request's expire_after is anything else (a number, None, - datetime, or timedelta), that value will be used. + def request( + self, + method: str, + url: str, + params: dict = None, + data: Any = None, + expire_after: ExpirationTime = None, + **kwargs, + ) -> requests.Response: + """This method prepares and sends a request while automatically performing any necessary + caching operations. This will be called by any other method-specific ``requests`` functions + (get, post, etc.). In all cases, if the value is an explicit datetime it returned as is. If it is None, it is also returned as is and caches forever. All other values will be considered a relative time in the future. - :param expire_after: Specifies when the cache for a particular response - expires. Accepts multiple argument types: - - - ``'default'`` to use the default expiry from the installed cache. This is the default. - - :const:`None` to disable caching for this request - - :class:`~datetime.timedelta` to set relative expiry times - - :class:`float` values as time in seconds for :class:`~datetime.timedelta` - - :class:`~datetime.datetime` to set an explicit expiration date + Args: + expire_after: Expiration time to set only for this request; see details below. + Overrides ``CachedSession.expire_after``. Accepts all the same types as + ``CachedSession.expire_after`` except for ``None``; use + ``CachedSession.cache_disabled`` to disable caching on a per-request basis. - :type expire_after: Union[None, str, float, datetime.timedelta, datetime.datetime] + Returns: + Either a new or cached response """ - self._request_expire_after = expire_after # store expire_after so we can handle it in the send-method - response = super().request(method, url, _normalize_parameters(params), _normalize_parameters(data), **kwargs) + # Store expire_after to be used by send() + params = _normalize_parameters(params) + params['_request_expire_after'] = expire_after + response = super().request(method, url, params, _normalize_parameters(data), **kwargs) if self._is_cache_disabled: - try: - return response - finally: - self._request_expire_after = 'default' - + return response main_key = self.cache.create_key(response.request) - # If self._return_old_data_on_error is set, - # responses won't always have the from_cache attribute. - if hasattr(response, "from_cache") and not response.from_cache and self._filter_fn(response) is not True: + # If self._return_old_data_on_error is set, responses may not have the from_cache attribute + if hasattr(response, "from_cache") and not response.from_cache and not self._filter_fn(response): self.cache.delete(main_key) - try: - return response - finally: - self._request_expire_after = 'default' + return response for r in response.history: self.cache.add_key_mapping(self.cache.create_key(r.request), main_key) - try: - return response - finally: - self._request_expire_after = 'default' - - def send_request_and_cache_response(self, request, cache_key, **kwargs): - response = super().send(request, **kwargs) - if response.status_code in self._cache_allowable_codes: - response.expire_after = self._request_expire_after - response.expiration_date = self._determine_expiration_datetime() - self.cache.save_response(cache_key, response) - response.from_cache = False - response.cache_date = None return response @contextmanager @@ -270,14 +195,12 @@ def cache_disabled(self): def remove_expired_responses(self): """Removes expired responses from storage""" - self.cache.remove_old_entries(datetime.now(timezone.utc)) + self.cache.remove_old_entries() def __repr__(self): - return "" % ( - self.cache.__class__.__name__, - self._cache_name, - self._cache_expire_after, - self._cache_allowable_methods, + return ( + f"" ) @@ -294,7 +217,7 @@ def install_cache( filter_fn: Callable = None, old_data_on_error: bool = False, session_factory=CachedSession, - **kwargs + **kwargs, ): """ Installs cache for all ``Requests`` requests by monkey-patching ``Session`` @@ -318,7 +241,7 @@ def __init__(self): allowable_methods=allowable_methods, filter_fn=filter_fn, old_data_on_error=old_data_on_error, - **kwargs + **kwargs, ) _patch_session_factory(_ConfiguredCachedSession) @@ -395,14 +318,30 @@ def remove_expired_responses(): return requests.Session().remove_expired_responses() +def _get_absolute_time(expire_after: Union[int, float, datetime, timedelta]) -> Optional[datetime]: + """Convert a time value to an absolute datetime, if it's not already""" + if isinstance(expire_after, datetime): + return expire_after + if expire_after is None: + return None + return datetime.now(timezone.utc) + _get_timedelta(expire_after) # type: ignore + + +def _get_timedelta(expire_after: Union[int, float, timedelta] = None) -> Optional[timedelta]: + """Convert a time value to a timedelta, if it's not already""" + if expire_after is not None and not isinstance(expire_after, timedelta): + expire_after = timedelta(seconds=expire_after) + return expire_after + + def _patch_session_factory(session_factory=CachedSession): requests.Session = requests.sessions.Session = session_factory -def _normalize_parameters(params): +def _normalize_parameters(params: Optional[Dict]) -> Dict: """If builtin dict is passed as parameter, returns sorted list of key-value pairs """ - if type(params) is dict: - return sorted(params.items(), key=itemgetter(0)) - return params + if isinstance(params, Mapping): + return dict(sorted(params.items(), key=itemgetter(0))) + return params or {} From aa9579ddf3108f795767f341258395aa10ee8f45 Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Thu, 18 Mar 2021 17:20:15 -0500 Subject: [PATCH 2/3] Consolidate expiration, pre-serializiation, and other response object logic into CachedResponse class: * Replace `_RawStore` with `CachedHTTPResponse` class to wrap raw responses * Maintain support for streaming requests (#68) * Improve handling for generator usage * Add support for use with `pandas.read_csv()` and similar readers (#148) * Add support for use as a context manager (#148) * Add support for `decode_content` arg * Fix streaming requests when used with memory backend (#188) * Verified that `PreparedRequest.body` is always encoded in utf-8, so no need to detect encoding (Re: TODO note) * Response creation time and expiration time are stored as CachedResponse, so the `(response, timestamp)` tuple is no longer necessary * Rename `response.expire_after` and `response.cache_date` to `expires` and `created_at`, respectively, based on browser cache directives * Add optional `expire_after` param to `CachedSession.remove_old_responses()` * Make `CachedSession` members `allowable_codes, allowable_methods, filter_fn, old_data_on_error` public, since they can safely be modified after initialization * More type annotations and docstring updates * Move main cache documentation from `CacheMixin` to CachedSession`, since that's probably where a user would look first * Wrap temporary `_request_expire_after` in a contextmanager * Add intersphinx links for `urllib` classes & methods * Fix linting issues raised by flake8 * Start adding some unit tests using requests-mock tmp --- requests_cache/__init__.py | 7 +- requests_cache/backends/__init__.py | 23 +- requests_cache/backends/base.py | 292 ++++++++----------------- requests_cache/core.py | 323 +++++++++++++++------------- requests_cache/response.py | 156 ++++++++++++++ 5 files changed, 452 insertions(+), 349 deletions(-) create mode 100755 requests_cache/response.py diff --git a/requests_cache/__init__.py b/requests_cache/__init__.py index 067506f7..c4b58782 100644 --- a/requests_cache/__init__.py +++ b/requests_cache/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# flake8: noqa: E402,F401 """ requests_cache ~~~~~~~~~~~~~~ @@ -18,15 +19,17 @@ # will take approximately 5 seconds instead 50 - :copyright: (c) 2012 by Roman Haritonov. + :copyright: (c) 2021 by Roman Haritonov. :license: BSD, see LICENSE for more details. """ __docformat__ = 'restructuredtext' __version__ = '0.6.0' -# Quietly ignore importerror, if setup.py is invoked outside a virtualenv +# Quietly ignore ImportError, if setup.py is invoked outside a virtualenv try: + from .response import AnyResponse, CachedHTTPResponse, CachedResponse, ExpirationTime from .core import ( + ALL_METHODS, CachedSession, CacheMixin, clear, diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py index 178ef086..6d1282d1 100644 --- a/requests_cache/backends/__init__.py +++ b/requests_cache/backends/__init__.py @@ -1,13 +1,28 @@ -# noqa: F401 +# flake8: noqa: F401 """ requests_cache.backends ~~~~~~~~~~~~~~~~~~~~~~~ Classes and functions for cache persistence """ - - -from .base import BACKEND_KWARGS, BaseCache +from .base import BaseCache + +# All backend-specific keyword arguments combined +BACKEND_KWARGS = [ + 'connection', + 'db_name', + 'endpont_url', + 'extension', + 'fast_save', + 'ignored_parameters', + 'include_get_headers', + 'location', + 'name', + 'namespace', + 'read_capacity_units', + 'region_name', + 'write_capacity_units', +] registry = { 'memory': BaseCache, diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index f97d522e..5698c4af 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python """ requests_cache.backends.base ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -7,34 +6,19 @@ extended to support persistence. """ import hashlib -from copy import copy -from datetime import datetime, timezone -from io import BytesIO +import json +from pickle import PickleError from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import requests -# All backend-specific keyword arguments combined -BACKEND_KWARGS = [ - 'connection', - 'db_name', - 'endpont_url', - 'extension', - 'fast_save', - 'ignored_parameters', - 'include_get_headers', - 'location', - 'name', - 'namespace', - 'read_capacity_units', - 'region_name', - 'write_capacity_units', -] +from ..response import AnyResponse, CachedResponse, ExpirationTime + DEFAULT_HEADERS = requests.utils.default_headers() class BaseCache(object): - """Base class for cache implementations, can be used as in-memory cache. + """Base class for cache implementations, which can also be used as in-memory cache. To extend it you can provide dictionary-like objects for :attr:`keys_map` and :attr:`responses` or override public methods. @@ -48,239 +32,151 @@ def __init__(self, *args, **kwargs): self._include_get_headers = kwargs.get("include_get_headers", False) self._ignored_parameters = set(kwargs.get("ignored_parameters") or []) - def save_response(self, key, response): + def save_response(self, key: str, response: AnyResponse, expire_after: ExpirationTime = None): """Save response to cache - :param key: key for this response - :param response: response to save - - .. note:: Response is reduced before saving (with :meth:`reduce_response`) - to make it picklable + Args: + key: key for this response + response: response to save + expire_after: Time in seconds until this cache item should expire """ - self.responses[key] = self.reduce_response(response), datetime.now(timezone.utc) + self.responses[key] = CachedResponse(response, expire_after=expire_after) - def add_key_mapping(self, new_key, key_to_response): + def add_key_mapping(self, new_key: str, key_to_response: str): """ Adds mapping of `new_key` to `key_to_response` to make it possible to associate many keys with single response - :param new_key: new key (e.g. url from redirect) - :param key_to_response: key which can be found in :attr:`responses` - :return: + Args: + new_key: New resource key (e.g. url from redirect) + key_to_response: Key which can be found in :attr:`responses` """ self.keys_map[new_key] = key_to_response - def get_response_and_time(self, key, default=(None, None)): - """Retrieves response and timestamp for `key` if it's stored in cache, - otherwise returns `default` + def get_response(self, key: str, default=None) -> CachedResponse: + """Retrieves response for `key` if it's stored in cache, otherwise returns `default` - :param key: key of resource - :param default: return this if `key` not found in cache - :returns: tuple (response, datetime) - - .. note:: Response is restored after unpickling with :meth:`restore_response` + Args: + key: Key of resource + default: Value to return if `key` is not in cache """ try: if key not in self.responses: key = self.keys_map[key] - response, timestamp = self.responses[key] - except KeyError: + response = self.responses[key] + response.reset() # In case response was in memory and raw content has already been read + return response + except (KeyError, TypeError, PickleError): return default - return self.restore_response(response), timestamp - def delete(self, key): + def delete(self, key: str): """Delete `key` from cache. Also deletes all responses from response history""" try: if key in self.responses: - response, _ = self.responses[key] + response = self.responses[key] del self.responses[key] else: - response, _ = self.responses[self.keys_map[key]] + response = self.responses[self.keys_map[key]] del self.keys_map[key] for r in response.history: del self.keys_map[self.create_key(r.request)] except KeyError: pass - def delete_url(self, url): + def delete_url(self, url: str): """Delete response associated with `url` from cache. Also deletes all responses from response history. Works only for GET requests """ self.delete(self._url_to_key(url)) def clear(self): - """Clear cache""" + """Delete all items from the cache""" self.responses.clear() self.keys_map.clear() - def remove_old_entries(self): - """Deletes expired entries from the cache""" - keys_to_delete = set() - for key, (response, _) in self.responses.items(): - if _is_expired(response): - keys_to_delete.add(key) + def remove_expired_responses(self, expire_after: ExpirationTime = None): + """Remove expired responses from the cache, optionally with revalidation - for key in keys_to_delete: - self.delete(key) - - def has_key(self, key): + Args: + expire_after: A new expiration time used to revalidate the cache + """ + for key, response in list(self.responses.items()): + # If we're revalidating and it's not yet expired, update the cached item's expiration + if expire_after is not None and not response.revalidate(expire_after): + self.responses[key] = response + if response.is_expired: + self.delete(key) + + def has_key(self, key: str) -> bool: """Returns `True` if cache has `key`, `False` otherwise""" return key in self.responses or key in self.keys_map - def has_url(self, url): - """Returns `True` if cache has `url`, `False` otherwise. - Works only for GET request urls - """ - return self.has_key(self._url_to_key(url)) + def has_url(self, url: str) -> bool: + """Returns `True` if cache has `url`, `False` otherwise. Works only for GET request urls""" + return self.has_key(self._url_to_key(url)) # noqa: W601 - def _url_to_key(self, url): + def _url_to_key(self, url: str) -> str: session = requests.Session() return self.create_key(session.prepare_request(requests.Request('GET', url))) - _response_attrs = [ - '_content', - 'url', - 'status_code', - 'cookies', - 'headers', - 'encoding', - 'request', - 'reason', - 'raw', - 'expiration_date', - 'expire_after', - ] - - _raw_response_attrs = [ - '_original_response', - 'decode_content', - 'headers', - 'reason', - 'status', - 'strict', - 'version', - ] - - def reduce_response(self, response, seen=None): - """Reduce response object to make it compatible with ``pickle``""" - if seen is None: - seen = {} - try: - return seen[id(response)] - except KeyError: - pass - result = _Store() - # prefetch - content = response.content - for field in self._response_attrs: - setattr(result, field, self._picklable_field(response, field)) - seen[id(response)] = result - result.history = tuple(self.reduce_response(r, seen) for r in response.history) - # Emulate stream fp is not consumed yet. See #68 - if response.raw is not None: - response.raw._fp = BytesIO(content) - return result - - def _picklable_field(self, response, name): - value = getattr(response, name, None) - if name == 'request': - value = copy(value) - value.hooks = [] - elif name == 'raw': - result = _RawStore() - for field in self._raw_response_attrs: - setattr(result, field, getattr(value, field, None)) - if result._original_response is not None: - setattr(result._original_response, "fp", None) # _io.BufferedReader is not picklable - value = result - return value - - def restore_response(self, response, seen=None): - """Restore response object after unpickling""" - if seen is None: - seen = {} - try: - return seen[id(response)] - except KeyError: - pass - result = requests.Response() - for field in self._response_attrs: - setattr(result, field, getattr(response, field, None)) - result.raw._cached_content_ = result.content - seen[id(response)] = result - result.history = tuple(self.restore_response(r, seen) for r in response.history) - return result - - def _remove_ignored_parameters(self, request): - def filter_ignored_parameters(data): - return [(k, v) for k, v in data if k not in self._ignored_parameters] - - url = urlparse(request.url) - query = parse_qsl(url.query) - query = filter_ignored_parameters(query) - query = urlencode(query) - url = urlunparse((url.scheme, url.netloc, url.path, url.params, query, url.fragment)) - body = request.body - content_type = request.headers.get('content-type') - if body and content_type: - if content_type == 'application/x-www-form-urlencoded': - body = parse_qsl(body) - body = filter_ignored_parameters(body) - body = urlencode(body) - elif content_type == 'application/json': - import json - - if isinstance(body, bytes): - body = str(body, "utf8") # TODO how to get body encoding? - body = json.loads(body) - body = filter_ignored_parameters(sorted(body.items())) - body = json.dumps(body) - return url, body - - def create_key(self, request): - if self._ignored_parameters: - url, body = self._remove_ignored_parameters(request) - else: - url, body = request.url, request.body + def create_key(self, request: requests.PreparedRequest) -> str: + url = self._remove_ignored_url_parameters(request) + body = self._remove_ignored_body_parameters(request) key = hashlib.sha256() - key.update(_to_bytes(request.method.upper())) - key.update(_to_bytes(url)) - if request.body: - key.update(_to_bytes(body)) + key.update(_encode(request.method.upper())) + key.update(_encode(url)) + + if body: + key.update(_encode(body)) else: if self._include_get_headers and request.headers != DEFAULT_HEADERS: for name, value in sorted(request.headers.items()): - key.update(_to_bytes(name)) - key.update(_to_bytes(value)) + key.update(_encode(name)) + key.update(_encode(value)) return key.hexdigest() - def __str__(self): - return 'keys: %s\nresponses: %s' % (self.keys_map, self.responses) - - -# used for saving response attributes -class _Store(object): - pass + def _remove_ignored_url_parameters(self, request: requests.PreparedRequest) -> str: + url = str(request.url) + if not self._ignored_parameters: + return url + url = urlparse(url) + query = parse_qsl(url.query) + query = self._filter_ignored_parameters(query) + query = urlencode(query) + url = urlunparse((url.scheme, url.netloc, url.path, url.params, query, url.fragment)) + return url -class _RawStore(object): - # noop for cached response - def release_conn(self): - pass + def _remove_ignored_body_parameters(self, request: requests.PreparedRequest) -> str: + body = request.body + content_type = request.headers.get('content-type') + if not self._ignored_parameters or not body or not content_type: + return request.body + + if content_type == 'application/x-www-form-urlencoded': + body = parse_qsl(body) + body = self._filter_ignored_parameters(body) + body = urlencode(body) + elif content_type == 'application/json': + body = json.loads(_decode(body)) + body = self._filter_ignored_parameters(sorted(body.items())) + body = json.dumps(body) + return body + + def _filter_ignored_parameters(self, data): + return [(k, v) for k, v in data if k not in self._ignored_parameters] - # for streaming requests support - def read(self, chunk_size=1): - if not hasattr(self, "_io_with_content_"): - self._io_with_content_ = BytesIO(self._cached_content_) - return self._io_with_content_.read(chunk_size) + def __str__(self): + return f'redirects: {len(self.keys_map)}\nresponses: {len(self.responses)}' -def _is_expired(response): - """Check a cached response to see if it's expired""" - if getattr(response, 'expire_after', None) is not None: - return datetime.now(timezone.utc) > response.expire_after - return False +def _encode(value, encoding='utf-8') -> bytes: + """Encode a value, if it hasn't already been""" + return value if isinstance(value, bytes) else value.encode(encoding) -def _to_bytes(s, encoding='utf-8'): - return s if isinstance(s, bytes) else bytes(s, encoding) +def _decode(value, encoding='utf-8') -> str: + """Decode a value, if hasn't already been. + Note: PreparedRequest.body is always encoded in utf-8. + """ + return value if isinstance(value, str) else value.decode(encoding) diff --git a/requests_cache/core.py b/requests_cache/core.py index af60a4be..2a4c9856 100644 --- a/requests_cache/core.py +++ b/requests_cache/core.py @@ -6,58 +6,30 @@ """ from collections.abc import Mapping from contextlib import contextmanager -from datetime import datetime, timedelta, timezone from operator import itemgetter -from typing import Any, Callable, Dict, Iterable, Optional, Union +from typing import Any, Callable, Dict, Iterable, Optional, Type import requests +from requests import PreparedRequest from requests import Session as OriginalSession from requests.hooks import dispatch_hook -from requests_cache.backends.base import BACKEND_KWARGS, _is_expired - from . import backends +from .response import AnyResponse, ExpirationTime, set_response_defaults -ExpirationTime = Union[None, int, float, datetime, timedelta] +ALL_METHODS = ['GET', 'HEAD', 'OPTIONS', 'POST', 'PUT', 'PATCH', 'DELETE'] class CacheMixin: - """Mixin class that extends ``requests.Session`` with caching features. - - Args: - cache_name: Cache prefix or namespace, depending on backend; see notes below - backend: Cache backend name; one of ``['sqlite', 'mongodb', 'gridfs', 'redis', 'dynamodb', 'memory']``. - Default behavior is to use ``'sqlite'`` if available, otherwise fallback to ``'memory'``. - expire_after: Number of seconds after which a cache entry will expire; set to ``None`` to - never expire - allowable_codes: Only cache responses with one of these codes - allowable_methods: Cache only responses for one of these HTTP methods - include_get_headers: Make request headers part of the cache key - ignored_parameters: List of request parameters to be excluded from the cache key. - filter_fn: function that takes a :py:class:`aiohttp.ClientResponse` object and - returns a boolean indicating whether or not that response should be cached. Will be - applied to both new and previously cached responses - old_data_on_error: Return expired cached responses if new request fails - - See individual backend classes for additional backend-specific arguments. - - The ``cache_name`` parameter will be used as follows depending on the backend: - - * ``sqlite``: Cache filename prefix, e.g ``my_cache.sqlite`` - * ``mongodb``: Database name - * ``redis``: Namespace, meaning all keys will be prefixed with ``'cache_name:'`` - - Note on cache key parameters: Set ``include_get_headers=True`` if you want responses to be - cached under different keys if they only differ by headers. You may also provide - ``ignored_parameters`` to ignore specific request params. This is useful, for example, when - requesting the same resource with different credentials or access tokens. + """Mixin class that extends :py:class:`requests.Session` with caching features. + See :py:class:`.CachedSession` for usage information. """ def __init__( self, cache_name: str = 'cache', backend: str = None, - expire_after: Union[int, float, timedelta] = None, + expire_after: ExpirationTime = None, allowable_codes: Iterable[int] = (200,), allowable_methods: Iterable['str'] = ('GET', 'HEAD'), filter_fn: Callable = None, @@ -65,115 +37,144 @@ def __init__( **kwargs, ): self.cache = backends.create_backend(backend, cache_name, kwargs) - self._cache_name = cache_name - self.expire_after = _get_timedelta(expire_after) + self.allowable_codes = allowable_codes + self.allowable_methods = allowable_methods + self.filter_fn = filter_fn or (lambda r: True) + self.old_data_on_error = old_data_on_error - self._cache_allowable_codes = allowable_codes - self._cache_allowable_methods = allowable_methods - self._filter_fn = filter_fn or (lambda r: True) - self._return_old_data_on_error = old_data_on_error - self._is_cache_disabled = False + self._cache_name = cache_name + self._expire_after = expire_after + self._request_expire_after: ExpirationTime = None + self._disabled = False # Remove any requests-cache-specific kwargs before passing along to superclass - session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS} + session_kwargs = {k: v for k, v in kwargs.items() if k not in backends.BACKEND_KWARGS} super().__init__(**session_kwargs) - def send(self, request, **kwargs): - _request_expire_after = kwargs.get('params', {}).pop('_request_expire_after', None) - expire_after = _request_expire_after or self.expire_after - - # If we shouldn't cache the response, just send the request - do_not_cache = self._is_cache_disabled or request.method not in self._cache_allowable_methods - if do_not_cache: - response = super().send(request, **kwargs) - response.from_cache = False - response.cache_date = None - response.expire_after = None - return response - - # If a response isn't already cached, send the request and cache the response - cache_key = self.cache.create_key(request) - try: - response, timestamp = self.cache.get_response_and_time(cache_key) - except (ImportError, TypeError): - response, timestamp = None, None - if response is None: - return self.send_request_and_cache_response(request, cache_key, expire_after, **kwargs) - - # If the cached response is invalid, send the request and cache the response - if _is_expired(response): - try: - new_response = self.send_request_and_cache_response(request, cache_key, expire_after, **kwargs) - self.cache.delete(cache_key) - return new_response - except Exception: - # Return the expired/invalid response on error, if specified - if self._return_old_data_on_error: - return response - self.cache.delete(cache_key) - raise - - # Dispatch hook here, because we've removed it before pickling - response.from_cache = True - response.cache_date = timestamp - response = dispatch_hook('response', request.hooks, response, **kwargs) - return response + @property + def expire_after(self): + """Get either the per-session expiration, or per-request expiration, if set""" + return self._request_expire_after or self._expire_after - def send_request_and_cache_response(self, request, cache_key, expire_after, **kwargs): - response = super().send(request, **kwargs) - response.from_cache = False - response.cache_date = None + @expire_after.setter + def expire_after(self, value: ExpirationTime): + """Set per-session expiration""" + self._expire_after = value - # Cache the response, if possible - if response.status_code in self._cache_allowable_codes: - response.expire_after = _get_absolute_time(expire_after) # type: ignore - self.cache.save_response(cache_key, response) - return response + @contextmanager + def request_expire_after(self, expire_after: ExpirationTime = None): + """Temporarily override ``expire_after`` for an individual request""" + self._request_expire_after = expire_after + yield + self._request_expire_after = None def request( self, method: str, url: str, - params: dict = None, + params: Dict = None, data: Any = None, expire_after: ExpirationTime = None, **kwargs, - ) -> requests.Response: + ) -> AnyResponse: """This method prepares and sends a request while automatically performing any necessary caching operations. This will be called by any other method-specific ``requests`` functions - (get, post, etc.). + (get, post, etc.). This does not include prepared requests, which will still be cached via + ``send()``. - In all cases, if the value is an explicit datetime it returned as is. - If it is None, it is also returned as is and caches forever. - All other values will be considered a relative time in the future. + See :py:meth:`requests.Session.request` for parameters. Additional parameters: Args: expire_after: Expiration time to set only for this request; see details below. - Overrides ``CachedSession.expire_after``. Accepts all the same types as - ``CachedSession.expire_after`` except for ``None``; use - ``CachedSession.cache_disabled`` to disable caching on a per-request basis. + Overrides ``CachedSession.expire_after``. Accepts all the same values as + ``CachedSession.expire_after`` except for ``None``; use ``-1`` to disable expiration + on a per-request basis. Returns: Either a new or cached response - """ - # Store expire_after to be used by send() - params = _normalize_parameters(params) - params['_request_expire_after'] = expire_after - response = super().request(method, url, params, _normalize_parameters(data), **kwargs) - if self._is_cache_disabled: + **Order of operations:** A request will pass through the following methods: + + 1. :py:func:`requests.get`/:py:meth:`requests.Session.get` or other method-specific functions (optional) + 2. :py:meth:`.CachedSession.request` + 3. :py:meth:`requests.Session.request` + 4. :py:meth:`.CachedSession.send` + 5. :py:meth:`.BaseCache.get_response` + 6. :py:meth:`requests.Session.send` (if not cached) + """ + with self.request_expire_after(expire_after): + response = super().request( + method, + url, + _normalize_parameters(params), + _normalize_parameters(data), + **kwargs, + ) + if self._disabled: return response - main_key = self.cache.create_key(response.request) - # If self._return_old_data_on_error is set, responses may not have the from_cache attribute - if hasattr(response, "from_cache") and not response.from_cache and not self._filter_fn(response): + # If the request has been filtered out, delete previously cached response if it exists + main_key = self.cache.create_key(response.request) + if not response.from_cache and not self.filter_fn(response): self.cache.delete(main_key) return response + # Cache redirect history for r in response.history: self.cache.add_key_mapping(self.cache.create_key(r.request), main_key) return response + def send(self, request: PreparedRequest, **kwargs) -> AnyResponse: + """Send a prepared request, with caching.""" + # If we shouldn't cache the response, just send the request + if not self._is_cacheable(request): + response = super().send(request, **kwargs) + return set_response_defaults(response) + + # Attempt to fetch the cached response + cache_key = self.cache.create_key(request) + try: + response = self.cache.get_response(cache_key) + except (ImportError, TypeError, ValueError): + response = None + + # Attempt to fetch and cache a new response, if needed + if response is None: + return self._send_and_cache(request, cache_key, **kwargs) + if response.is_expired: + return self._handle_expired_response(request, response, cache_key, **kwargs) + + # Dispatch hook here, because we've removed it before pickling + return dispatch_hook('response', request.hooks, response, **kwargs) + + def _is_cacheable(self, request: PreparedRequest) -> bool: + criteria = [ + not self._disabled, + str(request.method) in self.allowable_methods, + self.filter_fn(request), + ] + return all(criteria) + + def _handle_expired_response(self, request, response, cache_key, **kwargs) -> AnyResponse: + """Determine what to do with an expired response, depending on old_data_on_error setting""" + # Attempt to send the request and cache the new response + try: + new_response = self._send_and_cache(request, cache_key, **kwargs) + self.cache.delete(cache_key) + return new_response + # Return the expired/invalid response on error, if specified; otherwise reraise + except Exception: + if self.old_data_on_error: + return response + self.cache.delete(cache_key) + raise + + def _send_and_cache(self, request, cache_key, **kwargs): + response = super().send(request, **kwargs) + if response.status_code in self.allowable_codes: + self.cache.save_response(cache_key, response, self.expire_after) + return set_response_defaults(response) + @contextmanager def cache_disabled(self): """ @@ -184,52 +185,96 @@ def cache_disabled(self): >>> with s.cache_disabled(): ... s.get('http://httpbin.org/ip') """ - if self._is_cache_disabled: + if self._disabled: yield else: - self._is_cache_disabled = True + self._disabled = True try: yield finally: - self._is_cache_disabled = False + self._disabled = False + + def remove_expired_responses(self, expire_after: ExpirationTime = None): + """Remove expired responses from the cache, optionally with revalidation - def remove_expired_responses(self): - """Removes expired responses from storage""" - self.cache.remove_old_entries() + Args: + expire_after: A new expiration time used to revalidate the cache + """ + self.cache.remove_expired_responses(expire_after) def __repr__(self): return ( f"" + f"expire_after={self.expire_after}, allowable_methods={self.allowable_methods})>" ) class CachedSession(CacheMixin, OriginalSession): - pass + """Class that extends ``requests.Session`` with caching features. + See individual backend classes for additional backend-specific arguments. + + Args: + cache_name: Cache prefix or namespace, depending on backend + backend: Cache backend name; one of ``['sqlite', 'mongodb', 'gridfs', 'redis', 'dynamodb', 'memory']``. + Default behavior is to use ``'sqlite'`` if available, otherwise fallback to ``'memory'``. + expire_after: Time after which cached items will expire (see notes below) + allowable_codes: Only cache responses with one of these codes + allowable_methods: Cache only responses for one of these HTTP methods + include_get_headers: Make request headers part of the cache key + ignored_parameters: List of request parameters to be excluded from the cache key + filter_fn: function that takes a :py:class:`aiohttp.ClientResponse` object and + returns a boolean indicating whether or not that response should be cached. Will be + applied to both new and previously cached responses. + old_data_on_error: Return expired cached responses if new request fails + + **Cache Name:** + + The ``cache_name`` parameter will be used as follows depending on the backend: + + * ``sqlite``: Cache filename, e.g ``my_cache.sqlite`` + * ``mongodb``: Database name + * ``redis``: Namespace, meaning all keys will be prefixed with ``'cache_name:'`` + + **Cache Keys:** + + The cache key is a hash created from request information, and is used as an index for cached + responses. There are a couple ways you can customize how the cache key is created: + + * Use ``include_get_headers`` if you want headers to be included in the cache key. In other + words, this will create separate cache items for responses with different headers. + * Use ``ignored_parameters`` to exclude specific request params from the cache key. This is + useful, for example, if you request the same resource with different credentials or access + tokens. + + **Cache Expiration:** + + Use ``expire_after`` to specify how long responses will be cached. This can be a number + (in seconds), a :py:class:`.timedelta`, or a :py:class:`datetime`. Use ``None`` or ``-1`` to + never expire. This will not apply to responses cached in the current session; to apply a + different expiration to previously cached responses, see :py:meth:`remove_expired_responses`. + """ def install_cache( cache_name: str = 'cache', backend: str = None, - expire_after: Union[int, float, timedelta] = None, + expire_after: ExpirationTime = None, allowable_codes: Iterable[int] = (200,), allowable_methods: Iterable['str'] = ('GET', 'HEAD'), filter_fn: Callable = None, old_data_on_error: bool = False, - session_factory=CachedSession, + session_factory: Type[OriginalSession] = CachedSession, **kwargs, ): """ - Installs cache for all ``Requests`` requests by monkey-patching ``Session`` + Installs cache for all ``requests`` functions by monkey-patching ``Session`` - Parameters are the same as in :class:`CachedSession`. Additional parameters: + Parameters are the same as in :py:class:`CachedSession`. Additional parameters: Args: session_factory: Session class to use. It must inherit from either :py:class:`CachedSession` or :py:class:`CacheMixin` """ - if backend: - backend = backends.create_backend(backend, cache_name, kwargs) class _ConfiguredCachedSession(session_factory): def __init__(self): @@ -312,30 +357,18 @@ def clear(): get_cache().clear() -def remove_expired_responses(): - """Removes expired responses from storage""" - if is_installed(): - return requests.Session().remove_expired_responses() - - -def _get_absolute_time(expire_after: Union[int, float, datetime, timedelta]) -> Optional[datetime]: - """Convert a time value to an absolute datetime, if it's not already""" - if isinstance(expire_after, datetime): - return expire_after - if expire_after is None: - return None - return datetime.now(timezone.utc) + _get_timedelta(expire_after) # type: ignore - +def remove_expired_responses(expire_after: ExpirationTime = None): + """Remove expired responses from the cache, optionally with revalidation -def _get_timedelta(expire_after: Union[int, float, timedelta] = None) -> Optional[timedelta]: - """Convert a time value to a timedelta, if it's not already""" - if expire_after is not None and not isinstance(expire_after, timedelta): - expire_after = timedelta(seconds=expire_after) - return expire_after + Args: + expire_after: A new expiration time used to revalidate the cache + """ + if is_installed(): + return requests.Session().remove_expired_responses(expire_after) -def _patch_session_factory(session_factory=CachedSession): - requests.Session = requests.sessions.Session = session_factory +def _patch_session_factory(session_factory: Type[OriginalSession] = CachedSession): + requests.Session = requests.sessions.Session = session_factory # noqa def _normalize_parameters(params: Optional[Dict]) -> Dict: diff --git a/requests_cache/response.py b/requests_cache/response.py new file mode 100755 index 00000000..3eb3afa9 --- /dev/null +++ b/requests_cache/response.py @@ -0,0 +1,156 @@ +from copy import copy +from datetime import datetime, timedelta +from io import BytesIO +from typing import Any, Dict, Optional, Union + +from requests import Response +from urllib3.response import HTTPResponse + +ExpirationTime = Union[None, int, float, datetime, timedelta] + +# Reponse attributes to copy +RESPONSE_ATTRS = Response.__attrs__ +RAW_RESPONSE_ATTRS = [ + 'decode_content', + 'headers', + 'reason', + 'request_method', + 'request_url', + 'status', + 'strict', + 'version', +] + + +class CachedResponse(Response): + """A serializable wrapper for :py:class:`requests.Response`. CachedResponse objects will behave + the same as the original response, but with some additional cache-related details. This class is + responsible for converting and setting cache expiration times, and converting response info into + a serializable format. + + Args: + original_response: Response object + expire_after: + """ + + def __init__(self, original_response: Response, expire_after: ExpirationTime = None): + """Create a CachedResponse based on an original Response""" + super().__init__() + # Set cache-specific attrs + self.created_at = datetime.utcnow() + self.expires = self._get_expiration_datetime(expire_after) + self.from_cache = True + + # Copy basic response attrs and original request + for k in RESPONSE_ATTRS: + setattr(self, k, getattr(original_response, k, None)) + self.request = copy(original_response.request) + self.request.hooks = [] + + # Read content to support streaming requests, and reset file pointer on original request + self._content = original_response.content + original_response.raw._fp = BytesIO(self._content or b'') + + # Copy raw response + self._raw_response = None + self._raw_response_attrs: Dict[str, Any] = {} + for k in RAW_RESPONSE_ATTRS: + self._raw_response_attrs[k] = getattr(original_response.raw, k, None) + + # Copy redirect history, if any + self.history = [] + for redirect in original_response.history: + self.history.append(CachedResponse(redirect)) + + def __getstate__(self): + """Override pickling behavior in ``requests.Response.__getstate__``""" + return self.__dict__ + + def _get_expiration_datetime(self, expire_after: ExpirationTime) -> Optional[datetime]: + """Convert a time value or delta to an absolute datetime, if it's not already""" + if expire_after is None or expire_after == -1: + return None + elif isinstance(expire_after, datetime): + return expire_after + + if not isinstance(expire_after, timedelta): + expire_after = timedelta(seconds=expire_after) + return self.created_at + expire_after + + def reset(self): + """Reset raw response file handler, if previously initialized""" + self._raw_response = None + + @property + def is_expired(self) -> bool: + """Determine if this cached response is expired""" + return self.expires is not None and datetime.utcnow() > self.expires + + @property + def raw(self) -> HTTPResponse: + """Reconstruct a raw urllib response object from stored attrs""" + if not self._raw_response: + self._raw_response = CachedHTTPResponse(body=self._content, **self._raw_response_attrs) + return self._raw_response + + @raw.setter + def raw(self, value): + """No-op to handle requests.Response attempting to set self.raw""" + + def revalidate(self, expire_after: ExpirationTime) -> bool: + """Set a new expiration for this response, and determine if it is now expired""" + self.expires = self._get_expiration_datetime(expire_after) + return self.is_expired + + +class CachedHTTPResponse(HTTPResponse): + """A wrapper for raw urllib response objects, which wraps cached content with support for + streaming requests + """ + + def __init__(self, body: bytes = None, **kwargs): + kwargs.setdefault('preload_content', False) + super().__init__(body=BytesIO(body or b''), **kwargs) + self._body = body + + def release_conn(self): + """No-op for compatibility""" + + def read(self, amt=None, decode_content=False, **kwargs): + """Simplified reader for cached content that emulates + :py:meth:`urllib3.response.HTTPResponse.read()` + """ + data = self._fp.read(amt) + decode_content = self.decode_content if decode_content is None else decode_content + + # "close" the file to inform consumers to stop reading from it + if not data: + self._fp.close() + # Decode binary content, if specified + elif decode_content: + self._init_decoder() + data = self._decode(data, decode_content=True, flush_decoder=True) + + return data + + def stream(self, amt=None, **kwargs): + """Simplified generator over cached content that emulates + :py:meth:`urllib3.response.HTTPResponse.stream()` + """ + while not self._fp.closed: + yield self.read(amt=amt, **kwargs) + + +AnyResponse = Union[Response, CachedResponse] + + +def set_response_defaults(response: AnyResponse) -> AnyResponse: + """Set some default CachedResponse values on a requests.Response object, so they can be + expected to always be present + """ + if not isinstance(response, CachedResponse): + response.created_at = None + response.expires = None + response.from_cache = False + response.is_expired = False + return response From 7215cbcb9fda53852e85b811fc4779016ac15ec4 Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Thu, 18 Mar 2021 17:22:56 -0500 Subject: [PATCH 3/3] Add and rewrite unit tests for CachedSession and CachedResponse using requests-mock and fixtures --- example_request.py | 8 +- pyproject.toml | 1 + setup.cfg | 8 + setup.py | 1 + tests/conftest.py | 73 +++ tests/test_cache.py | 856 +++++++++++++----------------- tests/test_expiration_datetime.py | 43 -- tests/test_monkey_patch.py | 39 +- tests/test_per_request_cache.py | 192 ------- tests/test_response.py | 155 ++++++ 10 files changed, 645 insertions(+), 731 deletions(-) create mode 100644 tests/conftest.py delete mode 100644 tests/test_expiration_datetime.py delete mode 100644 tests/test_per_request_cache.py create mode 100644 tests/test_response.py diff --git a/example_request.py b/example_request.py index 0b327388..ead525d6 100644 --- a/example_request.py +++ b/example_request.py @@ -5,7 +5,7 @@ import requests_cache -requests_cache.install_cache('example_cache', backend='memory') +requests_cache.install_cache('example_cache', backend='sqlite') def main(): @@ -15,11 +15,9 @@ def main(): response = requests.get('https://httpbin.org/get') assert response.from_cache - # Changing the expires_after time causes a cache invalidation, - # thus /get is queried again ... + # Caching with expiration + requests_cache.clear() response = requests.get('https://httpbin.org/get', expire_after=1) - assert not response.from_cache - # ... but cached for 1 second response = requests.get('https://httpbin.org/get') assert response.from_cache # After > 1 second, it's cached value is expired diff --git a/pyproject.toml b/pyproject.toml index c2aa941c..b2bafe1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ source = ['requests_cache'] profile = "black" line_length = 100 skip_gitignore = true +skip = ['requests_cache/__init__.py'] known_first_party = ['test'] # Things that are common enough they may as well be grouped with stdlib imports extra_standard_library = ['pytest', 'setuptools'] diff --git a/setup.cfg b/setup.cfg index ccaf78b4..37718f7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,11 @@ ignore = E501 # line too long W503 # line break before binary operator W504 # line break after binary operator + +# Show 7 lines of context in debugger +[ipdb] +context = 7 + +# Tell mypy to ignore external libraries without type annotations +[mypy] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index 3debd88f..2f3d1a8e 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ 'pre-commit', 'pytest>=5.0', 'pytest-cov>=2.11', + 'requests-mock>=1.8', ], } # All development/testing packages combined diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..d2e862fb --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,73 @@ +"""Fixtures that will be automatically picked up by pytest + +Note: The protocol ``http(s)+mock://`` helps :py:class:`requests_mock.Adapter` play nicely with +:py:class:`requests.PreparedRequest`. More info here: +https://requests-mock.readthedocs.io/en/latest/adapter.html +""" +import pytest +from tempfile import NamedTemporaryFile + +from requests_mock import ANY as ANY_METHOD +from requests_mock import Adapter + +from requests_cache import ALL_METHODS, CachedSession + +MOCKED_URL = 'http+mock://requests-cache.com/text' +MOCKED_URL_GZIP = 'https+mock://requests-cache.com/gzip' # TODO +MOCKED_URL_HTTPS = 'https+mock://requests-cache.com/text' +MOCKED_URL_JSON = 'http+mock://requests-cache.com/json' +MOCKED_URL_REDIRECT = 'http+mock://requests-cache.com/redirect' # TODO +MOCK_PROTOCOLS = ['mock://', 'http+mock://', 'https+mock://'] + + +@pytest.fixture(scope='function') +def mock_session() -> CachedSession: + """Fixture for combining requests-cache with requests-mock. This will behave the same as a + CachedSession, except it will make mock requests for ``mock://`` URLs, if it hasn't been cached + already. + + For example, ``mock_session.get(MOCKED_URL)`` will return a mock response on the first call, + and a cached mock response on the second call. Additional mock responses can be added via + ``mock_session.mock_adapter.register_uri()``. + + This uses a temporary SQLite db stored in ``/tmp``, which will be removed after the fixture has + exited. + """ + with NamedTemporaryFile(suffix='.db') as temp: + session = CachedSession( + cache_name=temp.name, + backend='sqlite', + allowable_methods=ALL_METHODS, + ) + adapter = get_mock_adapter() + for protocol in MOCK_PROTOCOLS: + session.mount(protocol, adapter) + session.mock_adapter = adapter + yield session + + +def get_mock_adapter() -> Adapter: + """Get a requests-mock Adapter with some URLs mocked by default""" + adapter = Adapter() + adapter.register_uri( + ANY_METHOD, + MOCKED_URL, + headers={'Content-Type': 'text/plain'}, + text='mock response', + status_code=200, + ) + adapter.register_uri( + ANY_METHOD, + MOCKED_URL_HTTPS, + headers={'Content-Type': 'text/plain'}, + text='mock https response', + status_code=200, + ) + adapter.register_uri( + ANY_METHOD, + MOCKED_URL_JSON, + headers={'Content-Type': 'application/json'}, + json={'message': 'mock json response'}, + status_code=200, + ) + return adapter diff --git a/tests/test_cache.py b/tests/test_cache.py index efff7f78..895dcb17 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# flake8: noqa: F841 import json import os import pytest @@ -6,499 +6,393 @@ import time import unittest from collections import defaultdict -from datetime import datetime, timedelta, timezone -from unittest import mock +from datetime import datetime, timedelta +from pickle import PickleError +from unittest.mock import PropertyMock, patch import requests from requests import Request import requests_cache -from requests_cache import CachedSession - -CACHE_BACKEND = 'sqlite' -CACHE_NAME = 'requests_cache_test' -FAST_SAVE = False - -HTTPBIN_URL = os.getenv('HTTPBIN_URL', 'http://httpbin.org/') +from requests_cache import ALL_METHODS, CachedSession +from requests_cache.backends.storage.dbdict import DbPickleDict +from tests.conftest import MOCKED_URL, MOCKED_URL_HTTPS, MOCKED_URL_JSON, MOCKED_URL_REDIRECT def httpbin(*suffix): """Returns url for HTTPBIN resource.""" - return HTTPBIN_URL + '/'.join(suffix) - - -class CacheTestCase(unittest.TestCase): - def setUp(self): - self.s = CachedSession(CACHE_NAME, backend=CACHE_BACKEND, fast_save=FAST_SAVE) - self.s.cache.clear() - requests_cache.uninstall_cache() - - @classmethod - def tearDownClass(cls): - super(CacheTestCase, cls).tearDownClass() - filename = "{0}.{1}".format(CACHE_NAME, CACHE_BACKEND) - if os.path.exists(filename): - try: - os.unlink(filename) - except OSError: - pass - - def tearDown(self): - self.s.close() - - def test_expire_cache(self): - delay = 1 - url = httpbin('delay/%s' % delay) - s = CachedSession(CACHE_NAME, backend=CACHE_BACKEND, expire_after=0.06) - t = time.time() - r = s.get(url) - delta = time.time() - t - self.assertGreaterEqual(delta, delay) - time.sleep(0.5) - t = time.time() - r = s.get(url) - delta = time.time() - t - self.assertGreaterEqual(delta, delay) - s.close() - - def test_delete_urls(self): - url = httpbin('get') - r = self.s.get(url) - assert self.s.cache.has_url(url) - self.s.cache.delete_url(url) - assert not self.s.cache.has_url(url) - - def test_unregistered_backend(self): - with self.assertRaises(ValueError): - CachedSession(CACHE_NAME, backend='nonexistent') - - @mock.patch('requests_cache.backends.registry') - def test_missing_backend_dependency(self, mocked_registry): - # Testing that the correct error is thrown when a user does not have - # the Python package `redis` installed. We mock out the registry - # to simulate `redis` not being installed. - mocked_registry.__getitem__.side_effect = KeyError - with self.assertRaises(ImportError): - CachedSession(CACHE_NAME, backend='redis') - - def test_hooks(self): - state = defaultdict(int) - for hook in ('response',): # TODO it's only one hook here - - def hook_func(r, *args, **kwargs): - state[hook] += 1 - return r - - n = 5 - for i in range(n): - r = self.s.get(httpbin('get'), hooks={hook: hook_func}) - self.assertEqual(state[hook], n) - - def test_attr_from_cache_in_hook(self): - state = defaultdict(int) - hook = 'response' + return 'http://httpbin.org/' + '/'.join(suffix) + + +def test_unregistered_backend(): + with pytest.raises(ValueError): + CachedSession(backend='nonexistent') + + +@patch('requests_cache.backends.registry') +def test_missing_backend_dependency(mocked_registry): + """Test that the correct error is thrown when a user does not have a dependency installed""" + mocked_registry.__getitem__.side_effect = KeyError + with pytest.raises(ImportError): + CachedSession(backend='redis') + + +@pytest.mark.parametrize('method', ALL_METHODS) +@pytest.mark.parametrize('field', ['params', 'data', 'json']) +def test_all_methods(field, method, mock_session): + """Test all relevant combinations of methods and data fields. Requests with different request + params, data, or json should be cached under different keys. + """ + for params in [{'param_1': 1}, {'param_1': 2}, {'param_2': 2}]: + assert mock_session.request(method, MOCKED_URL, **{field: params}).from_cache is False + assert mock_session.request(method, MOCKED_URL, **{field: params}).from_cache is True + + +@pytest.mark.parametrize('method', ALL_METHODS) +@pytest.mark.parametrize('field', ['params', 'data', 'json']) +def test_all_methods__ignore_parameters(field, method, mock_session): + """Test all relevant combinations of methods and data fields. Requests with different request + params, data, or json should not be cached under different keys based on an ignored param. + """ + mock_session.cache._ignored_parameters = ['ignored'] + params_1 = {'ignored': 1, 'not ignored': 1} + params_2 = {'ignored': 2, 'not ignored': 1} + params_3 = {'ignored': 2, 'not ignored': 2} + + assert mock_session.request(method, MOCKED_URL, **{field: params_1}).from_cache is False + assert mock_session.request(method, MOCKED_URL, **{field: params_1}).from_cache is True + assert mock_session.request(method, MOCKED_URL, **{field: params_2}).from_cache is True + mock_session.request(method, MOCKED_URL, params={'a': 'b'}) + assert mock_session.request(method, MOCKED_URL, **{field: params_3}).from_cache is False + + +# TODO: mock response with cookies +def test_cookies(mock_session): + def get_json(url): + return json.loads(mock_session.get(url).text) + + response_1 = get_json(httpbin('cookies/set/test1/test2')) + with mock_session.cache_disabled(): + assert get_json(httpbin('cookies')) == response_1 + # From cache + response_2 = get_json(httpbin('cookies')) + assert response_2 == get_json(httpbin('cookies')) + # Not from cache + with mock_session.cache_disabled(): + response_3 = get_json(httpbin('cookies/set/test3/test4')) + assert response_3 == get_json(httpbin('cookies')) + + +# TODO: mock response with gzip-compressed content +def test_gzip(mock_session): + assert mock_session.get(httpbin('gzip')).from_cache is False + assert mock_session.get(httpbin('gzip')).from_cache is True + + +def test_https(mock_session): + assert mock_session.get(MOCKED_URL_HTTPS, verify=True).from_cache is False + assert mock_session.get(MOCKED_URL_HTTPS, verify=True).from_cache is True + + +def test_json(mock_session): + assert mock_session.get(MOCKED_URL_JSON).from_cache is False + response = mock_session.get(MOCKED_URL_JSON) + assert response.from_cache is True + assert response.json()['message'] == 'mock json response' + + +# TODO: Create mock response with redirect history +@pytest.mark.skip(reason='httpbin.org/relative-redirect no longer returns redirects') +def test_response_history(mock_session): + r1 = mock_session.get(httpbin('relative-redirect/3')) + + def test_redirect_history(url): + r2 = mock_session.get(url) + assert r2.from_cache is True + for r11, r22 in zip(r1.history, r2.history): + assert r11.url == r22.url + + test_redirect_history(httpbin('relative-redirect/3')) + test_redirect_history(httpbin('relative-redirect/2')) + r3 = requests.get(httpbin('relative-redirect/1')) + assert len(r3.history) == 1 + + +def test_repr(): + """Test session and cache string representations""" + cache_name = 'requests_cache_test' + session = CachedSession(cache_name=cache_name, backend='memory', expire_after=10) + session.cache.responses['key'] = 'value' + session.cache.keys_map['key'] = 'value' + session.cache.keys_map['key_2'] = 'value' + + assert cache_name in repr(session) and '10' in repr(session) + assert 'redirects: 2' in str(session.cache) and 'responses: 1' in str(session.cache) + + +# TODO: More event types; make a mock response that emulates hook behavior +def test_hooks(mock_session): + state = defaultdict(int) + mock_session.get(httpbin('get')) + + for hook in ('response',): def hook_func(r, *args, **kwargs): - if state[hook] > 0: - self.assertTrue(r.from_cache) state[hook] += 1 + assert r.from_cache is True return r - n = 5 - for i in range(n): - r = self.s.get(httpbin('get'), hooks={hook: hook_func}) - self.assertEqual(state[hook], n) - - def test_post(self): - url = httpbin('post') - r1 = json.loads(self.s.post(url, data={'test1': 'test1'}).text) - r2 = json.loads(self.s.post(url, data={'test2': 'test2'}).text) - self.assertIn('test2', r2['form']) - req = Request('POST', url).prepare() - self.assertFalse(self.s.cache.has_key(self.s.cache.create_key(req))) - - def test_disabled(self): - - url = httpbin('get') - requests_cache.install_cache(CACHE_NAME, backend=CACHE_BACKEND, fast_save=FAST_SAVE) - requests.get(url) - with requests_cache.disabled(): - for i in range(2): - r = requests.get(url) - self.assertFalse(getattr(r, 'from_cache', False)) - with self.s.cache_disabled(): - for i in range(2): - r = self.s.get(url) - self.assertFalse(getattr(r, 'from_cache', False)) - r = self.s.get(url) - self.assertTrue(getattr(r, 'from_cache', False)) - - def test_enabled(self): - url = httpbin('get') - options = dict(cache_name=CACHE_NAME, backend=CACHE_BACKEND, fast_save=FAST_SAVE) - with requests_cache.enabled(**options): - r = requests.get(url) - self.assertFalse(getattr(r, 'from_cache', False)) - for i in range(2): - r = requests.get(url) - self.assertTrue(getattr(r, 'from_cache', False)) - r = requests.get(url) - self.assertFalse(getattr(r, 'from_cache', False)) - - def test_content_and_cookies(self): - requests_cache.install_cache(CACHE_NAME, CACHE_BACKEND) - s = requests.session() - - def js(url): - return json.loads(s.get(url).text) - - r1 = js(httpbin('cookies/set/test1/test2')) - with requests_cache.disabled(): - r2 = js(httpbin('cookies')) - self.assertEqual(r1, r2) - r3 = js(httpbin('cookies')) - with requests_cache.disabled(): - r4 = js(httpbin('cookies/set/test3/test4')) - # from cache - self.assertEqual(r3, js(httpbin('cookies'))) - # updated - with requests_cache.disabled(): - self.assertEqual(r4, js(httpbin('cookies'))) - s.close() - - # TODO: Create mock responses instead of depending on httpbin - @pytest.mark.skip(reason='httpbin.org/relative-redirect no longer returns redirects') - def test_response_history(self): - r1 = self.s.get(httpbin('relative-redirect/3')) - - def test_redirect_history(url): - r2 = self.s.get(url) - self.assertTrue(r2.from_cache) - for r11, r22 in zip(r1.history, r2.history): - self.assertEqual(r11.url, r22.url) - - test_redirect_history(httpbin('relative-redirect/3')) - test_redirect_history(httpbin('relative-redirect/2')) - r3 = requests.get(httpbin('relative-redirect/1')) - self.assertEqual(len(r3.history), 1) - - # TODO: Create mock responses instead of depending on httpbin - @pytest.mark.skip(reason='httpbin.org/relative-redirect no longer returns redirects') - def test_response_history_simple(self): - r1 = self.s.get(httpbin('relative-redirect/2')) - r2 = self.s.get(httpbin('relative-redirect/1')) - self.assertTrue(r2.from_cache) - - def post(self, data): - return json.loads(self.s.post(httpbin('post'), data=data).text) - - def test_post_params(self): - # issue #2 - self.s = CachedSession(CACHE_NAME, CACHE_BACKEND, allowable_methods=('GET', 'POST')) - - d = {'param1': 'test1'} - for _ in range(2): - self.assertEqual(self.post(d)['form'], d) - d = {'param1': 'test1', 'param3': 'test3'} - self.assertEqual(self.post(d)['form'], d) - - self.assertTrue(self.s.post(httpbin('post'), data=d).from_cache) - d.update({'something': 'else'}) - self.assertFalse(self.s.post(httpbin('post'), data=d).from_cache) - - def test_post_data(self): - # issue #2, raw payload - self.s = CachedSession(CACHE_NAME, CACHE_BACKEND, allowable_methods=('GET', 'POST')) - d1 = json.dumps({'param1': 'test1'}) - d2 = json.dumps({'param1': 'test1', 'param2': 'test2'}) - d3 = str('some unicode data') - bin_data = bytes('some binary data', 'utf8') - - for d in (d1, d2, d3): - self.assertEqual(self.post(d)['data'], d) - r = self.s.post(httpbin('post'), data=d) - self.assertTrue(hasattr(r, 'from_cache')) - - self.assertEqual(self.post(bin_data)['data'], bin_data.decode('utf8')) - r = self.s.post(httpbin('post'), data=bin_data) - self.assertTrue(hasattr(r, 'from_cache')) - - def test_get_params_as_argument(self): - for _ in range(5): - p = {'arg1': 'value1'} - r = self.s.get(httpbin('get'), params=p) - self.assertTrue(self.s.cache.has_url(httpbin('get?arg1=value1'))) - - @unittest.skipIf(sys.version_info < (2, 7), "No https in 2.6") - def test_https_support(self): - n = 10 - delay = 1 - url = 'https://httpbin.org/delay/%s?ar1=value1' % delay - t = time.time() - for _ in range(n): - r = self.s.get(url, verify=False) - self.assertLessEqual(time.time() - t, delay * n / 2) - - def test_from_cache_attribute(self): - url = httpbin('get?q=1') - self.assertFalse(self.s.get(url).from_cache) - self.assertTrue(self.s.get(url).from_cache) - self.s.cache.clear() - self.assertFalse(self.s.get(url).from_cache) - - def test_gzip_response(self): - url = httpbin('gzip') - self.assertFalse(self.s.get(url).from_cache) - self.assertTrue(self.s.get(url).from_cache) - - def test_close_response(self): - for _ in range(3): - r = self.s.get(httpbin("get")) - r.close() - - def test_get_parameters_normalization(self): - url = httpbin("get") - params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} - - self.assertFalse(self.s.get(url, params=params).from_cache) - r = self.s.get(url, params=params) - self.assertTrue(r.from_cache) - self.assertEqual(r.json()["args"], params) - self.assertFalse(self.s.get(url, params={"a": "b"}).from_cache) - self.assertTrue(self.s.get(url, params=sorted(params.items())).from_cache) - - class UserSubclass(dict): - def items(self): - return sorted(super(UserSubclass, self).items(), reverse=True) - - params["z"] = "5" - custom_dict = UserSubclass(params) - self.assertFalse(self.s.get(url, params=custom_dict).from_cache) - self.assertTrue(self.s.get(url, params=custom_dict).from_cache) - - def test_post_parameters_normalization(self): - params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} - url = httpbin("post") - s = CachedSession(CACHE_NAME, CACHE_BACKEND, allowable_methods=('GET', 'POST')) - self.assertFalse(s.post(url, data=params).from_cache) - self.assertTrue(s.post(url, data=params).from_cache) - self.assertTrue(s.post(url, data=sorted(params.items())).from_cache) - self.assertFalse(s.post(url, data=sorted(params.items(), reverse=True)).from_cache) - - def test_stream_requests_support(self): - n = 100 - url = httpbin("stream/%s" % n) - r = self.s.get(url, stream=True) - first_char = r.raw.read(1) - lines = list(r.iter_lines()) - self.assertTrue(first_char) - self.assertEqual(len(lines), n) + for i in range(5): + r = mock_session.get(httpbin('get'), hooks={hook: hook_func}) + assert state[hook] == 5 + + +def test_normalize_params(mock_session): + params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} + reversed_params = dict(sorted(params.items(), reverse=True)) + + assert mock_session.get(MOCKED_URL, params=params).from_cache is False + assert mock_session.get(MOCKED_URL, params=params).from_cache is True + assert mock_session.get(MOCKED_URL, params={"a": "b"}).from_cache is False + assert mock_session.get(MOCKED_URL, params=reversed_params).from_cache is True + + class UserSubclass(dict): + def items(self): + return sorted(super(UserSubclass, self).items(), reverse=True) + + params["z"] = "5" + custom_dict = UserSubclass(params) + assert mock_session.get(MOCKED_URL, params=custom_dict).from_cache is False + assert mock_session.get(MOCKED_URL, params=custom_dict).from_cache is True + + +def test_normalize_post_data(mock_session): + params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} + assert mock_session.post(MOCKED_URL, data=params).from_cache is False + assert mock_session.post(MOCKED_URL, data=params).from_cache is True + assert mock_session.post(MOCKED_URL, data=sorted(params.items())).from_cache is True + assert mock_session.post(MOCKED_URL, data=sorted(params.items(), reverse=True)).from_cache is False + + +def test_delete_response(mock_session): + mock_session.get(MOCKED_URL) + mock_session.cache.delete_url(MOCKED_URL) + assert not mock_session.cache.has_url(MOCKED_URL) + + +def test_delete_nonexistent_response(mock_session): + """Deleting a response that was either already deleted (or never added) should fail silently""" + mock_session.cache.delete_url(MOCKED_URL) + + mock_session.get(MOCKED_URL) + mock_session.cache.delete_url(MOCKED_URL) + assert not mock_session.cache.has_url(MOCKED_URL) + mock_session.cache.delete_url(MOCKED_URL) # Should fail silently + + +# TODO: Better mocking for redirects +def test_delete_redirect(mock_session): + response_key = mock_session.cache._url_to_key(MOCKED_URL) + redirect_key = mock_session.cache._url_to_key(MOCKED_URL_REDIRECT) + mock_session.get(MOCKED_URL) + mock_session.cache.keys_map[redirect_key] = response_key + + mock_session.cache.delete_url(MOCKED_URL_REDIRECT) + assert mock_session.cache.has_url(MOCKED_URL) + assert not mock_session.cache.has_url(MOCKED_URL_REDIRECT) + + +# TODO +def test_delete_history(mock_session): + pass + + +def test_response_defaults(mock_session): + """Both cached and new responses should always have the following attributes""" + mock_session.expire_after = datetime.utcnow() + timedelta(days=1) + response_1 = mock_session.get(MOCKED_URL) + response_2 = mock_session.get(MOCKED_URL) + response_3 = mock_session.get(MOCKED_URL) + + assert response_1.created_at is None + assert response_1.expires is None + assert response_1.from_cache is False + assert response_1.is_expired is False + + assert isinstance(response_2.created_at, datetime) + assert isinstance(response_2.expires, datetime) + assert response_2.created_at == response_3.created_at + assert response_2.expires == response_3.expires + assert response_2.from_cache is response_3.from_cache is True + assert response_2.is_expired is response_3.is_expired is False + + +def test_include_get_headers(mock_session): + """With include_get_headers, requests with different headers should have different cache keys""" + mock_session.cache._include_get_headers = True + headers_list = [{'Accept': 'text/json'}, {'Accept': 'text/xml'}, {'Accept': 'custom'}, None] + for headers in headers_list: + assert mock_session.get(MOCKED_URL, headers=headers).from_cache is False + assert mock_session.get(MOCKED_URL, headers=headers).from_cache is True + + +def test_include_get_headers_normalize(mock_session): + """With include_get_headers, the same headers (in any order) should have the same cache key""" + mock_session.cache._include_get_headers = True + headers = {'Accept': 'text/json', 'Custom': 'abc'} + reversed_headers = {'Custom': 'abc', 'Accept': 'text/json'} + assert mock_session.get(MOCKED_URL, headers=headers).from_cache is False + assert mock_session.get(MOCKED_URL, headers=reversed_headers).from_cache is True + + +def test_cache_error(mock_session): + """If there is an error while fetching a cached response, a new one should be fetched""" + mock_session.get(MOCKED_URL) + with patch.object(mock_session.cache, 'get_response', side_effect=ValueError): + assert mock_session.get(MOCKED_URL).from_cache is False + + +def test_expired_request_error(mock_session): + """Without old_data_on_error (default), if there is an error while re-fetching an expired + response, the request should be re-raised and the expired item deleted""" + mock_session.old_data_on_error = False + mock_session.expire_after = 0.01 + mock_session.get(MOCKED_URL) + time.sleep(0.01) + + with patch.object(mock_session.cache, 'save_response', side_effect=ValueError): + with pytest.raises(ValueError): + mock_session.get(MOCKED_URL) + assert len(mock_session.cache.responses) == 0 + + +def test_old_data_on_error(mock_session): + """With old_data_on_error, expect to get old cache data if there is an error during a request""" + mock_session.old_data_on_error = True + mock_session.expire_after = 0.1 + + assert mock_session.get(MOCKED_URL).from_cache is False + assert mock_session.get(MOCKED_URL).from_cache is True + time.sleep(0.1) + with patch.object(mock_session.cache, 'save_response', side_effect=ValueError): + response = mock_session.get(MOCKED_URL) + assert response.from_cache is True and response.is_expired is True + +@pytest.mark.parametrize('method', ['POST', 'PUT']) +def test_raw_data(method, mock_session): + """POST and PUT requests with different data (raw) should be cached under different keys""" + assert mock_session.request(method, MOCKED_URL, data='raw data').from_cache is False + assert mock_session.request(method, MOCKED_URL, data='raw data').from_cache is True + assert mock_session.request(method, MOCKED_URL, data='new raw data').from_cache is False + + +def test_cache_disabled(mock_session): + mock_session.get(MOCKED_URL) + with mock_session.cache_disabled(): for i in range(2): - r = self.s.get(url, stream=True) - first_char_cached = r.raw.read(1) - self.assertTrue(r.from_cache) - cached_lines = list(r.iter_lines()) - self.assertEqual(cached_lines, lines) - self.assertEqual(first_char, first_char_cached) - - def test_headers_in_get_query(self): - url = httpbin("get") - s = CachedSession(CACHE_NAME, CACHE_BACKEND, include_get_headers=True) - headers = {"Accept": "text/json"} - self.assertFalse(s.get(url, headers=headers).from_cache) - self.assertTrue(s.get(url, headers=headers).from_cache) - - headers["Accept"] = "text/xml" - self.assertFalse(s.get(url, headers=headers).from_cache) - self.assertTrue(s.get(url, headers=headers).from_cache) - - headers["X-custom-header"] = "custom" - self.assertFalse(s.get(url, headers=headers).from_cache) - self.assertTrue(s.get(url, headers=headers).from_cache) - - self.assertFalse(s.get(url).from_cache) - self.assertTrue(s.get(url).from_cache) - - def test_str_and_repr(self): - s = repr(CachedSession(CACHE_NAME, CACHE_BACKEND, expire_after=10)) - self.assertIn(CACHE_NAME, s) - self.assertIn("10", s) - - @mock.patch("requests_cache.core.datetime") - @mock.patch("requests_cache.backends.base.datetime") - def test_return_old_data_on_error(self, datetime_mock_backend, datetime_mock): - now = datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - datetime_mock_backend.now.return_value = now - - datetime_mock.now.return_value = now - expire_after = 100 - url = httpbin("get") - s = CachedSession(CACHE_NAME, CACHE_BACKEND, old_data_on_error=True, expire_after=expire_after) - header = "X-Tst" - - def get(n): - return s.get(url, headers={header: n}).json()["headers"][header] - - get("expired") - self.assertEqual(get("2"), "expired") - datetime_mock.now.return_value = now + timedelta(seconds=expire_after * 2) - - with mock.patch.object(s.cache, "save_response", side_effect=Exception): - self.assertEqual(get("3"), "expired") - - with mock.patch("requests_cache.core.OriginalSession.send") as send_mock: - resp_mock = requests.Response() - request = requests.Request("GET", url) - resp_mock.request = request.prepare() - resp_mock.status_code = 400 - resp_mock._content = '{"other": "content"}' - send_mock.return_value = resp_mock - self.assertEqual(get("4"), "expired") - - resp_mock.status_code = 200 - self.assertIs(s.get(url).content, resp_mock.content) - - # default behaviour - datetime_mock.now.return_value = now + timedelta(seconds=expire_after * 5) - s = CachedSession(CACHE_NAME, CACHE_BACKEND, old_data_on_error=False, expire_after=expire_after) - with mock.patch.object(s.cache, "save_response", side_effect=Exception): - with self.assertRaises(Exception): - s.get(url) - - def test_ignore_parameters_get(self): - url = httpbin("get") - ignored_param = "ignored" - usual_param = "some" - params = {ignored_param: "1", usual_param: "1"} - - s = CachedSession(CACHE_NAME, CACHE_BACKEND, ignored_parameters=[ignored_param]) - - r = s.get(url, params=params) - self.assertIn(ignored_param, r.json()['args'].keys()) - self.assertFalse(r.from_cache) - - self.assertTrue(s.get(url, params=params).from_cache) - - params[ignored_param] = "new" - self.assertTrue(s.get(url, params=params).from_cache) - - params[usual_param] = "new" - self.assertFalse(s.get(url, params=params).from_cache) - - def test_ignore_parameters_post(self): - url = httpbin("post") - ignored_param = "ignored" - usual_param = "some" - d = {ignored_param: "1", usual_param: "1"} - - s = CachedSession( - CACHE_NAME, - CACHE_BACKEND, - allowable_methods=('POST'), - ignored_parameters=[ignored_param], - ) - - r = s.post(url, data=d) - self.assertIn(ignored_param, r.json()['form'].keys()) - self.assertFalse(r.from_cache) - - self.assertTrue(s.post(url, data=d).from_cache) - - d[ignored_param] = "new" - self.assertTrue(s.post(url, data=d).from_cache) - - d[usual_param] = "new" - self.assertFalse(s.post(url, data=d).from_cache) - - def test_ignore_parameters_post_json(self): - url = httpbin("post") - ignored_param = "ignored" - usual_param = "some" - d = {ignored_param: "1", usual_param: "1"} - - s = CachedSession( - CACHE_NAME, - CACHE_BACKEND, - allowable_methods=('POST'), - ignored_parameters=[ignored_param], - ) - - r = s.post(url, json=d) - self.assertIn(ignored_param, json.loads(r.json()['data']).keys()) - self.assertFalse(r.from_cache) - - self.assertTrue(s.post(url, json=d).from_cache) - - d[ignored_param] = "new" - self.assertTrue(s.post(url, json=d).from_cache) - - d[usual_param] = "new" - self.assertFalse(s.post(url, json=d).from_cache) - - def test_ignore_parameters_post_raw(self): - url = httpbin("post") - ignored_param = "ignored" - raw_data = "raw test data" - - s = CachedSession( - CACHE_NAME, - CACHE_BACKEND, - allowable_methods=('POST'), - ignored_parameters=[ignored_param], - ) - - self.assertFalse(s.post(url, data=raw_data).from_cache) - self.assertTrue(s.post(url, data=raw_data).from_cache) - - raw_data = "new raw data" - self.assertFalse(s.post(url, data=raw_data).from_cache) - - # TODO: Create mock responses instead of depending on httpbin - @pytest.mark.skip(reason='httpbin.org/relative-redirect no longer returns redirects') - @mock.patch("requests_cache.backends.base.datetime") - @mock.patch("requests_cache.core.datetime") - def test_remove_expired_entries(self, datetime_mock, datetime_mock2): - expire_after = timedelta(minutes=10) - start_time = datetime.utcnow().replace(year=2010, minute=0) - datetime_mock.utcnow.return_value = start_time - datetime_mock2.utcnow.return_value = start_time - - s = CachedSession(CACHE_NAME, CACHE_BACKEND, expire_after=expire_after) - s.get(httpbin('get')) - s.get(httpbin('relative-redirect/3')) - datetime_mock.utcnow.return_value = start_time + expire_after * 2 - datetime_mock2.utcnow.return_value = datetime_mock.utcnow.return_value - - ok_url = 'get?x=1' - s.get(httpbin(ok_url)) - self.assertEqual(len(s.cache.responses), 3) - self.assertEqual(len(s.cache.keys_map), 3) - s.remove_expired_responses() - self.assertEqual(len(s.cache.responses), 1) - self.assertEqual(len(s.cache.keys_map), 0) - self.assertIn(ok_url, list(s.cache.responses.values())[0][0].url) - - def test_cache_unpickle_errors(self): - url = httpbin('get?q=1') - self.assertFalse(self.s.get(url).from_cache) - with mock.patch("requests_cache.backends.base.BaseCache.restore_response", side_effect=TypeError): - resp = self.s.get(url) - self.assertFalse(resp.from_cache) - self.assertEqual(resp.json()["args"]["q"], "1") - resp = self.s.get(url) - self.assertTrue(resp.from_cache) - self.assertEqual(resp.json()["args"]["q"], "1") - - def test_cache_date(self): - url = httpbin('get') - response1 = self.s.get(url) - response2 = self.s.get(url) - response3 = self.s.get(url) - self.assertEqual(response1.cache_date, None) - self.assertTrue(isinstance(response2.cache_date, datetime)) - self.assertEqual(response2.cache_date, response3.cache_date) - - -if __name__ == '__main__': - unittest.main() + assert mock_session.get(MOCKED_URL).from_cache is False + assert mock_session.get(MOCKED_URL).from_cache is True + + +def test_remove_expired_responses(mock_session): + unexpired_url = f'{MOCKED_URL}?x=1' + mock_session.mock_adapter.register_uri('GET', unexpired_url, status_code=200, text='mock response') + mock_session.expire_after = timedelta(seconds=0.1) + mock_session.get(MOCKED_URL) + mock_session.get(MOCKED_URL_JSON) + time.sleep(0.1) + mock_session.get(unexpired_url) + + # At this point we should have 1 unexpired response and 2 expired responses + assert len(mock_session.cache.responses) == 3 + mock_session.remove_expired_responses() + assert len(mock_session.cache.responses) == 1 + cached_response = list(mock_session.cache.responses.values())[0] + assert cached_response.url == unexpired_url + + # Now the last response should be expired as well + time.sleep(0.1) + mock_session.remove_expired_responses() + assert len(mock_session.cache.responses) == 0 + + +def test_remove_expired_responses__extend_expiration(mock_session): + # Start with an expired response + mock_session.expire_after = datetime.utcnow() - timedelta(seconds=0.05) + mock_session.get(MOCKED_URL) + + # Set expiration in the future and revalidate + mock_session.remove_expired_responses(expire_after=datetime.utcnow() + timedelta(seconds=0.05)) + assert len(mock_session.cache.responses) == 1 + response = mock_session.get(MOCKED_URL) + assert response.is_expired is False and response.from_cache is True + + +def test_remove_expired_responses__shorten_expiration(mock_session): + # Start with a non-expired response + mock_session.expire_after = datetime.utcnow() + timedelta(seconds=1) + mock_session.get(MOCKED_URL) + + # Set expiration in the past and revalidate + mock_session.remove_expired_responses(expire_after=datetime.utcnow() - timedelta(seconds=0.05)) + assert len(mock_session.cache.responses) == 0 + response = mock_session.get(MOCKED_URL) + assert response.is_expired is False and response.from_cache is False + + +def test_remove_expired_responses__per_request(mock_session): + # Cache 3 responses with different expiration times + second_url = f'{MOCKED_URL}/endpoint_2' + third_url = f'{MOCKED_URL}/endpoint_3' + mock_session.mock_adapter.register_uri('GET', second_url, status_code=200) + mock_session.mock_adapter.register_uri('GET', third_url, status_code=200) + mock_session.get(MOCKED_URL) + mock_session.get(second_url, expire_after=0.2) + mock_session.get(third_url, expire_after=0.4) + + # All 3 responses should still be cached + mock_session.remove_expired_responses() + assert len(mock_session.cache.responses) == 3 + + # One should be expired after 0.2s, and another should be expired after 0.4s + time.sleep(0.2) + mock_session.remove_expired_responses() + assert len(mock_session.cache.responses) == 2 + time.sleep(0.2) + mock_session.remove_expired_responses() + assert len(mock_session.cache.responses) == 1 + + +def test_per_request__expiration(mock_session): + """No per-session expiration is set, but then overridden with per-request expiration""" + mock_session.expire_after = None + response = mock_session.get(MOCKED_URL, expire_after=0.01) + assert response.from_cache is False + time.sleep(0.01) + response = mock_session.get(MOCKED_URL) + assert response.from_cache is False + + +def test_per_request__no_expiration(mock_session): + """A per-session expiration is set, but then overridden with no per-request expiration""" + mock_session.expire_after = 0.01 + response = mock_session.get(MOCKED_URL, expire_after=-1) + assert response.from_cache is False + time.sleep(0.01) + response = mock_session.get(MOCKED_URL) + assert response.from_cache is True + + +def test_unpickle_errors(mock_session): + """If there is an error during deserialization, the request should be made again""" + assert mock_session.get(MOCKED_URL_JSON).from_cache is False + + with patch.object(DbPickleDict, '__getitem__', side_effect=PickleError): + resp = mock_session.get(MOCKED_URL_JSON) + assert resp.from_cache is False + assert resp.json()['message'] == 'mock json response' + + resp = mock_session.get(MOCKED_URL_JSON) + assert resp.from_cache is True + assert resp.json()['message'] == 'mock json response' diff --git a/tests/test_expiration_datetime.py b/tests/test_expiration_datetime.py deleted file mode 100644 index aa577a74..00000000 --- a/tests/test_expiration_datetime.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python -import datetime -import itertools -import unittest - -import requests - -import requests_cache - - -class ExpirationTimeTest(unittest.TestCase): - def setUp(self): - requests_cache.install_cache(backend='memory') - self.url = 'https://httpbin.org/get' - self.session = requests.Session() - self.now = datetime.datetime(2021, 2, 28, 16, 40) - self.response = requests.Response() - - def tearDown(self): - requests_cache.uninstall_cache() - - def test_expire_after_precedence_matrix(self): - in_five_seconds = datetime.datetime(2021, 2, 28, 16, 40, 5) - expire_afters = ['default', None, 5, datetime.timedelta(seconds=5), in_five_seconds] - - for cache_expire_after, request_expire_after, response_expire_after in itertools.product( - expire_afters, expire_afters, expire_afters - ): - if cache_expire_after == 'default': - continue # cache can never be default or cached - - expected = False - if request_expire_after == 'default': - expected = in_five_seconds if cache_expire_after is not None else None - else: - expected = in_five_seconds if request_expire_after is not None else None - - self.session._cache_expire_after = cache_expire_after - self.session._request_expire_after = request_expire_after - self.response.expire_after = response_expire_after - with self.subTest(cache=cache_expire_after, request=request_expire_after, response=response_expire_after): - actual = self.session._determine_expiration_datetime(relative_to=self.now) - self.assertEqual(actual, expected) diff --git a/tests/test_monkey_patch.py b/tests/test_monkey_patch.py index 3194f276..b4be07a3 100644 --- a/tests/test_monkey_patch.py +++ b/tests/test_monkey_patch.py @@ -71,25 +71,44 @@ class MyCache(BaseCache): session = CachedSession(backend=backend) self.assertIs(session.cache, backend) - @patch.object(BaseCache, 'remove_old_entries') - def test_remove_expired_responses(self, remove_old_entries): + @patch.object(OriginalSession, 'request') + @patch.object(CachedSession, 'request') + def test_disabled(self, cached_request, original_request): + requests_cache.install_cache() + with requests_cache.disabled(): + for i in range(3): + requests.get('some_url') + assert cached_request.call_count == 0 + assert original_request.call_count == 3 + + @patch.object(OriginalSession, 'request') + @patch.object(CachedSession, 'request') + def test_enabled(self, cached_request, original_request): + with requests_cache.enabled(): + for i in range(3): + requests.get('some_url') + assert cached_request.call_count == 3 + assert original_request.call_count == 0 + + @patch.object(BaseCache, 'remove_expired_responses') + def test_remove_expired_responses(self, remove_expired_responses): requests_cache.install_cache(expire_after=360) requests_cache.remove_expired_responses() - assert remove_old_entries.called is True + assert remove_expired_responses.called is True - @patch.object(BaseCache, 'remove_old_entries') - def test_remove_expired_responses__cache_not_installed(self, remove_old_entries): + @patch.object(BaseCache, 'remove_expired_responses') + def test_remove_expired_responses__cache_not_installed(self, remove_expired_responses): requests_cache.remove_expired_responses() - assert remove_old_entries.called is False + assert remove_expired_responses.called is False - @patch.object(BaseCache, 'remove_old_entries') - def test_remove_expired_responses__no_expiration(self, remove_old_entries): + @patch.object(BaseCache, 'remove_expired_responses') + def test_remove_expired_responses__no_expiration(self, remove_expired_responses): requests_cache.install_cache() requests_cache.remove_expired_responses() # Before https://github.com/reclosedev/requests-cache/pull/177, this - # was False, but with per-request caching, remove_old_entries must + # was False, but with per-request caching, remove_expired_responses must # always be called - assert remove_old_entries.called is True + assert remove_expired_responses.called is True if __name__ == '__main__': diff --git a/tests/test_per_request_cache.py b/tests/test_per_request_cache.py deleted file mode 100644 index 2a0d3e38..00000000 --- a/tests/test_per_request_cache.py +++ /dev/null @@ -1,192 +0,0 @@ -#!/usr/bin/env python -import os -import time -import unittest - -import requests - -import requests_cache - -HTTPBIN_URL = os.getenv('HTTPBIN_URL', 'http://httpbin.org/') - - -class PerRequestCachedSessionTest(unittest.TestCase): - def setUp(self): - requests_cache.install_cache(backend='memory') - self.url = HTTPBIN_URL + 'get' - - def tearDown(self): - requests_cache.uninstall_cache() - - def test_default_cache_always(self): - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - response = requests.get(self.url) - self.assertTrue(response.from_cache) - - response = requests.get(self.url, expire_after='default') - self.assertTrue(response.from_cache) - - def test_default_cache_never(self): - requests_cache.install_cache(backend='memory', expire_after=-1) - - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - response = requests.get(self.url, expire_after='default') - self.assertFalse(response.from_cache) - - def test_positive_cache(self): - response = requests.get(self.url, expire_after=0.1) - self.assertFalse(response.from_cache) - - time.sleep(0.5) - - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - # This should delete the cached entry before as it changed - response = requests.get(self.url, expire_after=5) - self.assertFalse(response.from_cache) - - # This should not delete the cached entry before as it didn't change - response = requests.get(self.url, expire_after=5) - self.assertTrue(response.from_cache) - - def test_negative_cache(self): - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - response = requests.get(self.url) - self.assertTrue(response.from_cache) - - response = requests.get(self.url, expire_after=-1) - self.assertFalse(response.from_cache) - - response = requests.get(self.url, expire_after=-1) - self.assertFalse(response.from_cache) - - def test_cache_invalidation(self): - self.assertFalse(requests.get(self.url, expire_after=1).from_cache) - self.assertTrue(requests.get(self.url).from_cache) - time.sleep(1.2) - self.assertFalse(requests.get(self.url).from_cache) - - self.assertFalse(requests.get(self.url, expire_after=-1).from_cache) - self.assertFalse(requests.get(self.url).from_cache) - - self.assertFalse(requests.get(self.url, expire_after=1).from_cache) - self.assertTrue(requests.get(self.url).from_cache) - - def test_auto_clear_expired(self): - requests_cache.install_cache(backend='memory', expire_after=1) - - second_url = HTTPBIN_URL + 'anything' - - response = requests.get(self.url, expire_after=5) - self.assertFalse(response.from_cache) - - response = requests.get(self.url) - self.assertTrue(response.from_cache) - - response = requests.get(second_url) - self.assertFalse(response.from_cache) - - time.sleep(2) - - response = requests.get(self.url) - self.assertTrue(response.from_cache) - - response = requests.get(second_url, expire_after=10) - self.assertFalse(response.from_cache) - - response = requests.get(second_url) - self.assertTrue(response.from_cache) - - def test_remove_expired(self): - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - response = requests.get(self.url) - self.assertTrue(response.from_cache) - - second_url = HTTPBIN_URL + 'anything' - - response = requests.get(second_url, expire_after=2) - self.assertFalse(response.from_cache) - - response = requests.get(second_url) - self.assertTrue(response.from_cache) - - third_url = HTTPBIN_URL - - response = requests.get(third_url, expire_after=10) - self.assertFalse(response.from_cache) - - response = requests.get(third_url) - self.assertTrue(response.from_cache) - - self.assertEqual(len(requests.Session().cache.responses), 3) - - time.sleep(2) - - requests_cache.remove_expired_responses() - - self.assertEqual(len(requests.Session().cache.responses), 2) - - def test_remove_expired_expire_by_default(self): - requests_cache.install_cache(backend='memory', expire_after=1) - - response = requests.get(self.url) - self.assertFalse(response.from_cache) - - response = requests.get(self.url) - self.assertTrue(response.from_cache) - - second_url = HTTPBIN_URL + 'anything' - - response = requests.get(second_url, expire_after=10) - self.assertFalse(response.from_cache) - - response = requests.get(second_url) - self.assertTrue(response.from_cache) - - self.assertEqual(len(requests.Session().cache.responses), 2) - - time.sleep(1) - - requests_cache.core.remove_expired_responses() - - self.assertEqual(len(requests.Session().cache.responses), 1) - - -class ContextManagerTest(unittest.TestCase): - def tearDown(self): - os.unlink('test_cache.sqlite') - - def test_as_context_manager(self): - url = HTTPBIN_URL + 'delay/2' - with requests_cache.enabled('test_cache', expire_after=10): - response = requests.get(url) - self.assertFalse(response.from_cache) - - response = requests.get(url) - self.assertTrue(response.from_cache) - - start = time.time() - response = requests.get(url) - end = time.time() - self.assertFalse(hasattr(response, 'from_cache')) - self.assertGreaterEqual(end - start, 1.5) - - with requests_cache.enabled('test_cache'): - response = requests.get(url) - self.assertTrue(response.from_cache) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 00000000..c85e242a --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,155 @@ +import gzip +import pytest +from datetime import datetime, timedelta +from io import BytesIO +from time import sleep + +from urllib3.response import HTTPResponse + +from requests_cache import CachedHTTPResponse, CachedResponse +from tests.conftest import MOCKED_URL + + +def test_basic_attrs(mock_session): + response = CachedResponse(mock_session.get(MOCKED_URL)) + + assert response.from_cache is True + assert response.url == MOCKED_URL + assert response.status_code == 200 + assert response.reason is None + assert response.encoding == 'ISO-8859-1' + assert response.headers['Content-Type'] == 'text/plain' + assert response.text == 'mock response' + assert response.created_at is not None + assert response.expires is None + assert response.is_expired is False + + +@pytest.mark.parametrize( + 'expire_after, is_expired', + [ + (None, False), + (datetime.utcnow() + timedelta(days=1), False), + (datetime.utcnow() - timedelta(days=1), True), + ], +) +def test_expiration(expire_after, is_expired, mock_session): + response = CachedResponse(mock_session.get(MOCKED_URL), expire_after) + assert response.from_cache is True + assert response.is_expired == is_expired + + +def test_history(mock_session): + original_response = mock_session.get(MOCKED_URL) + original_response.history = [mock_session.get(MOCKED_URL)] * 3 + response = CachedResponse(original_response) + assert len(response.history) == 3 + assert all([isinstance(r, CachedResponse) for r in response.history]) + + +def test_raw_response__read(mock_session): + response = CachedResponse(mock_session.get(MOCKED_URL)) + assert isinstance(response.raw, CachedHTTPResponse) + assert response.raw.read(10) == b'mock respo' + assert response.raw.read(None) == b'nse' + assert response.raw.read(1) == b'' + assert response.raw._fp.closed is True + + +def test_raw_response__close(mock_session): + response = CachedResponse(mock_session.get(MOCKED_URL)) + response.close() + assert response.raw._fp.closed is True + + +def test_raw_response__reset(mock_session): + response = CachedResponse(mock_session.get(MOCKED_URL)) + response.raw.read(None) + assert response.raw.read(1) == b'' + assert response.raw._fp.closed is True + + response.reset() + assert response.raw.read(None) == b'mock response' + + +def test_raw_response__decode(mock_session): + """Test that a gzip-compressed raw response can be manually uncompressed with decode_content""" + url = f'{MOCKED_URL}/utf-8' + mock_session.mock_adapter.register_uri( + 'GET', + url, + status_code=200, + body=BytesIO(gzip.compress(b'compressed response')), + headers={'content-encoding': 'gzip'}, + ) + response = CachedResponse(mock_session.get(url)) + # Requests will have already read this, but let's just pretend we want to do it manually + response.raw._fp = BytesIO(gzip.compress(b'compressed response')) + assert response.raw.read(None, decode_content=True) == b'compressed response' + + +def test_raw_response__stream(mock_session): + response = CachedResponse(mock_session.get(MOCKED_URL)) + data = b'' + for chunk in response.raw.stream(1): + data += chunk + assert data == b'mock response' + assert response.raw._fp.closed + + +def test_raw_response__iterator(mock_session): + # Set up mock response with streamed content + url = f'{MOCKED_URL}/stream' + mock_raw_response = HTTPResponse( + body=BytesIO(b'mock response'), + status=200, + request_method='GET', + decode_content=False, + preload_content=False, + ) + mock_session.mock_adapter.register_uri( + 'GET', + url, + status_code=200, + raw=mock_raw_response, + ) + + # Expect the same chunks of data from the original response and subsequent cached responses + last_request_chunks = None + for i in range(3): + response = mock_session.get(url, stream=True) + chunks = list(response.iter_lines()) + if i == 0: + assert response.from_cache is False + else: + assert response.from_cache is True + assert chunks == last_request_chunks + last_request_chunks = chunks + + +def test_revalidate__extend_expiration(mock_session): + # Start with an expired response + response = CachedResponse( + mock_session.get(MOCKED_URL), + expire_after=datetime.utcnow() - timedelta(seconds=0.05), + ) + assert response.is_expired is True + + # Set expiration in the future and revalidate + is_expired = response.revalidate(datetime.utcnow() + timedelta(seconds=0.05)) + assert is_expired is response.is_expired is False + sleep(0.1) + assert response.is_expired is True + + +def test_revalidate__shorten_expiration(mock_session): + # Start with a non-expired response + response = CachedResponse( + mock_session.get(MOCKED_URL), + expire_after=datetime.utcnow() + timedelta(seconds=1), + ) + assert response.is_expired is False + + # Set expiration in the past and revalidate + is_expired = response.revalidate(datetime.utcnow() - timedelta(seconds=0.05)) + assert is_expired is response.is_expired is True