From 5b7147f831f99a3f43bdce0c9ed47bf4a0ef10ce Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 20 Jun 2023 21:15:21 +0000 Subject: [PATCH 01/16] Stream client prototype --- .../cirq_google/engine/engine_client.py | 190 +++++++++++++++++- 1 file changed, 188 insertions(+), 2 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index a3c93b72a1e..3b959093740 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -18,6 +18,7 @@ import threading from typing import ( AsyncIterable, + AsyncIterator, Awaitable, Callable, Dict, @@ -42,6 +43,8 @@ _M = TypeVar('_M', bound=proto.Message) _R = TypeVar('_R') +JobPath = str +MessageId = str class EngineException(Exception): @@ -95,6 +98,136 @@ def instance(cls): return cls._instance +class ResponseDemux: + """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern. + Args: + cancel_callback: Function to be called when the future matching its request argument is + canceled. + """ + + def __init__(self): + self._subscribers: Dict[JobPath, Tuple[MessageId, duet.AwaitableFuture]] = {} + self._next_available_message_id = 0 + + def subscribe( + self, request: quantum.QuantumRunStreamRequest + ) -> duet.AwaitableFuture[quantum.QuantumRunStreamResponse]: + """Assumes the message ID has not been set.""" + + if 'create_quantum_program_and_job' in request: + job_path = request.create_quantum_program_and_job.quantum_job.name + elif 'create_quantum_job' in request: + job_path = request.create_quantum_job.quantum_job.name + else: # 'get_quantum_result' in request + job_path = request.get_quantum_result.parent + + request.message_id = self._next_available_message_id + response_future = duet.AwaitableFuture[quantum.QuantumRunStreamResponse]() + self._subscribers[job_path] = (self._next_available_message_id, response_future) + self._next_available_message_id += 1 + return response_future + + def publish(self, response: quantum.QuantumRunStreamResponse) -> None: + if 'error' in response: + job_path = next( + ( + p + for p, (message_id, _) in self._subscribers.items() + if message_id == response.message_id + ), + default='', + ) + elif 'job' in response: + job_path = response.job.name + else: # 'result' in response + job_path = response.result.parent + + if job_path not in self._subscribers: + return + + self._subscribers[job_path].try_set_result(response) + del self._subscribers[job_path] + + def publish_exception(self, exception: GoogleAPICallError) -> None: + """Publishes an exception to all outstanding futures.""" + for _, future in self._subscribers.values(): + future.try_set_exception(exception) + self._subscribers = {} + + +class StreamManager: + def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): + self._grpc_client = grpc_client + self._request_queue = asyncio.Queue() + self._response_demux = ResponseDemux() + self._manage_stream_loop_running = False + + def _executor(self) -> AsyncioExecutor: + # We must re-use a single Executor due to multi-threading issues in gRPC + # clients: https://github.com/grpc/grpc/issues/25364. + return AsyncioExecutor.instance() + + def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: + async def iterator(): + yield await self._request_queue.get() + + # TODO how to make an iterator properly? + return iterator + + async def _manage_stream(self): + """Keeps the stream alive and routes responses to the appropriate request handler""" + while True: + try: + response_iterable = self._grpc_client.quantum_run_stream(self._request_iterator) + async for response in response_iterable: + self._response_demux.publish(response) + except GoogleAPICallError as e: # TODO what's the right error to check here? + self._response_demux.publish_exception(e) + + async def _run_program( + self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + ) -> quantum.QuantumResult: + """This method is executed in a separate asyncio Task for each request.""" + create_program_and_job_request = quantum.QuantumRunStreamRequest( + parent=project_name, + create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( + parent=project_name, quantum_program=program, quantum_job=job + ), + ) + get_result_request = quantum.QuantumRunStreamRequest( + parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) + ) + + response_future = self._response_demux.subscribe(create_program_and_job_request) + await self._request_queue.put(create_program_and_job_request) + + response: Optional[quantum.QuantumRunStreamResponse] = None + while response is None: + try: + response = await response_future + except GoogleAPICallError: + response_future = self._response_demux.subscribe(get_result_request) + await self._request_queue.put(get_result_request) + + if response.result is not None: + return response.result + # TODO handle QuantumJob response and retryable StreamError. + + async def _cancel(self, job_name: str) -> None: + await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) + + def send( + self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + ) -> duet.AwaitableFuture[quantum.QuantumResult]: + """Sends a request over the stream and returns a future for the result.""" + if not self._manage_stream_loop_running: + self._executor.submit(self._manage_stream) + result_future = self._executor.submit(self._run_program, project_name, program, job) + result_future.add_done_callback(lambda _: self._executor.submit(self._cancel, job.name)) + # TODO will asyncio run_program task terminate when future is cancelled? + return result_future + + class EngineClient: """Client for the Quantum Engine API handling protos and gRPC client. @@ -148,6 +281,10 @@ async def make_client(): return self._executor.submit(make_client).result() + @cached_property + def stream_manager(self) -> StreamManager: + return StreamManager(self.grpc_client) + async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: """Sends a request by invoking an asyncio callable.""" return await self._run_retry_async(func, request) @@ -277,7 +414,7 @@ async def list_programs_async( val = _date_or_time_to_filter_expr('created_before', created_before) filters.append(f"create_time <= {val}") if has_labels is not None: - for (k, v) in has_labels.items(): + for k, v in has_labels.items(): filters.append(f"labels.{k}:{v}") request = quantum.ListQuantumProgramsRequest( parent=_project_name(project_id), filter=" AND ".join(filters) @@ -528,7 +665,7 @@ async def list_jobs_async( val = _date_or_time_to_filter_expr('created_before', created_before) filters.append(f"create_time <= {val}") if has_labels is not None: - for (k, v) in has_labels.items(): + for k, v in has_labels.items(): filters.append(f"labels.{k}:{v}") if execution_states is not None: state_filter = [] @@ -744,6 +881,55 @@ async def get_job_results_async( get_job_results = duet.sync(get_job_results_async) + def run_job_over_stream( + self, + project_id: str, + program_id: str, + job_id: Optional[str], + processor_ids: Sequence[str], + run_context: any_pb2.Any, + priority: Optional[int] = None, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + ) -> duet.AwaitableFuture[quantum.QuantumResult]: + # Check program to run and program parameters. + if priority and not 0 <= priority < 1000: + raise ValueError('priority must be between 0 and 1000') + + # Create job. + job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else '' + job = quantum.QuantumJob( + name=job_name, + scheduling_config=quantum.SchedulingConfig( + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor_names=[ + _processor_name_from_ids(project_id, processor_id) + for processor_id in processor_ids + ] + ) + ), + run_context=run_context, + ) + if priority: + job.scheduling_config.priority = priority + if description: + job.description = description + if labels: + job.labels.update(labels) + job_request = quantum.CreateQuantumJobRequest( + parent=_program_name_from_ids(project_id, program_id), + quantum_job=job, + overwrite_existing_run_context=False, + ) + stream_request = quantum.QuantumRunStreamRequest( + message_id=self._msg_id_generator.generate(), + parent=_project_name(project_id), + create_quantum_job=job_request, + ) + return self.stream_manager.send(stream_request) + + # TODO NEXT UP: change to sending over QuantumProgram instead... + async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to From d69d5ca3313bc8685d84dd925a62672245ff2e95 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Thu, 22 Jun 2023 23:51:36 +0000 Subject: [PATCH 02/16] Improved cancellation; added method to stop the stream manager loop --- .../cirq_google/engine/engine_client.py | 105 +++++++++++------- 1 file changed, 66 insertions(+), 39 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 3b959093740..c1f66860dbd 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -31,6 +31,7 @@ Union, ) import warnings +from concurrent.futures import Future import duet import proto @@ -99,11 +100,7 @@ def instance(cls): class ResponseDemux: - """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern. - Args: - cancel_callback: Function to be called when the future matching its request argument is - canceled. - """ + """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.""" def __init__(self): self._subscribers: Dict[JobPath, Tuple[MessageId, duet.AwaitableFuture]] = {} @@ -113,20 +110,18 @@ def subscribe( self, request: quantum.QuantumRunStreamRequest ) -> duet.AwaitableFuture[quantum.QuantumRunStreamResponse]: """Assumes the message ID has not been set.""" - - if 'create_quantum_program_and_job' in request: - job_path = request.create_quantum_program_and_job.quantum_job.name - elif 'create_quantum_job' in request: - job_path = request.create_quantum_job.quantum_job.name - else: # 'get_quantum_result' in request - job_path = request.get_quantum_result.parent - + job_path = _get_job_path_from_stream_request(request) request.message_id = self._next_available_message_id response_future = duet.AwaitableFuture[quantum.QuantumRunStreamResponse]() self._subscribers[job_path] = (self._next_available_message_id, response_future) self._next_available_message_id += 1 return response_future + def unsubscribe(self, request: quantum.QuantumRunStreamRequest) -> None: + job_path = _get_job_path_from_stream_request(request) + if job_path in self._subscribers: + del self._subscribers[job_path] + def publish(self, response: quantum.QuantumRunStreamResponse) -> None: if 'error' in response: job_path = next( @@ -160,7 +155,7 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._grpc_client = grpc_client self._request_queue = asyncio.Queue() self._response_demux = ResponseDemux() - self._manage_stream_loop_running = False + self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None def _executor(self) -> AsyncioExecutor: # We must re-use a single Executor due to multi-threading issues in gRPC @@ -169,26 +164,29 @@ def _executor(self) -> AsyncioExecutor: def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: async def iterator(): - yield await self._request_queue.get() + while not self._request_queue.empty(): + yield await self._request_queue.get() - # TODO how to make an iterator properly? return iterator async def _manage_stream(self): """Keeps the stream alive and routes responses to the appropriate request handler""" while True: try: + # TODO specify stream timeout below with exponential backoff response_iterable = self._grpc_client.quantum_run_stream(self._request_iterator) async for response in response_iterable: self._response_demux.publish(response) - except GoogleAPICallError as e: # TODO what's the right error to check here? - self._response_demux.publish_exception(e) + except BaseException as e: + # TODO send halfclose to the stream upon CancelledError? How does that work? + self._response_demux.publish_exception(e) # Raise to all request tasks + self._request_iterator = asyncio.Queue() # Clear requests - async def _run_program( + async def _make_request( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob ) -> quantum.QuantumResult: """This method is executed in a separate asyncio Task for each request.""" - create_program_and_job_request = quantum.QuantumRunStreamRequest( + current_request = quantum.QuantumRunStreamRequest( parent=project_name, create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( parent=project_name, quantum_program=program, quantum_job=job @@ -197,21 +195,31 @@ async def _run_program( get_result_request = quantum.QuantumRunStreamRequest( parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) ) + response_future = None - response_future = self._response_demux.subscribe(create_program_and_job_request) - await self._request_queue.put(create_program_and_job_request) - - response: Optional[quantum.QuantumRunStreamResponse] = None - while response is None: - try: - response = await response_future - except GoogleAPICallError: - response_future = self._response_demux.subscribe(get_result_request) - await self._request_queue.put(get_result_request) - - if response.result is not None: - return response.result - # TODO handle QuantumJob response and retryable StreamError. + try: + response_future = self._response_demux.subscribe(current_request) + await self._request_queue.put(current_request) + + response: Optional[quantum.QuantumRunStreamResponse] = None + while response is None: + try: + response = await response_future + except GoogleAPICallError: + # TODO how to distinguish between program not found vs job not found? + # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or + # job doesn't exist. + # TODO add exponential backoff + current_request = get_result_request + + if response.result is not None: + return response.result + # TODO handle QuantumJob response and retryable StreamError. + + except asyncio.CancelledError: + if response_future is not None: + response_future.cancel() + self._response_demux.unsubscribe(current_request) async def _cancel(self, job_name: str) -> None: await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) @@ -220,13 +228,23 @@ def send( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob ) -> duet.AwaitableFuture[quantum.QuantumResult]: """Sends a request over the stream and returns a future for the result.""" - if not self._manage_stream_loop_running: - self._executor.submit(self._manage_stream) - result_future = self._executor.submit(self._run_program, project_name, program, job) - result_future.add_done_callback(lambda _: self._executor.submit(self._cancel, job.name)) - # TODO will asyncio run_program task terminate when future is cancelled? + if self._manage_stream_loop_future is None: + self._manage_stream_loop_future = self._executor.submit(self._manage_stream) + result_future = self._executor.submit(self._make_request, project_name, program, job) + + def cancel(future: Future): + if future.cancelled(): + self._executor.submit(self._cancel, job.name) + + result_future.add_done_callback(cancel) return result_future + def stop(self) -> None: + """Stops and resets the stream manager.""" + if self._manage_stream_loop_future is not None: + self._manage_stream_loop_future.cancel() + self._manage_stream_loop_future = None + class EngineClient: """Client for the Quantum Engine API handling protos and gRPC client. @@ -1323,3 +1341,12 @@ def _date_or_time_to_filter_expr(param_name: str, param: Union[datetime.datetime f"type {type(param)}. Supported types: datetime.datetime and" f"datetime.date" ) + + +def _get_job_path_from_stream_request(request: quantum.QuantumRunStreamRequest) -> str: + if 'create_quantum_program_and_job' in request: + return request.create_quantum_program_and_job.quantum_job.name + elif 'create_quantum_job' in request: + return request.create_quantum_job.quantum_job.name + # 'get_quantum_result' in request + return request.get_quantum_result.parent From cfa8a528deb2b3f533bb746b271c9ddd24a4fe33 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 23 Jun 2023 00:14:59 +0000 Subject: [PATCH 03/16] Moved make_request task cancellation handler to be inside the task --- cirq-google/cirq_google/engine/engine_client.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index c1f66860dbd..8f30a7d5473 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -216,10 +216,12 @@ async def _make_request( return response.result # TODO handle QuantumJob response and retryable StreamError. + # Either when this request is canceled or the _manage_stream() loop is canceled. except asyncio.CancelledError: if response_future is not None: response_future.cancel() self._response_demux.unsubscribe(current_request) + await self._cancel(job.name) async def _cancel(self, job_name: str) -> None: await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) @@ -230,14 +232,7 @@ def send( """Sends a request over the stream and returns a future for the result.""" if self._manage_stream_loop_future is None: self._manage_stream_loop_future = self._executor.submit(self._manage_stream) - result_future = self._executor.submit(self._make_request, project_name, program, job) - - def cancel(future: Future): - if future.cancelled(): - self._executor.submit(self._cancel, job.name) - - result_future.add_done_callback(cancel) - return result_future + return self._executor.submit(self._make_request, project_name, program, job) def stop(self) -> None: """Stops and resets the stream manager.""" From 658f1b3a92e6150982045ff17aca608dbaaa364e Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 23 Jun 2023 01:02:06 +0000 Subject: [PATCH 04/16] Fix the call to send() from run_job_over_stream() --- .../cirq_google/engine/engine_client.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 8f30a7d5473..6921d947bb3 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -898,6 +898,7 @@ def run_job_over_stream( self, project_id: str, program_id: str, + code: any_pb2.Any, job_id: Optional[str], processor_ids: Sequence[str], run_context: any_pb2.Any, @@ -909,7 +910,15 @@ def run_job_over_stream( if priority and not 0 <= priority < 1000: raise ValueError('priority must be between 0 and 1000') - # Create job. + project_name = _project_name(project_id) + + program_name = _program_name_from_ids(project_id, program_id) if program_id else '' + program = quantum.QuantumProgram(name=program_name, code=code) + if description: + program.description = description + if labels: + program.labels.update(labels) + job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else '' job = quantum.QuantumJob( name=job_name, @@ -929,19 +938,8 @@ def run_job_over_stream( job.description = description if labels: job.labels.update(labels) - job_request = quantum.CreateQuantumJobRequest( - parent=_program_name_from_ids(project_id, program_id), - quantum_job=job, - overwrite_existing_run_context=False, - ) - stream_request = quantum.QuantumRunStreamRequest( - message_id=self._msg_id_generator.generate(), - parent=_project_name(project_id), - create_quantum_job=job_request, - ) - return self.stream_manager.send(stream_request) - # TODO NEXT UP: change to sending over QuantumProgram instead... + return self.stream_manager.send(project_name, program, job) async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the From b2e0b114d1fd493bb604ef8c210c3bdcff2998ab Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 23 Jun 2023 18:54:39 +0000 Subject: [PATCH 05/16] [WIP] Testing --- .../cirq_google/engine/engine_client.py | 7 ++- .../cirq_google/engine/engine_client_test.py | 53 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 6921d947bb3..402568f77d3 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -295,7 +295,7 @@ async def make_client(): return self._executor.submit(make_client).result() @cached_property - def stream_manager(self) -> StreamManager: + def _stream_manager(self) -> StreamManager: return StreamManager(self.grpc_client) async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: @@ -939,7 +939,10 @@ def run_job_over_stream( if labels: job.labels.update(labels) - return self.stream_manager.send(project_name, program, job) + return self._stream_manager.send(project_name, program, job) + + def stop_stream(self): + self._stream_manager.stop() async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index e45127eb21c..f69386bf2ce 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for EngineClient.""" +from typing import AsyncIterable, AsyncIterator, Awaitable + import asyncio import datetime from unittest import mock @@ -26,6 +28,7 @@ from cirq_google.engine.engine_client import EngineClient, EngineException from cirq_google.engine.test_utils import uses_async_mock from cirq_google.cloud import quantum +from cirq_google.cloud.quantum_v1alpha1.types import engine def setup_mock_(client_constructor): @@ -34,6 +37,33 @@ def setup_mock_(client_constructor): return grpc_client +def setup_fake_quantum_run_stream_client(client_constructor): + grpc_client = _FakeQuantumRunStream() + client_constructor.return_value = grpc_client + return grpc_client + + +class _FakeQuantumRunStream: + def __init__(self): + self.request_count = 0 + + def set_response_list(self, response_list): + self._response_list = response_list + + async def quantum_run_stream( + self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs + ) -> AsyncIterable[engine.QuantumRunStreamResponse]: + async def run_async_iterator(): + logger.warning('Server: Waiting for requests...') + async for _ in requests: + logger.warning('Server: Got a request, yielding response next') + self.request_count += 1 + yield self._response_list.pop(0) + + await asyncio.sleep(0.0001) + return run_async_iterator() + + @uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_program(client_constructor): @@ -1229,3 +1259,26 @@ def test_list_time_slots(client_constructor): client = EngineClient() assert client.list_time_slots('proj', 'processor0') == results + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_run_job_over_stream(client_constructor): + fake_client = setup_fake_quantum_run_stream_client(client_constructor) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] + fake_client.set_response_list(mock_responses) + + actual_result = client.run_job_over_stream( + 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels + ).result(timeout=5) + client.stop_stream() + + # TODO(verult) test that response listener message IDs are being deleted + + assert actual_result == expected_result + assert fake_client.request_count == 1 From 2f2263d1b6e1e705a6c26bd821be3aa76bf1a19e Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 23 Jun 2023 21:27:47 +0000 Subject: [PATCH 06/16] Bug fix: make_request didn't subscribe after retrying --- .../cirq_google/engine/engine_client.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 402568f77d3..0dd1ba3762c 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -195,33 +195,33 @@ async def _make_request( get_result_request = quantum.QuantumRunStreamRequest( parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) ) - response_future = None - try: - response_future = self._response_demux.subscribe(current_request) - await self._request_queue.put(current_request) - - response: Optional[quantum.QuantumRunStreamResponse] = None - while response is None: - try: - response = await response_future - except GoogleAPICallError: - # TODO how to distinguish between program not found vs job not found? - # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or - # job doesn't exist. - # TODO add exponential backoff - current_request = get_result_request - - if response.result is not None: - return response.result - # TODO handle QuantumJob response and retryable StreamError. - - # Either when this request is canceled or the _manage_stream() loop is canceled. - except asyncio.CancelledError: - if response_future is not None: - response_future.cancel() - self._response_demux.unsubscribe(current_request) - await self._cancel(job.name) + response_future: Optional[duet.AwaitableFuture[quantum.QuantumRunStreamResponse]] = None + response: Optional[quantum.QuantumRunStreamResponse] = None + while response is None: + try: + response_future = self._response_demux.subscribe(current_request) + await self._request_queue.put(current_request) + response = await response_future + + except GoogleAPICallError: + # TODO how to distinguish between program not found vs job not found? + # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or + # job doesn't exist. + # TODO add exponential backoff + current_request = get_result_request + + # Either when this request is canceled or the _manage_stream() loop is canceled. + except asyncio.CancelledError: + if response_future is not None: + response_future.cancel() + self._response_demux.unsubscribe(current_request) + await self._cancel(job.name) + return + + if response.result is not None: + return response.result + # TODO handle QuantumJob response and retryable StreamError. async def _cancel(self, job_name: str) -> None: await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) From f5295b5a2db1b4d20c593a449204f013e3fa58e3 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sat, 24 Jun 2023 00:54:21 +0000 Subject: [PATCH 07/16] Change ResponseDemux to use asyncio.Future --- cirq-google/cirq_google/engine/engine_client.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 0dd1ba3762c..4def0b4ca33 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -103,16 +103,14 @@ class ResponseDemux: """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.""" def __init__(self): - self._subscribers: Dict[JobPath, Tuple[MessageId, duet.AwaitableFuture]] = {} + self._subscribers: Dict[JobPath, Tuple[MessageId, asyncio.Future]] = {} self._next_available_message_id = 0 - def subscribe( - self, request: quantum.QuantumRunStreamRequest - ) -> duet.AwaitableFuture[quantum.QuantumRunStreamResponse]: + def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future: """Assumes the message ID has not been set.""" job_path = _get_job_path_from_stream_request(request) request.message_id = self._next_available_message_id - response_future = duet.AwaitableFuture[quantum.QuantumRunStreamResponse]() + response_future = asyncio.Future() self._subscribers[job_path] = (self._next_available_message_id, response_future) self._next_available_message_id += 1 return response_future @@ -140,13 +138,16 @@ def publish(self, response: quantum.QuantumRunStreamResponse) -> None: if job_path not in self._subscribers: return - self._subscribers[job_path].try_set_result(response) + _, future = self._subscribers[job_path] + if not future.done(): + future.set_result(response) del self._subscribers[job_path] def publish_exception(self, exception: GoogleAPICallError) -> None: """Publishes an exception to all outstanding futures.""" for _, future in self._subscribers.values(): - future.try_set_exception(exception) + if not future.done(): + future.set_exception(exception) self._subscribers = {} @@ -196,7 +197,7 @@ async def _make_request( parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) ) - response_future: Optional[duet.AwaitableFuture[quantum.QuantumRunStreamResponse]] = None + response_future: Optional[asyncio.Future] = None response: Optional[quantum.QuantumRunStreamResponse] = None while response is None: try: From 748c591d5d7db947e5b4a907e8bdd42bda3c6257 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sat, 24 Jun 2023 01:06:32 +0000 Subject: [PATCH 08/16] More bug fixes: * Message ID should be str in the request. * _executor and _request_iterator should be properties * quantum_run_stream() call needs to be awaited. --- .../cirq_google/engine/engine_client.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 4def0b4ca33..c9f9e0b3ed2 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -31,7 +31,6 @@ Union, ) import warnings -from concurrent.futures import Future import duet import proto @@ -109,9 +108,9 @@ def __init__(self): def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future: """Assumes the message ID has not been set.""" job_path = _get_job_path_from_stream_request(request) - request.message_id = self._next_available_message_id + request.message_id = str(self._next_available_message_id) response_future = asyncio.Future() - self._subscribers[job_path] = (self._next_available_message_id, response_future) + self._subscribers[job_path] = (request.message_id, response_future) self._next_available_message_id += 1 return response_future @@ -158,30 +157,31 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._response_demux = ResponseDemux() self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None + @property def _executor(self) -> AsyncioExecutor: # We must re-use a single Executor due to multi-threading issues in gRPC # clients: https://github.com/grpc/grpc/issues/25364. return AsyncioExecutor.instance() - def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: - async def iterator(): - while not self._request_queue.empty(): - yield await self._request_queue.get() - - return iterator + @property + async def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: + while not self._request_queue.empty(): + yield await self._request_queue.get() async def _manage_stream(self): """Keeps the stream alive and routes responses to the appropriate request handler""" while True: try: # TODO specify stream timeout below with exponential backoff - response_iterable = self._grpc_client.quantum_run_stream(self._request_iterator) + response_iterable = await self._grpc_client.quantum_run_stream( + self._request_iterator + ) async for response in response_iterable: self._response_demux.publish(response) except BaseException as e: # TODO send halfclose to the stream upon CancelledError? How does that work? self._response_demux.publish_exception(e) # Raise to all request tasks - self._request_iterator = asyncio.Queue() # Clear requests + self._request_queue = asyncio.Queue() # Clear requests async def _make_request( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob From 7f2295b8e40551fe631ccf8c957e230e7c969fa0 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sat, 24 Jun 2023 01:08:18 +0000 Subject: [PATCH 09/16] Basic test of successful path --- cirq-google/cirq_google/engine/engine_client_test.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index f69386bf2ce..40a58943c2f 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for EngineClient.""" -from typing import AsyncIterable, AsyncIterator, Awaitable +from typing import AsyncIterable, AsyncIterator import asyncio import datetime @@ -54,9 +54,7 @@ async def quantum_run_stream( self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs ) -> AsyncIterable[engine.QuantumRunStreamResponse]: async def run_async_iterator(): - logger.warning('Server: Waiting for requests...') async for _ in requests: - logger.warning('Server: Got a request, yielding response next') self.request_count += 1 yield self._response_list.pop(0) @@ -1278,7 +1276,5 @@ def test_run_job_over_stream(client_constructor): ).result(timeout=5) client.stop_stream() - # TODO(verult) test that response listener message IDs are being deleted - assert actual_result == expected_result assert fake_client.request_count == 1 From 4c2ed33b57615416ddaa01fb58dc907ff43a7efa Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sat, 24 Jun 2023 01:33:15 +0000 Subject: [PATCH 10/16] Cancellation test --- .../cirq_google/engine/engine_client_test.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index 40a58943c2f..ff1e392384f 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -45,7 +45,8 @@ def setup_fake_quantum_run_stream_client(client_constructor): class _FakeQuantumRunStream: def __init__(self): - self.request_count = 0 + self.stream_request_count = 0 + self.cancel_requests = [] def set_response_list(self, response_list): self._response_list = response_list @@ -55,12 +56,18 @@ async def quantum_run_stream( ) -> AsyncIterable[engine.QuantumRunStreamResponse]: async def run_async_iterator(): async for _ in requests: - self.request_count += 1 + self.stream_request_count += 1 + while not self._response_list: + await asyncio.sleep(1) yield self._response_list.pop(0) await asyncio.sleep(0.0001) return run_async_iterator() + async def cancel_quantum_job(self, request: engine.CancelQuantumJobRequest) -> None: + self.cancel_requests.append(request) + await asyncio.sleep(0.0001) + @uses_async_mock @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) @@ -1277,4 +1284,27 @@ def test_run_job_over_stream(client_constructor): client.stop_stream() assert actual_result == expected_result - assert fake_client.request_count == 1 + assert fake_client.stream_request_count == 1 + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_run_job_over_stream_cancellation(client_constructor): + fake_client = setup_fake_quantum_run_stream_client(client_constructor) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] + fake_client.set_response_list(mock_responses) + + result_future = client.run_job_over_stream( + 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels + ) + result_future.cancel() + client.stop_stream() + + assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( + name='projects/proj/programs/prog/jobs/job0' + ) From d54b10fb754da8e6ccccca0e76241ba4f5c1976f Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sun, 25 Jun 2023 02:06:13 +0000 Subject: [PATCH 11/16] Bugfix: response.result throws when GoogleAPICallError is caught --- cirq-google/cirq_google/engine/engine_client.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index c9f9e0b3ed2..a769ef3d7ee 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -204,9 +204,12 @@ async def _make_request( response_future = self._response_demux.subscribe(current_request) await self._request_queue.put(current_request) response = await response_future + if response.result is not None: + return response.result except GoogleAPICallError: # TODO how to distinguish between program not found vs job not found? + # TODO handle QuantumJob response and retryable StreamError. # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or # job doesn't exist. # TODO add exponential backoff @@ -220,10 +223,6 @@ async def _make_request( await self._cancel(job.name) return - if response.result is not None: - return response.result - # TODO handle QuantumJob response and retryable StreamError. - async def _cancel(self, job_name: str) -> None: await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) From 9a65eb0072c898d9b72d8844ce690ff1cb36524f Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Sun, 25 Jun 2023 02:08:08 +0000 Subject: [PATCH 12/16] Stream break test --- .../cirq_google/engine/engine_client_test.py | 46 +++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index ff1e392384f..04cfd472601 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for EngineClient.""" -from typing import AsyncIterable, AsyncIterator +from typing import AsyncIterable, AsyncIterator, List import asyncio import datetime @@ -47,9 +47,12 @@ class _FakeQuantumRunStream: def __init__(self): self.stream_request_count = 0 self.cancel_requests = [] + self.responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] = [] - def set_response_list(self, response_list): - self._response_list = response_list + def add_responses_and_exceptions( + self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] + ): + self.responses_and_exceptions.extend(responses_and_exceptions) async def quantum_run_stream( self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs @@ -57,9 +60,12 @@ async def quantum_run_stream( async def run_async_iterator(): async for _ in requests: self.stream_request_count += 1 - while not self._response_list: - await asyncio.sleep(1) - yield self._response_list.pop(0) + while not self.responses_and_exceptions: + await asyncio.sleep(0.1) + response_or_exception = self.responses_and_exceptions.pop(0) + if isinstance(response_or_exception, BaseException): + raise response_or_exception + yield response_or_exception await asyncio.sleep(0.0001) return run_async_iterator() @@ -1276,7 +1282,7 @@ def test_run_job_over_stream(client_constructor): client = EngineClient() expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] - fake_client.set_response_list(mock_responses) + fake_client.add_responses_and_exceptions(mock_responses) actual_result = client.run_job_over_stream( 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels @@ -1297,7 +1303,7 @@ def test_run_job_over_stream_cancellation(client_constructor): client = EngineClient() expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] - fake_client.set_response_list(mock_responses) + fake_client.add_responses_and_exceptions(mock_responses) result_future = client.run_job_over_stream( 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels @@ -1308,3 +1314,27 @@ def test_run_job_over_stream_cancellation(client_constructor): assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( name='projects/proj/programs/prog/jobs/job0' ) + + +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +def test_run_job_over_stream_stream_broken(client_constructor): + fake_client = setup_fake_quantum_run_stream_client(client_constructor) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses_and_exceptions = [ + exceptions.Aborted('aborted'), + quantum.QuantumRunStreamResponse(message_id='0', result=expected_result), + ] + fake_client.add_responses_and_exceptions(mock_responses_and_exceptions) + + actual_result = client.run_job_over_stream( + 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels + ).result(timeout=5) + client.stop_stream() + + assert actual_result == expected_result + assert fake_client.stream_request_count == 2 From 077850114ef0f4bb1b85d132bc2703219c04350a Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 28 Jun 2023 21:53:24 +0000 Subject: [PATCH 13/16] Typecheck fixes --- cirq-google/cirq_google/engine/engine_client.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index a769ef3d7ee..43e7e64cc03 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -109,7 +109,7 @@ def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future: """Assumes the message ID has not been set.""" job_path = _get_job_path_from_stream_request(request) request.message_id = str(self._next_available_message_id) - response_future = asyncio.Future() + response_future: asyncio.Future = asyncio.Future() self._subscribers[job_path] = (request.message_id, response_future) self._next_available_message_id += 1 return response_future @@ -153,7 +153,7 @@ def publish_exception(self, exception: GoogleAPICallError) -> None: class StreamManager: def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._grpc_client = grpc_client - self._request_queue = asyncio.Queue() + self._request_queue: asyncio.Queue = asyncio.Queue() self._response_demux = ResponseDemux() self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None @@ -165,6 +165,7 @@ def _executor(self) -> AsyncioExecutor: @property async def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: + # TODO might need to keep this long-running in order to not close the stream. while not self._request_queue.empty(): yield await self._request_queue.get() @@ -204,16 +205,14 @@ async def _make_request( response_future = self._response_demux.subscribe(current_request) await self._request_queue.put(current_request) response = await response_future - if response.result is not None: - return response.result except GoogleAPICallError: # TODO how to distinguish between program not found vs job not found? - # TODO handle QuantumJob response and retryable StreamError. # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or # job doesn't exist. # TODO add exponential backoff current_request = get_result_request + continue # Either when this request is canceled or the _manage_stream() loop is canceled. except asyncio.CancelledError: @@ -221,7 +220,11 @@ async def _make_request( response_future.cancel() self._response_demux.unsubscribe(current_request) await self._cancel(job.name) - return + return quantum.QuantumResult() + + if response.result is not None: + return response.result + # TODO handle QuantumJob response and retryable StreamError. async def _cancel(self, job_name: str) -> None: await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) From 7300522d70c012805676953506ce1222dc4e7a17 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 28 Jun 2023 23:45:24 +0000 Subject: [PATCH 14/16] Addressed wcourtney's comments --- .../cirq_google/engine/engine_client.py | 10 ++-- .../cirq_google/engine/engine_client_test.py | 55 +++++++++++-------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 43e7e64cc03..1fa0fb488d1 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -180,7 +180,7 @@ async def _manage_stream(self): async for response in response_iterable: self._response_demux.publish(response) except BaseException as e: - # TODO send halfclose to the stream upon CancelledError? How does that work? + # TODO Close the request iterator to close the existing stream. self._response_demux.publish_exception(e) # Raise to all request tasks self._request_queue = asyncio.Queue() # Clear requests @@ -902,7 +902,7 @@ def run_job_over_stream( project_id: str, program_id: str, code: any_pb2.Any, - job_id: Optional[str], + job_id: Optional[str], # TODO make this non-optional. processor_ids: Sequence[str], run_context: any_pb2.Any, priority: Optional[int] = None, @@ -1347,5 +1347,7 @@ def _get_job_path_from_stream_request(request: quantum.QuantumRunStreamRequest) return request.create_quantum_program_and_job.quantum_job.name elif 'create_quantum_job' in request: return request.create_quantum_job.quantum_job.name - # 'get_quantum_result' in request - return request.get_quantum_result.parent + elif 'get_quantum_result' in request: + return request.get_quantum_result.parent + else: + raise ValueError(f'Unrecognized request type in request: {request}') diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index 04cfd472601..80e068e07c3 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -37,17 +37,19 @@ def setup_mock_(client_constructor): return grpc_client -def setup_fake_quantum_run_stream_client(client_constructor): - grpc_client = _FakeQuantumRunStream() +def setup_fake_quantum_run_stream_client(client_constructor, responses_and_exceptions): + grpc_client = _FakeQuantumRunStream(responses_and_exceptions) client_constructor.return_value = grpc_client return grpc_client class _FakeQuantumRunStream: - def __init__(self): + def __init__( + self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] + ): self.stream_request_count = 0 - self.cancel_requests = [] - self.responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] = [] + self.cancel_requests: List[engine.CancelQuantumJobRequest] = [] + self.responses_and_exceptions = responses_and_exceptions def add_responses_and_exceptions( self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] @@ -61,13 +63,13 @@ async def run_async_iterator(): async for _ in requests: self.stream_request_count += 1 while not self.responses_and_exceptions: - await asyncio.sleep(0.1) + await asyncio.sleep(0) response_or_exception = self.responses_and_exceptions.pop(0) if isinstance(response_or_exception, BaseException): raise response_or_exception yield response_or_exception - await asyncio.sleep(0.0001) + await asyncio.sleep(0) return run_async_iterator() async def cancel_quantum_job(self, request: engine.CancelQuantumJobRequest) -> None: @@ -1273,16 +1275,17 @@ def test_list_time_slots(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) -def test_run_job_over_stream(client_constructor): - fake_client = setup_fake_quantum_run_stream_client(client_constructor) +def test_run_job_over_stream_send_job_expects_result_response(client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] + fake_client = setup_fake_quantum_run_stream_client( + client_constructor, responses_and_exceptions=mock_responses + ) code = any_pb2.Any() run_context = any_pb2.Any() labels = {'hello': 'world'} client = EngineClient() - expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] - fake_client.add_responses_and_exceptions(mock_responses) actual_result = client.run_job_over_stream( 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels @@ -1294,15 +1297,17 @@ def test_run_job_over_stream(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) -def test_run_job_over_stream_cancellation(client_constructor): - fake_client = setup_fake_quantum_run_stream_client(client_constructor) +def test_run_job_over_stream_cancel_expects_engine_cancellation_rpc_call(client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] + fake_client = setup_fake_quantum_run_stream_client( + client_constructor, responses_and_exceptions=mock_responses + ) code = any_pb2.Any() run_context = any_pb2.Any() labels = {'hello': 'world'} client = EngineClient() - expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] fake_client.add_responses_and_exceptions(mock_responses) result_future = client.run_job_over_stream( @@ -1317,18 +1322,22 @@ def test_run_job_over_stream_cancellation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) -def test_run_job_over_stream_stream_broken(client_constructor): - fake_client = setup_fake_quantum_run_stream_client(client_constructor) - - code = any_pb2.Any() - run_context = any_pb2.Any() - labels = {'hello': 'world'} - client = EngineClient() +def test_run_job_over_stream_stream_broken_expects_retry_with_get_quantum_result( + client_constructor, +): expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') mock_responses_and_exceptions = [ exceptions.Aborted('aborted'), quantum.QuantumRunStreamResponse(message_id='0', result=expected_result), ] + fake_client = setup_fake_quantum_run_stream_client( + client_constructor, responses_and_exceptions=mock_responses_and_exceptions + ) + + code = any_pb2.Any() + run_context = any_pb2.Any() + labels = {'hello': 'world'} + client = EngineClient() fake_client.add_responses_and_exceptions(mock_responses_and_exceptions) actual_result = client.run_job_over_stream( From 8c8e901bd5049cd4dbeea400d04051bf8891afda Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 7 Jul 2023 21:55:31 +0000 Subject: [PATCH 15/16] Addressed maffoo's feedback --- .../cirq_google/engine/engine_client.py | 65 +++++++++---------- .../cirq_google/engine/engine_client_test.py | 29 ++++----- 2 files changed, 43 insertions(+), 51 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 1fa0fb488d1..898a6bba8c7 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from asyncio.log import logger import datetime import sys import threading @@ -102,49 +103,32 @@ class ResponseDemux: """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.""" def __init__(self): - self._subscribers: Dict[JobPath, Tuple[MessageId, asyncio.Future]] = {} + self._subscribers: Dict[MessageId, asyncio.Future] = {} self._next_available_message_id = 0 def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future: """Assumes the message ID has not been set.""" - job_path = _get_job_path_from_stream_request(request) request.message_id = str(self._next_available_message_id) - response_future: asyncio.Future = asyncio.Future() - self._subscribers[job_path] = (request.message_id, response_future) + response_future: asyncio.Future = asyncio.get_running_loop().create_future() + self._subscribers[request.message_id] = response_future self._next_available_message_id += 1 return response_future def unsubscribe(self, request: quantum.QuantumRunStreamRequest) -> None: - job_path = _get_job_path_from_stream_request(request) - if job_path in self._subscribers: - del self._subscribers[job_path] + if request.message_id in self._subscribers: + del self._subscribers[request.message_id] def publish(self, response: quantum.QuantumRunStreamResponse) -> None: - if 'error' in response: - job_path = next( - ( - p - for p, (message_id, _) in self._subscribers.items() - if message_id == response.message_id - ), - default='', - ) - elif 'job' in response: - job_path = response.job.name - else: # 'result' in response - job_path = response.result.parent - - if job_path not in self._subscribers: + if response.message_id not in self._subscribers: return - _, future = self._subscribers[job_path] + future = self._subscribers.pop(response.message_id) if not future.done(): future.set_result(response) - del self._subscribers[job_path] def publish_exception(self, exception: GoogleAPICallError) -> None: """Publishes an exception to all outstanding futures.""" - for _, future in self._subscribers.values(): + for future in self._subscribers.values(): if not future.done(): future.set_exception(exception) self._subscribers = {} @@ -163,10 +147,11 @@ def _executor(self) -> AsyncioExecutor: # clients: https://github.com/grpc/grpc/issues/25364. return AsyncioExecutor.instance() - @property async def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: - # TODO might need to keep this long-running in order to not close the stream. - while not self._request_queue.empty(): + """The request iterator for quantum_run_stream(). + + Every call of this method generates a new iterator.""" + while True: yield await self._request_queue.get() async def _manage_stream(self): @@ -175,14 +160,16 @@ async def _manage_stream(self): try: # TODO specify stream timeout below with exponential backoff response_iterable = await self._grpc_client.quantum_run_stream( - self._request_iterator + self._request_iterator() ) async for response in response_iterable: + logger.warning('publishing response to demux') self._response_demux.publish(response) except BaseException as e: # TODO Close the request iterator to close the existing stream. self._response_demux.publish_exception(e) # Raise to all request tasks - self._request_queue = asyncio.Queue() # Clear requests + if isinstance(e, asyncio.CancelledError): + break async def _make_request( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob @@ -202,15 +189,19 @@ async def _make_request( response: Optional[quantum.QuantumRunStreamResponse] = None while response is None: try: + logger.warn('Making request') response_future = self._response_demux.subscribe(current_request) await self._request_queue.put(current_request) response = await response_future + logger.warning('Got response') except GoogleAPICallError: # TODO how to distinguish between program not found vs job not found? # TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or # job doesn't exist. # TODO add exponential backoff + logger.warn('Got GoogleAPICallError') + self._response_demux.unsubscribe(current_request) current_request = get_result_request continue @@ -223,6 +214,7 @@ async def _make_request( return quantum.QuantumResult() if response.result is not None: + logger.warning('Got result') return response.result # TODO handle QuantumJob response and retryable StreamError. @@ -237,11 +229,12 @@ def send( self._manage_stream_loop_future = self._executor.submit(self._manage_stream) return self._executor.submit(self._make_request, project_name, program, job) - def stop(self) -> None: + def stop(self) -> duet.AwaitableFuture[None]: """Stops and resets the stream manager.""" - if self._manage_stream_loop_future is not None: - self._manage_stream_loop_future.cancel() - self._manage_stream_loop_future = None + if self._manage_stream_loop_future is None: + return duet.completed_future(None) + self._manage_stream_loop_future.cancel() + return self._manage_stream_loop_future class EngineClient: @@ -944,8 +937,8 @@ def run_job_over_stream( return self._stream_manager.send(project_name, program, job) - def stop_stream(self): - self._stream_manager.stop() + def stop_stream(self) -> duet.AwaitableFuture[None]: + return self._stream_manager.stop() async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index 80e068e07c3..7bcbf6d8aa5 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for EngineClient.""" -from typing import AsyncIterable, AsyncIterator, List +from typing import AsyncIterable, AsyncIterator, Awaitable, List import asyncio +from asyncio.log import logger import datetime from unittest import mock @@ -58,15 +59,17 @@ def add_responses_and_exceptions( async def quantum_run_stream( self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs - ) -> AsyncIterable[engine.QuantumRunStreamResponse]: + ) -> Awaitable[AsyncIterable[engine.QuantumRunStreamResponse]]: async def run_async_iterator(): - async for _ in requests: + async for request in requests: self.stream_request_count += 1 while not self.responses_and_exceptions: await asyncio.sleep(0) + logger.warning('Responding') response_or_exception = self.responses_and_exceptions.pop(0) if isinstance(response_or_exception, BaseException): raise response_or_exception + response_or_exception.message_id = request.message_id yield response_or_exception await asyncio.sleep(0) @@ -74,7 +77,7 @@ async def run_async_iterator(): async def cancel_quantum_job(self, request: engine.CancelQuantumJobRequest) -> None: self.cancel_requests.append(request) - await asyncio.sleep(0.0001) + await asyncio.sleep(0) @uses_async_mock @@ -1277,7 +1280,7 @@ def test_list_time_slots(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_run_job_over_stream_send_job_expects_result_response(client_constructor): expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] + mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] fake_client = setup_fake_quantum_run_stream_client( client_constructor, responses_and_exceptions=mock_responses ) @@ -1290,7 +1293,7 @@ def test_run_job_over_stream_send_job_expects_result_response(client_constructor actual_result = client.run_job_over_stream( 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels ).result(timeout=5) - client.stop_stream() + client.stop_stream().result assert actual_result == expected_result assert fake_client.stream_request_count == 1 @@ -1298,23 +1301,20 @@ def test_run_job_over_stream_send_job_expects_result_response(client_constructor @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_run_job_over_stream_cancel_expects_engine_cancellation_rpc_call(client_constructor): - expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - mock_responses = [quantum.QuantumRunStreamResponse(message_id='0', result=expected_result)] fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + client_constructor, responses_and_exceptions=[] ) code = any_pb2.Any() run_context = any_pb2.Any() labels = {'hello': 'world'} client = EngineClient() - fake_client.add_responses_and_exceptions(mock_responses) result_future = client.run_job_over_stream( 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels ) result_future.cancel() - client.stop_stream() + client.stop_stream().result assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( name='projects/proj/programs/prog/jobs/job0' @@ -1328,7 +1328,7 @@ def test_run_job_over_stream_stream_broken_expects_retry_with_get_quantum_result expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') mock_responses_and_exceptions = [ exceptions.Aborted('aborted'), - quantum.QuantumRunStreamResponse(message_id='0', result=expected_result), + quantum.QuantumRunStreamResponse(result=expected_result), ] fake_client = setup_fake_quantum_run_stream_client( client_constructor, responses_and_exceptions=mock_responses_and_exceptions @@ -1338,12 +1338,11 @@ def test_run_job_over_stream_stream_broken_expects_retry_with_get_quantum_result run_context = any_pb2.Any() labels = {'hello': 'world'} client = EngineClient() - fake_client.add_responses_and_exceptions(mock_responses_and_exceptions) actual_result = client.run_job_over_stream( 'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels - ).result(timeout=5) - client.stop_stream() + ).result(timeout=1) + client.stop_stream().result assert actual_result == expected_result assert fake_client.stream_request_count == 2 From 31ad529cb8854dea4247cb5deb53eed37b9459d5 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 7 Jul 2023 22:13:09 +0000 Subject: [PATCH 16/16] Add TODO to restrict response future scopes to just the relevant tasks --- cirq-google/cirq_google/engine/engine_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index e503e22b77a..f73945915cd 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -138,8 +138,10 @@ class StreamManager: def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._grpc_client = grpc_client self._request_queue: asyncio.Queue = asyncio.Queue() - self._response_demux = ResponseDemux() self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None + # TODO consider making the scope of response futures local to the relevant tasks rather than + # all of StreamManager. + self._response_demux = ResponseDemux() @property def _executor(self) -> AsyncioExecutor: