Skip to content

Commit

Permalink
fix switch_context (dont wait for an exception) (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
aleneum committed Oct 13, 2020
1 parent d19204d commit d1903c3
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,5 @@ dmypy.json
cython_debug/

# IntelliJ
.idea/
.idea/
.vscode/settings.json
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@pytest.fixture(params=[
pytest.param('asyncio'),
pytest.param('trio'),
pytest.param('curio'),
# pytest.param('curio'),
])
def anyio_backend(request):
return request.param
Expand Down
3 changes: 1 addition & 2 deletions tests/test_transitions_anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ async def test_async_conditions(m):
mock.assert_called_once_with()


@pytest.mark.xfail(reason="we should investigate")
async def test_multiple_models(machine_cls):
m1 = machine_cls(states=['A', 'B', 'C'], initial='A', name="m1")
m2 = machine_cls(states=['A'], initial='A', name='m2')
Expand Down Expand Up @@ -173,7 +172,7 @@ async def test_async_dispatch(machine_cls):
assert machine.initial == model3.state


@pytest.mark.xfail(reason="we should investigate")
# @pytest.mark.xfail(reason="we should investigate")
async def test_queued(machine_cls):
states = ['A', 'B', 'C', 'D']

Expand Down
21 changes: 10 additions & 11 deletions transitions_anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,20 @@ async def with_result(func):

async def process_context(self, func, model):
if self.current_context.get() is None:
try:
async with open_cancel_scope() as scope:
self.current_context.set(scope)
return await func()
except get_cancelled_exc_class():
return False
return await func()

def switch_model_context(self, model):
res = False
async with open_cancel_scope() as scope:
self.current_context.set(scope)
res = await self._process(func)
return res
return await self._process(func)

async def switch_model_context(self, model):
current_scope = self.current_context.get()
running_scope = self.async_tasks.get(model, None)
if current_scope != running_scope:
if running_scope is not None:
self.async_tasks[model].cancel()
self.async_tasks[model] = self.current_context.get()
if running_scope is not None:
await running_scope.cancel()


class AnyIOGraphMachine(GraphMachine, AnyIOMachine):
Expand Down

0 comments on commit d1903c3

Please sign in to comment.