Skip to content

Commit

Permalink
Add lock around remove_expired_responses() for SQLite, Filesystem, an…
Browse files Browse the repository at this point in the history
…d GridFS backends
  • Loading branch information
JWCook committed Feb 23, 2022
1 parent 9576fcf commit 0bd1446
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 26 deletions.
4 changes: 4 additions & 0 deletions requests_cache/backends/filesystem.py
Expand Up @@ -82,6 +82,10 @@ def clear(self):
self.responses.clear()
self.redirects.init_db()

def remove_expired_responses(self, *args, **kwargs):
with self.responses._lock:
return super().remove_expired_responses(*args, **kwargs)


class FileDict(BaseStorage):
"""A dictionary-like interface to files on the local filesystem"""
Expand Down
4 changes: 4 additions & 0 deletions requests_cache/backends/gridfs.py
Expand Up @@ -45,6 +45,10 @@ def __init__(self, db_name: str, **kwargs):
db_name, collection_name='redirects', connection=self.responses.connection, **kwargs
)

def remove_expired_responses(self, *args, **kwargs):
with self.responses._lock:
return super().remove_expired_responses(*args, **kwargs)


class GridFSPickleDict(BaseStorage):
"""A dictionary-like interface for a GridFS database
Expand Down
4 changes: 4 additions & 0 deletions requests_cache/backends/sqlite.py
Expand Up @@ -131,6 +131,10 @@ def clear(self):
self.responses.init_db()
self.redirects.init_db()

def remove_expired_responses(self, *args, **kwargs):
with self.responses._lock, self.redirects._lock:
return super().remove_expired_responses(*args, **kwargs)


class SQLiteDict(BaseStorage):
"""A dictionary-like interface for SQLite"""
Expand Down
52 changes: 26 additions & 26 deletions tests/integration/base_cache_test.py
Expand Up @@ -232,7 +232,7 @@ def test_conditional_request__max_age_0(self, cache_headers, validator_headers):
'validator_headers', [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]
)
@pytest.mark.parametrize('cache_headers', [{'Cache-Control': 'max-age=0'}])
def test_conditional_request_refreshenes_expire_date(self, cache_headers, validator_headers):
def test_conditional_request_refreshes_expire_date(self, cache_headers, validator_headers):
"""Test that revalidation attempt with 304 responses causes stale entry to become fresh again considering
Cache-Control header of the 304 response."""
url = httpbin('response-headers')
Expand Down Expand Up @@ -306,31 +306,6 @@ def test_remove_expired_responses(self):
assert not session.cache.has_url(httpbin('redirect/1'))
assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS])

@pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
def test_concurrency(self, iteration, executor_class):
"""Run multithreaded and multiprocess stress tests for each backend.
The number of workers (thread/processes), iterations, and requests per iteration can be
increased via the `STRESS_TEST_MULTIPLIER` environment variable.
"""
start = time()
url = httpbin('anything')

session_factory = partial(self.init_session, clear=False)
request_func = partial(_send_request, session_factory, url)
with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
_ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))

# Some logging for debug purposes
elapsed = time() - start
average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS)
worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes'
logger.info(
f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with '
f'{N_WORKERS} {worker_type} in {elapsed} s\n'
f'Average time per request: {average} ms'
)

@pytest.mark.parametrize('method', HTTPBIN_METHODS)
def test_filter_request_headers(self, method):
url = httpbin(method.lower())
Expand Down Expand Up @@ -369,6 +344,31 @@ def test_filter_request_post_data(self, post_type):
body = json.loads(response.request.body)
assert "api_key" not in body

@pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
def test_concurrency(self, iteration, executor_class):
"""Run multithreaded and multiprocess stress tests for each backend.
The number of workers (thread/processes), iterations, and requests per iteration can be
increased via the `STRESS_TEST_MULTIPLIER` environment variable.
"""
start = time()
url = httpbin('anything')

session_factory = partial(self.init_session, clear=False)
request_func = partial(_send_request, session_factory, url)
with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
_ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))

# Some logging for debug purposes
elapsed = time() - start
average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS)
worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes'
logger.info(
f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with '
f'{N_WORKERS} {worker_type} in {elapsed} s\n'
f'Average time per request: {average} ms'
)


def _send_request(session_factory, url, _=None):
"""Concurrent request function for stress tests. Defined in module scope so it can be serialized
Expand Down

0 comments on commit 0bd1446

Please sign in to comment.