Skip to content

Commit

Permalink
Merge branch 'absolute-expire-time' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JWCook committed Feb 28, 2021
2 parents 7424ccf + 6c95230 commit 653e9d4
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 36 deletions.
35 changes: 13 additions & 22 deletions aiohttp_client_cache/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib
from abc import ABCMeta, abstractmethod
from collections import UserDict
from datetime import datetime, timedelta
from datetime import timedelta
from logging import getLogger
from typing import Callable, Iterable, Optional, Union
from urllib.parse import parse_qsl, urlparse, urlunparse
Expand Down Expand Up @@ -76,26 +76,19 @@ def __init__(
self.include_headers = include_headers
self.ignored_params = set(ignored_params or [])

def is_cacheable(self, response: Union[ClientResponse, CachedResponse, None]) -> bool:
def is_cacheable(self, response: Union[AnyResponse, None]) -> bool:
"""Perform all checks needed to determine if the given response should be cached"""
if not response:
return False
cache_criteria = [
not self.disabled,
response.status in self.allowed_codes,
response.method in self.allowed_methods,
not self.is_expired(response),
self.filter_fn(response),
]
cache_criteria = {
'allowed status': response.status in self.allowed_codes,
'allowed method': response.method in self.allowed_methods,
'not disabled': not self.disabled,
'not expired': not getattr(response, 'is_expired', False),
'not filtered': self.filter_fn(response),
}
logger.debug(f'is_cacheable checks for response from {response.url}: {cache_criteria}') # type: ignore
return all(cache_criteria)

def is_expired(self, response: AnyResponse) -> bool:
"""Determine if a given response is expired"""
created_at = getattr(response, 'created_at', None)
if not created_at or not self.expire_after:
return False
return datetime.utcnow() - created_at >= self.expire_after
return all(cache_criteria.values())

async def get_response(self, key: str) -> Optional[CachedResponse]:
"""Retrieve response and timestamp for `key` if it's stored in cache,
Expand All @@ -116,15 +109,13 @@ async def get_response(self, key: str) -> Optional[CachedResponse]:
if not isinstance(response, CachedResponse):
logger.debug('Cached response is invalid')
return None

logger.info(f'Cached response found for key: {key}')

# If the item is expired or filtered out, delete it from the cache
if not self.is_cacheable(response):
logger.info('Cached response expired; deleting')
await self.delete(key)
response.is_expired = True
return None

logger.info(f'Cached response found for key: {key}')
return response

async def save_response(self, key: str, response: ClientResponse):
Expand All @@ -138,7 +129,7 @@ async def save_response(self, key: str, response: ClientResponse):
return
logger.info(f'Saving response for key: {key}')

cached_response = await CachedResponse.from_client_response(response)
cached_response = await CachedResponse.from_client_response(response, self.expire_after)
await self.responses.write(key, cached_response)

# Alias any redirect requests to the same cache key
Expand Down
23 changes: 18 additions & 5 deletions aiohttp_client_cache/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from datetime import datetime
from datetime import datetime, timedelta
from http.cookies import SimpleCookie
from typing import Any, Dict, Iterable, Mapping, Optional, Union

Expand All @@ -13,6 +13,7 @@
'_body',
'created_at',
'encoding',
'expires',
'history',
'is_expired',
'request_info',
Expand Down Expand Up @@ -41,8 +42,8 @@ def from_object(cls, request_info):

@attr.s(slots=True)
class CachedResponse:
"""A dataclass containing cached response information. Will mostly behave the same as a
:py:class:`aiohttp.ClientResponse` that has been read.
"""A dataclass containing cached response information. It will mostly behave the same as a
:py:class:`aiohttp.ClientResponse` that has been read, with some additional cache-related info.
"""

method: str = attr.ib()
Expand All @@ -55,13 +56,16 @@ class CachedResponse:
cookies: SimpleCookie = attr.ib(default=None)
created_at: datetime = attr.ib(factory=datetime.utcnow)
encoding: str = attr.ib(default=None)
expires: datetime = attr.ib(default=None)
headers: Mapping = attr.ib(factory=dict)
history: Iterable = attr.ib(factory=tuple)
is_expired: bool = attr.ib(default=False)
request_info: RequestInfo = attr.ib(default=None)

@classmethod
async def from_client_response(cls, client_response: ClientResponse):
async def from_client_response(
cls, client_response: ClientResponse, expire_after: timedelta = None
):
"""Convert a ClientResponse into a CachedReponse"""
# Response may not have been read yet, if fetched by something other than CachedSession
if not client_response._released:
await client_response.read()
Expand All @@ -74,6 +78,10 @@ async def from_client_response(cls, client_response: ClientResponse):
response._body = client_response._body
response.headers = dict(client_response.headers)

# Set expiration time
if expire_after:
response.expires = datetime.utcnow() + expire_after

# The encoding may be unset even if the response has been read
try:
response.encoding = client_response.get_encoding()
Expand All @@ -100,6 +108,11 @@ def ok(self) -> bool:
def get_encoding(self):
return self.encoding

@property
def is_expired(self) -> bool:
"""Determine if this cached response is expired"""
return bool(self.expires) and datetime.utcnow() > self.expires

async def json(self, encoding: Optional[str] = None, **kwargs) -> Optional[Dict[str, Any]]:
"""Read and decode JSON response"""

Expand Down
2 changes: 1 addition & 1 deletion aiohttp_client_cache/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def _request(self, method: str, str_or_url: StrOrURL, **kwargs) -> AnyResp

# Attempt to fetch cached response; if missing or expired, fetch new one
cached_response = await self.cache.get_response(cache_key)
if cached_response and not getattr(cached_response, 'is_expired', False):
if cached_response:
return cached_response
else:
logger.info(f'Cached response not found; making request to {str_or_url}')
Expand Down
9 changes: 5 additions & 4 deletions test/unit/backends/test_backend_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@pytest.mark.asyncio
async def test_cache_backend__get_response__cache_response_hit():
cache = CacheBackend()
mock_response = MagicMock(spec=CachedResponse)
mock_response = MagicMock(spec=CachedResponse, method='GET', status=200, is_expired=False)
await cache.responses.write('request-key', mock_response)

response = await cache.get_response('request-key')
Expand All @@ -20,7 +20,7 @@ async def test_cache_backend__get_response__cache_response_hit():
async def test_cache_backend__get_response__cache_redirect_hit():
# Set up a cache with a couple cached items and a redirect
cache = CacheBackend()
mock_response = MagicMock(spec=CachedResponse)
mock_response = MagicMock(spec=CachedResponse, method='GET', status=200, is_expired=False)
await cache.responses.write('request-key', mock_response)
await cache.redirects.write('redirect-key', 'request-key')

Expand All @@ -45,10 +45,11 @@ async def test_cache_backend__get_response__cache_miss():
@patch.object(CacheBackend, 'is_cacheable', return_value=False)
async def test_cache_backend__get_response__cache_expired(mock_is_cacheable, mock_delete):
cache = CacheBackend()
await cache.responses.write('request-key', MagicMock(spec=CachedResponse))
mock_response = MagicMock(spec=CachedResponse, method='GET', status=200, is_expired=True)
await cache.responses.write('request-key', mock_response)

response = await cache.get_response('request-key')
assert response.is_expired is True
assert response is None
mock_delete.assert_called_with('request-key')


Expand Down
13 changes: 11 additions & 2 deletions test/unit/test_response.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import asyncio
import pytest
from datetime import timedelta

from aiohttp import ClientResponseError, ClientSession, web

from aiohttp_client_cache.response import CachedResponse, RequestInfo


async def get_real_response():
# Just for debugging purposes, not used by unit tests
async with ClientSession() as session:
return await session.get('http://httpbin.org/get')


async def get_test_response(client_factory, url='/'):
async def get_test_response(client_factory, url='/', **kwargs):
app = web.Application()
client = await client_factory(app)
client_response = await client.get(url)
return await CachedResponse.from_client_response(client_response)
return await CachedResponse.from_client_response(client_response, **kwargs)


async def test_response__basic_attrs(aiohttp_client):
Expand All @@ -30,6 +33,12 @@ async def test_response__basic_attrs(aiohttp_client):
assert response.is_expired is False


async def test_response__expiration(aiohttp_client):
response = await get_test_response(aiohttp_client, expire_after=timedelta(seconds=0.01))
await asyncio.sleep(0.01)
assert response.is_expired is True


async def test_response__encoding(aiohttp_client):
response = await get_test_response(aiohttp_client)
assert response.encoding == response.get_encoding() == 'utf-8'
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ async def test_session__cache_hit(mock_request):

@pytest.mark.asyncio
@patch.object(ClientSession, '_request')
async def test_session__cache_expired(mock_request):
async def test_session__cache_expired_or_invalid(mock_request):
cache = MagicMock(spec=CacheBackend)
cache.get_response.return_value = AsyncMock(is_expired=True)
cache.get_response.return_value = None
session = CachedSession(cache=cache)

await session.get('http://test.url')
Expand Down

0 comments on commit 653e9d4

Please sign in to comment.