Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/viam/components/arm/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from grpclib.server import Stream

from viam.components.service_base import ComponentServiceBase
from viam.errors import ComponentNotFoundError
from viam.proto.component.arm import (
Expand Down Expand Up @@ -35,7 +36,7 @@ async def GetEndPosition(self, stream: Stream[GetEndPositionRequest, GetEndPosit
except ComponentNotFoundError as e:
raise e.grpc_error
timeout = stream.deadline.time_remaining() if stream.deadline else None
position = await arm.get_end_position(extra=struct_to_dict(request.extra), timeout=timeout)
position = await arm.get_end_position(extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
response = GetEndPositionResponse(pose=position)
await stream.send_message(response)

Expand All @@ -48,7 +49,9 @@ async def MoveToPosition(self, stream: Stream[MoveToPositionRequest, MoveToPosit
except ComponentNotFoundError as e:
raise e.grpc_error
timeout = stream.deadline.time_remaining() if stream.deadline else None
await arm.move_to_position(request.to, request.world_state, extra=struct_to_dict(request.extra), timeout=timeout)
await arm.move_to_position(
request.to, request.world_state, extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata
)
response = MoveToPositionResponse()
await stream.send_message(response)

Expand All @@ -61,7 +64,7 @@ async def GetJointPositions(self, stream: Stream[GetJointPositionsRequest, GetJo
except ComponentNotFoundError as e:
raise e.grpc_error
timeout = stream.deadline.time_remaining() if stream.deadline else None
positions = await arm.get_joint_positions(extra=struct_to_dict(request.extra), timeout=timeout)
positions = await arm.get_joint_positions(extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
response = GetJointPositionsResponse(positions=positions)
await stream.send_message(response)

Expand All @@ -74,7 +77,7 @@ async def MoveToJointPositions(self, stream: Stream[MoveToJointPositionsRequest,
except ComponentNotFoundError as e:
raise e.grpc_error
timeout = stream.deadline.time_remaining() if stream.deadline else None
await arm.move_to_joint_positions(request.positions, extra=struct_to_dict(request.extra), timeout=timeout)
await arm.move_to_joint_positions(request.positions, extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
response = MoveToJointPositionsResponse()
await stream.send_message(response)

Expand All @@ -87,6 +90,6 @@ async def Stop(self, stream: Stream[StopRequest, StopResponse]) -> None:
except ComponentNotFoundError as e:
raise e.grpc_error
timeout = stream.deadline.time_remaining() if stream.deadline else None
await arm.stop(extra=struct_to_dict(request.extra), timeout=timeout)
await arm.stop(extra=struct_to_dict(request.extra), timeout=timeout, metadata=stream.metadata)
response = StopResponse()
await stream.send_message(response)
23 changes: 19 additions & 4 deletions src/viam/operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import functools
import time
from typing import Any, Callable, Coroutine, Optional, TypeVar, cast
from typing import Any, Callable, Coroutine, Mapping, Optional, TypeVar, cast
from uuid import UUID, uuid4

from typing_extensions import Self
Expand All @@ -27,8 +27,8 @@ class Operation:
_cancel_event: asyncio.Event
_cancelled: bool

def __init__(self, method: str, cancel_event: asyncio.Event) -> None:
self.id = uuid4()
def __init__(self, method: str, cancel_event: asyncio.Event, opid: Optional[UUID] = None) -> None:
self.id = uuid4() if opid is None else opid
self.method = method
self.time_started = time.time()
self._cancel_event = cancel_event
Expand Down Expand Up @@ -61,6 +61,19 @@ def _noop(cls) -> Self:
P = ParamSpec("P")
T = TypeVar("T")

METADATA_KEY = "opid"


def opid_from_metadata(metadata: Optional[Mapping[str, str]]) -> Optional[UUID]:
if metadata is None:
return None

opid = metadata.get(METADATA_KEY)
if opid is None:
return None

return UUID(opid)


def run_with_operation(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
"""Run a component function with an `Operation`.
Expand Down Expand Up @@ -89,7 +102,9 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
func_name = func.__qualname__
arg_names = ", ".join([str(a) for a in args])
kwarg_names = ", ".join([f"{key}={value}" for (key, value) in kwargs.items()])
operation = Operation(f"{func_name}({arg_names}{', ' if len(arg_names) else ''}{kwarg_names})", event)
method = f"{func_name}({arg_names}{', ' if len(arg_names) else ''}{kwarg_names})"
opid = opid_from_metadata(kwargs.get("metadata")) # type: ignore
operation = Operation(method, event, opid=opid)
kwargs[Operation.ARG_NAME] = operation
timeout = kwargs.get("timeout", None)
timer: Optional[asyncio.TimerHandle] = None
Expand Down
20 changes: 19 additions & 1 deletion tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import time
from uuid import UUID

import pytest

from viam.operations import Operation, run_with_operation
from viam.operations import METADATA_KEY, Operation, run_with_operation


@pytest.mark.asyncio
Expand Down Expand Up @@ -57,3 +58,20 @@ async def long_running(self, **kwargs) -> bool:
test_obj.long_running_task_cancelled = False
assert test_obj.long_running_task_cancelled is False
assert await asyncio.create_task(test_obj.long_running(timeout=0.02)) is True


@pytest.mark.asyncio
async def test_wrapper_with_metadata():
test_metadata_opid = "11111111-1111-1111-1111-111111111111"

class TestWrapperClass:
@run_with_operation
async def run(self, **kwargs) -> bool:
operation: Operation = kwargs.get(Operation.ARG_NAME, Operation._noop())
assert operation.id == UUID(test_metadata_opid)
return False

test_obj = TestWrapperClass()
metadata = {METADATA_KEY: test_metadata_opid}
result = await test_obj.run(metadata=metadata)
assert result is False