diff --git a/aioresponses/core.py b/aioresponses/core.py index a596259..1a03d20 100644 --- a/aioresponses/core.py +++ b/aioresponses/core.py @@ -8,6 +8,7 @@ from functools import wraps from typing import Callable, Dict, Tuple, Union, Optional, List # noqa from unittest.mock import Mock, patch +from uuid import uuid4 from aiohttp import ( ClientConnectionError, @@ -207,7 +208,7 @@ async def build_response( class aioresponses(object): """Mock aiohttp requests made by ClientSession.""" - _matches = None # type: List[RequestMatch] + _matches = None # type: Dict[str, RequestMatch] _responses = None # type: List[ClientResponse] requests = None # type: Dict @@ -254,7 +255,7 @@ def clear(self): def start(self): self._responses = [] - self._matches = [] + self._matches = {} self.patcher.start() self.patcher.return_value = self._request_mock @@ -297,7 +298,8 @@ def add(self, url: 'Union[URL, str, Pattern]', method: str = hdrs.METH_GET, timeout: bool = False, reason: Optional[str] = None, callback: Optional[Callable] = None) -> None: - self._matches.append(RequestMatch( + + self._matches[str(uuid4())] = (RequestMatch( url, method=method, status=status, @@ -330,7 +332,7 @@ async def match( ) -> Optional['ClientResponse']: history = [] while True: - for i, matcher in enumerate(self._matches): + for key, matcher in self._matches.items(): if matcher.match(method, url): response_or_exc = await matcher.build_response( url, allow_redirects=allow_redirects, **kwargs @@ -340,7 +342,7 @@ async def match( return None if matcher.repeat is False: - del self._matches[i] + del self._matches[key] if self.is_exception(response_or_exc): raise response_or_exc diff --git a/tests/test_aioresponses.py b/tests/test_aioresponses.py index c208a1b..f25c343 100644 --- a/tests/test_aioresponses.py +++ b/tests/test_aioresponses.py @@ -2,6 +2,7 @@ import asyncio import re from asyncio import CancelledError, TimeoutError +from random import uniform from typing import Coroutine, Generator, Union from unittest.mock import patch @@ -482,6 +483,24 @@ async def test_exception_requests_are_tracked(self, mocked): self.assertEqual(request.args, ()) self.assertEqual(request.kwargs, kwargs) + async def test_possible_race_condition(self): + async def random_sleep_cb(url, **kwargs): + await asyncio.sleep(uniform(0.1, 1)) + return CallbackResult(body='test') + + with aioresponses() as mocked: + for i in range(20): + mocked.get( + 'http://example.org/id-{}'.format(i), + callback=random_sleep_cb + ) + + tasks = [ + self.session.get('http://example.org/id-{}'.format(i)) for + i in range(20) + ] + await asyncio.gather(*tasks) + class AIOResponsesRaiseForStatusSessionTestCase(AsyncTestCase): """Test case for sessions with raise_for_status=True.