diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 97796a83281..f73945915cd 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -13,11 +13,13 @@ # limitations under the License. import asyncio +from asyncio.log import logger import datetime import sys import threading from typing import ( AsyncIterable, + AsyncIterator, Awaitable, Callable, Dict, @@ -42,6 +44,8 @@ _M = TypeVar('_M', bound=proto.Message) _R = TypeVar('_R') +JobPath = str +MessageId = str class EngineException(Exception): @@ -95,6 +99,146 @@ def instance(cls): return cls._instance +class ResponseDemux: + """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.""" + + def __init__(self): + self._subscribers: Dict[MessageId, asyncio.Future] = {} + self._next_available_message_id = 0 + + def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future: + """Assumes the message ID has not been set.""" + request.message_id = str(self._next_available_message_id) + response_future: asyncio.Future = asyncio.get_running_loop().create_future() + self._subscribers[request.message_id] = response_future + self._next_available_message_id += 1 + return response_future + + def unsubscribe(self, request: quantum.QuantumRunStreamRequest) -> None: + if request.message_id in self._subscribers: + del self._subscribers[request.message_id] + + def publish(self, response: quantum.QuantumRunStreamResponse) -> None: + if response.message_id not in self._subscribers: + return + + future = self._subscribers.pop(response.message_id) + if not future.done(): + future.set_result(response) + + def publish_exception(self, exception: GoogleAPICallError) -> None: + """Publishes an exception to all outstanding futures.""" + for future in self._subscribers.values(): + if not future.done(): + future.set_exception(exception) + self._subscribers = {} + + +class StreamManager: + def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): + self._grpc_client = grpc_client + self._request_queue: asyncio.Queue = asyncio.Queue() + self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None + # TODO consider making the scope of response futures local to the relevant tasks rather than + # all of StreamManager. + self._response_demux = ResponseDemux() + + @property + def _executor(self) -> AsyncioExecutor: + # We must re-use a single Executor due to multi-threading issues in gRPC + # clients: https://github.com/grpc/grpc/issues/25364. + return AsyncioExecutor.instance() + + async def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: + """The request iterator for quantum_run_stream(). + + Every call of this method generates a new iterator.""" + while True: + yield await self._request_queue.get() + + async def _manage_stream(self): + """Keeps the stream alive and routes responses to the appropriate request handler""" + while True: + try: + # TODO specify stream timeout below with exponential backoff + response_iterable = await self._grpc_client.quantum_run_stream( + self._request_iterator() + ) + async for response in response_iterable: + logger.warning('publishing response to demux') + self._response_demux.publish(response) + except BaseException as e: + # TODO Close the request iterator to close the existing stream. + self._response_demux.publish_exception(e) # Raise to all request tasks + if isinstance(e, asyncio.CancelledError): + break + + async def _make_request( + self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + ) -> quantum.QuantumResult: + """This method is executed in a separate asyncio Task for each request.""" + current_request = quantum.QuantumRunStreamRequest( + parent=project_name, + create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( + parent=project_name, quantum_program=program, quantum_job=job + ), + ) + get_result_request = quantum.QuantumRunStreamRequest( + parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) + ) + + response_future: Optional[asyncio.Future] = None + response: Optional[quantum.QuantumRunStreamResponse] = None + while response is None: + try: + logger.warn('Making request') + response_future = self._response_demux.subscribe(current_request) + await self._request_queue.put(current_request) + response = await response_future + logger.warning('Got response') + + except GoogleAPICallError: + # TODO how to distinguish between program not found vs job not found? + # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or + # job doesn't exist. + # TODO add exponential backoff + logger.warn('Got GoogleAPICallError') + self._response_demux.unsubscribe(current_request) + current_request = get_result_request + continue + + # Either when this request is canceled or the _manage_stream() loop is canceled. + except asyncio.CancelledError: + if response_future is not None: + response_future.cancel() + self._response_demux.unsubscribe(current_request) + await self._cancel(job.name) + return quantum.QuantumResult() + + if response.result is not None: + logger.warning('Got result') + return response.result + # TODO handle QuantumJob response and retryable StreamError. + + async def _cancel(self, job_name: str) -> None: + await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) + + def send( + self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + ) -> duet.AwaitableFuture[quantum.QuantumResult]: + """Sends a request over the stream and returns a future for the result.""" + if self._manage_stream_loop_future is None: + self._manage_stream_loop_future = self._executor.submit(self._manage_stream) + return self._executor.submit(self._make_request, project_name, program, job) + + def stop(self) -> duet.AwaitableFuture[None]: + """Stops and resets the stream manager.""" + if self._manage_stream_loop_future is None: + return duet.completed_future(None) + self._manage_stream_loop_future.cancel() + return self._manage_stream_loop_future + + class EngineClient: """Client for the Quantum Engine API handling protos and gRPC client. @@ -148,6 +292,10 @@ async def make_client(): return self._executor.submit(make_client).result() + @cached_property + def _stream_manager(self) -> StreamManager: + return StreamManager(self.grpc_client) + async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: """Sends a request by invoking an asyncio callable.""" return await self._run_retry_async(func, request) @@ -740,6 +888,56 @@ async def get_job_results_async( get_job_results = duet.sync(get_job_results_async) + def run_job_over_stream( + self, + project_id: str, + program_id: str, + code: any_pb2.Any, + job_id: Optional[str], # TODO make this non-optional. + processor_ids: Sequence[str], + run_context: any_pb2.Any, + priority: Optional[int] = None, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + ) -> duet.AwaitableFuture[quantum.QuantumResult]: + # Check program to run and program parameters. + if priority and not 0 <= priority < 1000: + raise ValueError('priority must be between 0 and 1000') + + project_name = _project_name(project_id) + + program_name = _program_name_from_ids(project_id, program_id) if program_id else '' + program = quantum.QuantumProgram(name=program_name, code=code) + if description: + program.description = description + if labels: + program.labels.update(labels) + + job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else '' + job = quantum.QuantumJob( + name=job_name, + scheduling_config=quantum.SchedulingConfig( + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor_names=[ + _processor_name_from_ids(project_id, processor_id) + for processor_id in processor_ids + ] + ) + ), + run_context=run_context, + ) + if priority: + job.scheduling_config.priority = priority + if description: + job.description = description + if labels: + job.labels.update(labels) + + return self._stream_manager.send(project_name, program, job) + + def stop_stream(self) -> duet.AwaitableFuture[None]: + return self._stream_manager.stop() + async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to @@ -1133,3 +1331,14 @@ def _date_or_time_to_filter_expr(param_name: str, param: Union[datetime.datetime f"type {type(param)}. Supported types: datetime.datetime and" f"datetime.date" ) + + +def _get_job_path_from_stream_request(request: quantum.QuantumRunStreamRequest) -> str: + if 'create_quantum_program_and_job' in request: + return request.create_quantum_program_and_job.quantum_job.name + elif 'create_quantum_job' in request: + return request.create_quantum_job.quantum_job.name + elif 'get_quantum_result' in request: + return request.get_quantum_result.parent + else: + raise ValueError(f'Unrecognized request type in request: {request}') diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index b416f75b7d1..ba4da750ab2 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for EngineClient.""" +from typing import AsyncIterable, AsyncIterator, Awaitable, List + import asyncio +from asyncio.log import logger import datetime from unittest import mock @@ -25,6 +28,7 @@ from cirq_google.engine.engine_client import EngineClient, EngineException from cirq_google.cloud import quantum +from cirq_google.cloud.quantum_v1alpha1.types import engine def setup_mock_(client_constructor): @@ -33,6 +37,48 @@ def setup_mock_(client_constructor): return grpc_client +def setup_fake_quantum_run_stream_client(client_constructor, responses_and_exceptions): + grpc_client = _FakeQuantumRunStream(responses_and_exceptions) + client_constructor.return_value = grpc_client + return grpc_client + + +class _FakeQuantumRunStream: + def __init__( + self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] + ): + self.stream_request_count = 0 + self.cancel_requests: List[engine.CancelQuantumJobRequest] = [] + self.responses_and_exceptions = responses_and_exceptions + + def add_responses_and_exceptions( + self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] + ): + self.responses_and_exceptions.extend(responses_and_exceptions) + + async def quantum_run_stream( + self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs + ) -> Awaitable[AsyncIterable[engine.QuantumRunStreamResponse]]: + async def run_async_iterator(): + async for request in requests: + self.stream_request_count += 1 + while not self.responses_and_exceptions: + await asyncio.sleep(0) + logger.warning('Responding') + response_or_exception = self.responses_and_exceptions.pop(0) + if isinstance(response_or_exception, BaseException): + raise response_or_exception + response_or_exception.message_id = request.message_id + yield response_or_exception + + await asyncio.sleep(0) + return run_async_iterator() + + async def cancel_quantum_job(self, request: engine.CancelQuantumJobRequest) -> None: + self.cancel_requests.append(request) + await asyncio.sleep(0) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_program(client_constructor): grpc_client = setup_mock_(client_constructor) @@ -1176,3 +1222,74 @@ def test_list_time_slots(client_constructor): client = EngineClient() assert client.list_time_slots('proj', 'processor0') == results + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_run_job_over_stream_send_job_expects_result_response(client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] + fake_client = setup_fake_quantum_run_stream_client( + client_constructor, responses_and_exceptions=mock_responses + ) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() + + actual_result = client.run_job_over_stream( + 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels + ).result(timeout=5) + client.stop_stream().result + + assert actual_result == expected_result + assert fake_client.stream_request_count == 1 + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_run_job_over_stream_cancel_expects_engine_cancellation_rpc_call(client_constructor): + fake_client = setup_fake_quantum_run_stream_client( + client_constructor, responses_and_exceptions=[] + ) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() + + result_future = client.run_job_over_stream( + 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels + ) + result_future.cancel() + client.stop_stream().result + + assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( + name='projects/proj/programs/prog/jobs/job0' + ) + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_run_job_over_stream_stream_broken_expects_retry_with_get_quantum_result( + client_constructor, +): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses_and_exceptions = [ + exceptions.Aborted('aborted'), + quantum.QuantumRunStreamResponse(result=expected_result), + ] + fake_client = setup_fake_quantum_run_stream_client( + client_constructor, responses_and_exceptions=mock_responses_and_exceptions + ) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() + + actual_result = client.run_job_over_stream( + 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels + ).result(timeout=1) + client.stop_stream().result + + assert actual_result == expected_result + assert fake_client.stream_request_count == 2