diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py index c018c5ad..72e2d9c1 100644 --- a/requests_cache/backends/filesystem.py +++ b/requests_cache/backends/filesystem.py @@ -82,6 +82,10 @@ def clear(self): self.responses.clear() self.redirects.init_db() + def remove_expired_responses(self, *args, **kwargs): + with self.responses._lock: + return super().remove_expired_responses(*args, **kwargs) + class FileDict(BaseStorage): """A dictionary-like interface to files on the local filesystem""" diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index 418dba6f..db31ba4c 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -45,6 +45,10 @@ def __init__(self, db_name: str, **kwargs): db_name, collection_name='redirects', connection=self.responses.connection, **kwargs ) + def remove_expired_responses(self, *args, **kwargs): + with self.responses._lock: + return super().remove_expired_responses(*args, **kwargs) + class GridFSPickleDict(BaseStorage): """A dictionary-like interface for a GridFS database diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index c5e22cf0..d2217c9c 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -131,6 +131,10 @@ def clear(self): self.responses.init_db() self.redirects.init_db() + def remove_expired_responses(self, *args, **kwargs): + with self.responses._lock, self.redirects._lock: + return super().remove_expired_responses(*args, **kwargs) + class SQLiteDict(BaseStorage): """A dictionary-like interface for SQLite""" diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 13e2ede6..2cbf96e9 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -232,7 +232,7 @@ def test_conditional_request__max_age_0(self, cache_headers, validator_headers): 'validator_headers', [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}] ) @pytest.mark.parametrize('cache_headers', [{'Cache-Control': 'max-age=0'}]) - def test_conditional_request_refreshenes_expire_date(self, cache_headers, validator_headers): + def test_conditional_request_refreshes_expire_date(self, cache_headers, validator_headers): """Test that revalidation attempt with 304 responses causes stale entry to become fresh again considering Cache-Control header of the 304 response.""" url = httpbin('response-headers') @@ -306,31 +306,6 @@ def test_remove_expired_responses(self): assert not session.cache.has_url(httpbin('redirect/1')) assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS]) - @pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor]) - @pytest.mark.parametrize('iteration', range(N_ITERATIONS)) - def test_concurrency(self, iteration, executor_class): - """Run multithreaded and multiprocess stress tests for each backend. - The number of workers (thread/processes), iterations, and requests per iteration can be - increased via the `STRESS_TEST_MULTIPLIER` environment variable. - """ - start = time() - url = httpbin('anything') - - session_factory = partial(self.init_session, clear=False) - request_func = partial(_send_request, session_factory, url) - with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: - _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION))) - - # Some logging for debug purposes - elapsed = time() - start - average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS) - worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes' - logger.info( - f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with ' - f'{N_WORKERS} {worker_type} in {elapsed} s\n' - f'Average time per request: {average} ms' - ) - @pytest.mark.parametrize('method', HTTPBIN_METHODS) def test_filter_request_headers(self, method): url = httpbin(method.lower()) @@ -369,6 +344,31 @@ def test_filter_request_post_data(self, post_type): body = json.loads(response.request.body) assert "api_key" not in body + @pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor]) + @pytest.mark.parametrize('iteration', range(N_ITERATIONS)) + def test_concurrency(self, iteration, executor_class): + """Run multithreaded and multiprocess stress tests for each backend. + The number of workers (thread/processes), iterations, and requests per iteration can be + increased via the `STRESS_TEST_MULTIPLIER` environment variable. + """ + start = time() + url = httpbin('anything') + + session_factory = partial(self.init_session, clear=False) + request_func = partial(_send_request, session_factory, url) + with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: + _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION))) + + # Some logging for debug purposes + elapsed = time() - start + average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS) + worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes' + logger.info( + f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with ' + f'{N_WORKERS} {worker_type} in {elapsed} s\n' + f'Average time per request: {average} ms' + ) + def _send_request(session_factory, url, _=None): """Concurrent request function for stress tests. Defined in module scope so it can be serialized