Skip to content

Commit

Permalink
Merge pull request #53 from tangkong/enh_client_apply
Browse files Browse the repository at this point in the history
ENH: add Client.apply method, adjust TaskStatus to properly capture exceptions
  • Loading branch information
tangkong committed Jul 12, 2024
2 parents 2dbd9e1 + f7c1ad6 commit 7cc33ef
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 18 deletions.
22 changes: 22 additions & 0 deletions docs/source/upcoming_release_notes/53-enh_client_apply.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
53 enh_client_apply
###################

API Breaks
----------
- N/A

Features
--------
- Implements `Client.apply` method for writing values from `Entry` data to the control system.

Bugfixes
--------
- N/A

Maintenance
-----------
- Adjusts `TaskStatus` use to properly capture exceptions.

Contributors
------------
- tangkong
118 changes: 110 additions & 8 deletions superscore/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""Client for superscore. Used for programmatic interactions with superscore"""
from typing import Any, Generator
import logging
from typing import Any, Generator, List, Optional, Union
from uuid import UUID

from superscore.backends.core import _Backend
from superscore.model import Entry
from superscore.control_layers import ControlLayer
from superscore.control_layers.status import TaskStatus
from superscore.model import Entry, Setpoint, Snapshot

logger = logging.getLogger(__name__)


class Client:
backend: _Backend
cl: ControlLayer

def __init__(self, backend=None, **kwargs) -> None:
# if backend is None, startup default filestore backend
return
def __init__(self, backend: _Backend, **kwargs) -> None:
self.backend = backend
self.cl = ControlLayer()

@classmethod
def from_config(cls, cfg=None):
Expand All @@ -34,9 +41,104 @@ def compare(self, entry_l: Entry, entry_r: Entry) -> Any:
"""Compare two entries. Should be of same type, and return a diff"""
raise NotImplementedError

def apply(self, entry: Entry):
"""Apply settings found in ``entry``. If no values found, no-op"""
raise NotImplementedError
def apply(
self,
entry: Union[Setpoint, Snapshot],
sequential: bool = False
) -> Optional[List[TaskStatus]]:
"""
Apply settings found in ``entry``. If no writable values found, return.
If ``sequential`` is True, apply values in ``entry`` in sequence, blocking
with each put request. Else apply all values simultaneously (asynchronously)
Parameters
----------
entry : Union[Setpoint, Snapshot]
The entry to apply values from
sequential : bool, optional
Whether to apply values sequentially, by default False
Returns
-------
Optional[List[TaskStatus]]
TaskStatus(es) for each value applied.
"""
if not isinstance(entry, (Setpoint, Snapshot)):
logger.info("Entries must be a Snapshot or Setpoint")
return

if isinstance(entry, Setpoint):
return [self.cl.put(entry.pv_name, entry.data)]

# Gather pv-value list and apply at once
status_list = []
pv_list, data_list = self._gather_data(entry)
if sequential:
for pv, data in zip(pv_list, data_list):
logger.debug(f'Putting {pv} = {data}')
status: TaskStatus = self.cl.put(pv, data)
if status.exception():
logger.warning(f"Failed to put {pv} = {data}, "
"terminating put sequence")
return

status_list.append(status)
else:
return self.cl.put(pv_list, data_list)

def _gather_data(
self,
entry: Union[Setpoint, Snapshot, UUID],
pv_list: Optional[List[str]] = None,
data_list: Optional[List[Any]] = None
) -> Optional[tuple[List[str], List[Any]]]:
"""
Gather writable pv name - data pairs recursively.
If pv_list and data_list are provided, gathered data will be added to
these lists in-place. If both lists are omitted, this function will return
the two lists after gathering.
Queries the backend to fill any UUID values found.
Parameters
----------
entry : Union[Setpoint, Snapshot, UUID]
Entry to gather writable data from
pv_list : Optional[List[str]], optional
List of addresses to write data to, by default None
data_list : Optional[List[Any]], optional
List of data to write to addresses in ``pv_list``, by default None
Returns
-------
Optional[tuple[List[str], List[Any]]]
the filled pv_list and data_list
"""
top_level = False
if (pv_list is None) and (data_list is None):
pv_list = []
data_list = []
top_level = True
elif (pv_list is None) or (data_list is None):
raise ValueError(
"Arguments pv_list and data_list must either both be provided "
"or both omitted."
)

if isinstance(entry, Snapshot):
for child in entry.children:
self._gather_data(child, pv_list, data_list)
elif isinstance(entry, UUID):
child_entry = self.backend.get_entry(entry)
self._gather_data(child_entry, pv_list, data_list)
elif isinstance(entry, Setpoint):
pv_list.append(entry.pv_name)
data_list.append(entry.data)

# Readbacks are not writable, and are not gathered

if top_level:
return pv_list, data_list

def validate(self, entry: Entry):
"""
Expand Down
1 change: 1 addition & 0 deletions superscore/control_layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .core import ControlLayer # noqa
from .status import TaskStatus # noqa
6 changes: 3 additions & 3 deletions superscore/control_layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def status_coro():
status = self._put_one(address, value)
if cb is not None:
status.add_callback(cb)
await status.task
await asyncio.gather(status, return_exceptions=True)
return status

return asyncio.run(status_coro())
Expand Down Expand Up @@ -185,15 +185,15 @@ async def status_coros():
status.add_callback(c)

statuses.append(status)
await asyncio.gather(*[s.task for s in statuses])
await asyncio.gather(*statuses, return_exceptions=True)
return statuses

return asyncio.run(status_coros())

