Skip to content

Commit

Permalink
fixes race condition while removing matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
pnuckowski committed Aug 22, 2020
1 parent 27363ea commit f52bd75
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
12 changes: 7 additions & 5 deletions aioresponses/core.py
Expand Up @@ -8,6 +8,7 @@
from functools import wraps
from typing import Callable, Dict, Tuple, Union, Optional, List # noqa
from unittest.mock import Mock, patch
from uuid import uuid4

from aiohttp import (
ClientConnectionError,
Expand Down Expand Up @@ -207,7 +208,7 @@ async def build_response(

class aioresponses(object):
"""Mock aiohttp requests made by ClientSession."""
_matches = None # type: List[RequestMatch]
_matches = None # type: Dict[str, RequestMatch]
_responses = None # type: List[ClientResponse]
requests = None # type: Dict

Expand Down Expand Up @@ -254,7 +255,7 @@ def clear(self):

def start(self):
self._responses = []
self._matches = []
self._matches = {}
self.patcher.start()
self.patcher.return_value = self._request_mock

Expand Down Expand Up @@ -297,7 +298,8 @@ def add(self, url: 'Union[URL, str, Pattern]', method: str = hdrs.METH_GET,
timeout: bool = False,
reason: Optional[str] = None,
callback: Optional[Callable] = None) -> None:
self._matches.append(RequestMatch(

self._matches[str(uuid4())] = (RequestMatch(
url,
method=method,
status=status,
Expand Down Expand Up @@ -330,7 +332,7 @@ async def match(
) -> Optional['ClientResponse']:
history = []
while True:
for i, matcher in enumerate(self._matches):
for key, matcher in self._matches.items():
if matcher.match(method, url):
response_or_exc = await matcher.build_response(
url, allow_redirects=allow_redirects, **kwargs
Expand All @@ -340,7 +342,7 @@ async def match(
return None

if matcher.repeat is False:
del self._matches[i]
del self._matches[key]

if self.is_exception(response_or_exc):
raise response_or_exc
Expand Down
19 changes: 19 additions & 0 deletions tests/test_aioresponses.py
Expand Up @@ -2,6 +2,7 @@
import asyncio
import re
from asyncio import CancelledError, TimeoutError
from random import uniform
from typing import Coroutine, Generator, Union
from unittest.mock import patch

Expand Down Expand Up @@ -482,6 +483,24 @@ async def test_exception_requests_are_tracked(self, mocked):
self.assertEqual(request.args, ())
self.assertEqual(request.kwargs, kwargs)

async def test_possible_race_condition(self):
async def random_sleep_cb(url, **kwargs):
await asyncio.sleep(uniform(0.1, 1))
return CallbackResult(body='test')

with aioresponses() as mocked:
for i in range(20):
mocked.get(
'http://example.org/id-{}'.format(i),
callback=random_sleep_cb
)

tasks = [
self.session.get('http://example.org/id-{}'.format(i)) for
i in range(20)
]
await asyncio.gather(*tasks)


class AIOResponsesRaiseForStatusSessionTestCase(AsyncTestCase):
"""Test case for sessions with raise_for_status=True.
Expand Down

0 comments on commit f52bd75

Please sign in to comment.