@TaskStatus.wrap
async def _put_one(self, address: str, value: Any):
"""
Base async get function. Use this to construct higher-level get methods
Base async put function. Use this to construct higher-level put methods
"""
shim = self.shim_from_pv(address)
await shim.put(address, value)
Expand Down
27 changes: 25 additions & 2 deletions superscore/control_layers/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
class TaskStatus:
"""
Unified Status object for wrapping task completion information and attaching
callbacks This must be created inside of a coroutine, but can be returned to
synchronous scope for examining the task
callbacks. This must be created inside of a coroutine, but can be returned to
synchronous scope for examining the task.
Awaiting this status is similar to awaiting the wrapped task.
Largely vendored from bluesky/ophyd-async
"""
Expand Down Expand Up @@ -57,6 +59,27 @@ def success(self) -> bool:
and self.task.exception() is None
)

def wait(self, timeout=None) -> None:
"""
Block until the coroutine finishes. Raises asyncio.TimeoutError if
the timeout elapses before the task is completed
To be called in a synchronous context, if the status has not been awaited
Parameters
----------
timeout : number, optional
timeout in seconds, by default None
Raises
------
asyncio.TimeoutError
"""
# ensure task runs in the event loop it was assigned to originally
asyncio.get_event_loop().run_until_complete(
asyncio.wait_for(self.task, timeout)
)

def __repr__(self) -> str:
if self.done:
if e := self.exception():
Expand Down
27 changes: 27 additions & 0 deletions superscore/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import shutil
from pathlib import Path
from typing import List
from unittest.mock import MagicMock

import pytest

from superscore.backends.core import _Backend
from superscore.backends.filestore import FilestoreBackend
from superscore.backends.test import TestBackend
from superscore.client import Client
from superscore.control_layers._base_shim import _BaseShim
from superscore.control_layers.core import ControlLayer
from superscore.model import (Collection, Parameter, Readback, Root, Setpoint,
Expand Down Expand Up @@ -697,3 +699,28 @@ def dummy_cl() -> ControlLayer:
cl.shims['ca'] = DummyShim()
cl.shims['pva'] = DummyShim()
return cl


@pytest.fixture(scope='function')
def mock_backend() -> _Backend:
bk = _Backend()
bk.delete_entry = MagicMock()
bk.save_entry = MagicMock()
bk.get_entry = MagicMock()
bk.search = MagicMock()
bk.update_entry = MagicMock()


class MockTaskStatus:
def exception(self):
return None

@property
def done(self):
return True


@pytest.fixture(scope='function')
def mock_client(mock_backend: _Backend) -> Client:
client = Client(backend=mock_backend)
return client
8 changes: 5 additions & 3 deletions superscore/tests/test_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def test_fail(dummy_cl):
mock_ca_put = AsyncMock(side_effect=ValueError)
dummy_cl.shims['ca'].put = mock_ca_put

# exceptions get passed through the control layer
with pytest.raises(ValueError):
dummy_cl.put("THAT:PV", 4)
# exceptions get captured in status object
status = dummy_cl.put("THAT:PV", 4)
assert isinstance(status.exception(), ValueError)

assert mock_ca_put.called


def test_put_callback(dummy_cl):
Expand Down
21 changes: 21 additions & 0 deletions superscore/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest.mock import patch

from superscore.client import Client
from superscore.model import Root

from .conftest import MockTaskStatus


@patch('superscore.control_layers.core.ControlLayer.put')
def test_apply(put_mock, mock_client: Client, sample_database: Root):
put_mock.return_value = MockTaskStatus()
snap = sample_database.entries[3]
mock_client.apply(snap)
assert put_mock.call_count == 1
call_args = put_mock.call_args[0]
assert len(call_args[0]) == len(call_args[1]) == 3

put_mock.reset_mock()

mock_client.apply(snap, sequential=True)
assert put_mock.call_count == 3
42 changes: 40 additions & 2 deletions superscore/tests/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ async def inner_coroutine():
return inner_coroutine


@pytest.fixture
async def long_coroutine_status() -> TaskStatus:
@TaskStatus.wrap
async def inner_coroutine():
for i in range(100):
print(f'coro wait: {i}')
await asyncio.sleep(1)

return inner_coroutine()


async def test_status_success(normal_coroutine):
st = TaskStatus(normal_coroutine())
assert isinstance(st, TaskStatus)
Expand All @@ -40,7 +51,34 @@ async def test_status_fail(failing_coroutine):
with pytest.raises(ValueError):
await status

assert type(status.exception()) == ValueError
assert isinstance(status.exception(), ValueError)


def test_sync_status_fail(failing_coroutine):
# A usage note for the curious. If we gather these tasks with
# `return_exceptions` = False (default), the first exception will be propagated,
# though the other tasks will complete. This may stop tasks from being returned
# `retur_exceptions` = True will not raise exceptions, instead those exceptions
# will only be captured in `task.exception()`
async def wrap_coro(return_exc: bool):
status = TaskStatus(failing_coroutine())
await asyncio.gather(status, return_exceptions=return_exc)
return status

status = asyncio.run(wrap_coro(True))
assert status.done
assert isinstance(status.exception(), ValueError)

with pytest.raises(ValueError):
asyncio.run(wrap_coro(False))


def test_status_wait(long_coroutine_status):
assert not long_coroutine_status.done
with pytest.raises(asyncio.TimeoutError):
long_coroutine_status.wait(1)
assert long_coroutine_status.done
assert isinstance(long_coroutine_status.exception(), asyncio.CancelledError)


async def test_status_wrap():
Expand All @@ -50,5 +88,5 @@ async def coro_status():

st = coro_status()
assert isinstance(st, TaskStatus)
await st.task
await st
assert st.done

0 comments on commit 7cc33ef

Please sign in to comment.